From 51c5b51ea3794ed8a74e72a4612103c90a5bca22 Mon Sep 17 00:00:00 2001 From: Abhijeet Kasurde Date: Thu, 2 Nov 2017 01:41:52 +0530 Subject: [PATCH] Handle error message from psycopg2 using to_native (#32371) This fix adds handling of error/exception message using to_native API instead of decoding. Also, fixes PEP8 errors. Fixes: #31825 Signed-off-by: Abhijeet Kasurde --- .../database/postgresql/postgresql_privs.py | 57 +++++++------------ test/sanity/pep8/legacy-files.txt | 1 - 2 files changed, 20 insertions(+), 38 deletions(-) diff --git a/lib/ansible/modules/database/postgresql/postgresql_privs.py b/lib/ansible/modules/database/postgresql/postgresql_privs.py index 1ec5e45773..e889cb315b 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_privs.py +++ b/lib/ansible/modules/database/postgresql/postgresql_privs.py @@ -288,17 +288,17 @@ class Connection(object): # check which values are empty and don't include in the **kw # dictionary params_map = { - "host":"host", - "login":"user", - "password":"password", - "port":"port", + "host": "host", + "login": "user", + "password": "password", + "port": "port", "database": "database", - "ssl_mode":"sslmode", - "ssl_rootcert":"sslrootcert" + "ssl_mode": "sslmode", + "ssl_rootcert": "sslrootcert" } - kw = dict( (params_map[k], getattr(params, k)) for k in params_map - if getattr(params, k) != '' and getattr(params, k) is not None ) + kw = dict((params_map[k], getattr(params, k)) for k in params_map + if getattr(params, k) != '' and getattr(params, k) is not None) # If a unix_socket is specified, incorporate it here. is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost" @@ -312,11 +312,9 @@ class Connection(object): self.connection = psycopg2.connect(**kw) self.cursor = self.connection.cursor() - def commit(self): self.connection.commit() - def rollback(self): self.connection.rollback() @@ -325,8 +323,7 @@ class Connection(object): """Connection encoding in Python-compatible form""" return psycopg2.extensions.encodings[self.connection.encoding] - - ### Methods for querying database objects + # Methods for querying database objects # PostgreSQL < 9.0 doesn't support "ALL TABLES IN SCHEMA schema"-like # phrases in GRANT or REVOKE statements, therefore alternative methods are @@ -338,7 +335,6 @@ class Connection(object): self.cursor.execute(query, (schema,)) return self.cursor.fetchone()[0] > 0 - def get_all_tables_in_schema(self, schema): if not self.schema_exists(schema): raise Error('Schema "%s" does not exist.' % schema) @@ -349,7 +345,6 @@ class Connection(object): self.cursor.execute(query, (schema,)) return [t[0] for t in self.cursor.fetchall()] - def get_all_sequences_in_schema(self, schema): if not self.schema_exists(schema): raise Error('Schema "%s" does not exist.' % schema) @@ -360,9 +355,7 @@ class Connection(object): self.cursor.execute(query, (schema,)) return [t[0] for t in self.cursor.fetchall()] - - - ### Methods for getting access control lists and group membership info + # Methods for getting access control lists and group membership info # To determine whether anything has changed after granting/revoking # privileges, we compare the access control lists of the specified database @@ -379,7 +372,6 @@ class Connection(object): self.cursor.execute(query, (schema, tables)) return [t[0] for t in self.cursor.fetchall()] - def get_sequence_acls(self, schema, sequences): query = """SELECT relacl FROM pg_catalog.pg_class c @@ -389,7 +381,6 @@ class Connection(object): self.cursor.execute(query, (schema, sequences)) return [t[0] for t in self.cursor.fetchall()] - def get_function_acls(self, schema, function_signatures): funcnames = [f.split('(', 1)[0] for f in function_signatures] query = """SELECT proacl @@ -400,35 +391,30 @@ class Connection(object): self.cursor.execute(query, (schema, funcnames)) return [t[0] for t in self.cursor.fetchall()] - def get_schema_acls(self, schemas): query = """SELECT nspacl FROM pg_catalog.pg_namespace WHERE nspname = ANY (%s) ORDER BY nspname""" self.cursor.execute(query, (schemas,)) return [t[0] for t in self.cursor.fetchall()] - def get_language_acls(self, languages): query = """SELECT lanacl FROM pg_catalog.pg_language WHERE lanname = ANY (%s) ORDER BY lanname""" self.cursor.execute(query, (languages,)) return [t[0] for t in self.cursor.fetchall()] - def get_tablespace_acls(self, tablespaces): query = """SELECT spcacl FROM pg_catalog.pg_tablespace WHERE spcname = ANY (%s) ORDER BY spcname""" self.cursor.execute(query, (tablespaces,)) return [t[0] for t in self.cursor.fetchall()] - def get_database_acls(self, databases): query = """SELECT datacl FROM pg_catalog.pg_database WHERE datname = ANY (%s) ORDER BY datname""" self.cursor.execute(query, (databases,)) return [t[0] for t in self.cursor.fetchall()] - def get_group_memberships(self, groups): query = """SELECT roleid, grantor, member, admin_option FROM pg_catalog.pg_auth_members am @@ -438,8 +424,7 @@ class Connection(object): self.cursor.execute(query, (groups,)) return self.cursor.fetchall() - - ### Manipulating privileges + # Manipulating privileges def manipulate_privs(self, obj_type, privs, objs, roles, state, grant_option, schema_qualifier=None): @@ -545,7 +530,7 @@ class Connection(object): def main(): module = AnsibleModule( - argument_spec = dict( + argument_spec=dict( database=dict(required=True, aliases=['db']), state=dict(default='present', choices=['present', 'absent']), privs=dict(required=False, aliases=['priv']), @@ -571,7 +556,7 @@ def main(): ssl_mode=dict(default="prefer", choices=['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']), ssl_rootcert=dict(default=None) ), - supports_check_mode = True + supports_check_mode=True ) # Create type object as namespace for module params @@ -643,12 +628,12 @@ def main(): roles = p.roles.split(',') changed = conn.manipulate_privs( - obj_type = p.type, - privs = privs, - objs = objs, - roles = roles, - state = p.state, - grant_option = p.grant_option, + obj_type=p.type, + privs=privs, + objs=objs, + roles=roles, + state=p.state, + grant_option=p.grant_option, schema_qualifier=p.schema ) @@ -658,9 +643,7 @@ def main(): except psycopg2.Error as e: conn.rollback() - # psycopg2 errors come in connection encoding - msg = to_text(e.message(encoding=conn.encoding)) - module.fail_json(msg=msg) + module.fail_json(msg=to_native(e.message)) if module.check_mode: conn.rollback() diff --git a/test/sanity/pep8/legacy-files.txt b/test/sanity/pep8/legacy-files.txt index b14dabb500..af8de143b4 100644 --- a/test/sanity/pep8/legacy-files.txt +++ b/test/sanity/pep8/legacy-files.txt @@ -135,7 +135,6 @@ lib/ansible/modules/database/mssql/mssql_db.py lib/ansible/modules/database/postgresql/postgresql_db.py lib/ansible/modules/database/postgresql/postgresql_ext.py lib/ansible/modules/database/postgresql/postgresql_lang.py -lib/ansible/modules/database/postgresql/postgresql_privs.py lib/ansible/modules/database/postgresql/postgresql_schema.py lib/ansible/modules/database/vertica/vertica_configuration.py lib/ansible/modules/database/vertica/vertica_facts.py