1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Handle error message from psycopg2 using to_native (#32371)

This fix adds handling of error/exception message using
to_native API instead of decoding.
Also, fixes PEP8 errors.

Fixes: #31825

Signed-off-by: Abhijeet Kasurde <akasurde@redhat.com>
This commit is contained in:
Abhijeet Kasurde 2017-11-02 01:41:52 +05:30 committed by ansibot
parent 61ca2a711d
commit 51c5b51ea3
2 changed files with 20 additions and 38 deletions

View file

@ -288,17 +288,17 @@ class Connection(object):
# check which values are empty and don't include in the **kw # check which values are empty and don't include in the **kw
# dictionary # dictionary
params_map = { params_map = {
"host":"host", "host": "host",
"login":"user", "login": "user",
"password":"password", "password": "password",
"port":"port", "port": "port",
"database": "database", "database": "database",
"ssl_mode":"sslmode", "ssl_mode": "sslmode",
"ssl_rootcert":"sslrootcert" "ssl_rootcert": "sslrootcert"
} }
kw = dict( (params_map[k], getattr(params, k)) for k in params_map kw = dict((params_map[k], getattr(params, k)) for k in params_map
if getattr(params, k) != '' and getattr(params, k) is not None ) if getattr(params, k) != '' and getattr(params, k) is not None)
# If a unix_socket is specified, incorporate it here. # If a unix_socket is specified, incorporate it here.
is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost" is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost"
@ -312,11 +312,9 @@ class Connection(object):
self.connection = psycopg2.connect(**kw) self.connection = psycopg2.connect(**kw)
self.cursor = self.connection.cursor() self.cursor = self.connection.cursor()
def commit(self): def commit(self):
self.connection.commit() self.connection.commit()
def rollback(self): def rollback(self):
self.connection.rollback() self.connection.rollback()
@ -325,8 +323,7 @@ class Connection(object):
"""Connection encoding in Python-compatible form""" """Connection encoding in Python-compatible form"""
return psycopg2.extensions.encodings[self.connection.encoding] return psycopg2.extensions.encodings[self.connection.encoding]
# Methods for querying database objects
### Methods for querying database objects
# PostgreSQL < 9.0 doesn't support "ALL TABLES IN SCHEMA schema"-like # PostgreSQL < 9.0 doesn't support "ALL TABLES IN SCHEMA schema"-like
# phrases in GRANT or REVOKE statements, therefore alternative methods are # phrases in GRANT or REVOKE statements, therefore alternative methods are
@ -338,7 +335,6 @@ class Connection(object):
self.cursor.execute(query, (schema,)) self.cursor.execute(query, (schema,))
return self.cursor.fetchone()[0] > 0 return self.cursor.fetchone()[0] > 0
def get_all_tables_in_schema(self, schema): def get_all_tables_in_schema(self, schema):
if not self.schema_exists(schema): if not self.schema_exists(schema):
raise Error('Schema "%s" does not exist.' % schema) raise Error('Schema "%s" does not exist.' % schema)
@ -349,7 +345,6 @@ class Connection(object):
self.cursor.execute(query, (schema,)) self.cursor.execute(query, (schema,))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_all_sequences_in_schema(self, schema): def get_all_sequences_in_schema(self, schema):
if not self.schema_exists(schema): if not self.schema_exists(schema):
raise Error('Schema "%s" does not exist.' % schema) raise Error('Schema "%s" does not exist.' % schema)
@ -360,9 +355,7 @@ class Connection(object):
self.cursor.execute(query, (schema,)) self.cursor.execute(query, (schema,))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
# Methods for getting access control lists and group membership info
### Methods for getting access control lists and group membership info
# To determine whether anything has changed after granting/revoking # To determine whether anything has changed after granting/revoking
# privileges, we compare the access control lists of the specified database # privileges, we compare the access control lists of the specified database
@ -379,7 +372,6 @@ class Connection(object):
self.cursor.execute(query, (schema, tables)) self.cursor.execute(query, (schema, tables))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_sequence_acls(self, schema, sequences): def get_sequence_acls(self, schema, sequences):
query = """SELECT relacl query = """SELECT relacl
FROM pg_catalog.pg_class c FROM pg_catalog.pg_class c
@ -389,7 +381,6 @@ class Connection(object):
self.cursor.execute(query, (schema, sequences)) self.cursor.execute(query, (schema, sequences))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_function_acls(self, schema, function_signatures): def get_function_acls(self, schema, function_signatures):
funcnames = [f.split('(', 1)[0] for f in function_signatures] funcnames = [f.split('(', 1)[0] for f in function_signatures]
query = """SELECT proacl query = """SELECT proacl
@ -400,35 +391,30 @@ class Connection(object):
self.cursor.execute(query, (schema, funcnames)) self.cursor.execute(query, (schema, funcnames))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_schema_acls(self, schemas): def get_schema_acls(self, schemas):
query = """SELECT nspacl FROM pg_catalog.pg_namespace query = """SELECT nspacl FROM pg_catalog.pg_namespace
WHERE nspname = ANY (%s) ORDER BY nspname""" WHERE nspname = ANY (%s) ORDER BY nspname"""
self.cursor.execute(query, (schemas,)) self.cursor.execute(query, (schemas,))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_language_acls(self, languages): def get_language_acls(self, languages):
query = """SELECT lanacl FROM pg_catalog.pg_language query = """SELECT lanacl FROM pg_catalog.pg_language
WHERE lanname = ANY (%s) ORDER BY lanname""" WHERE lanname = ANY (%s) ORDER BY lanname"""
self.cursor.execute(query, (languages,)) self.cursor.execute(query, (languages,))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_tablespace_acls(self, tablespaces): def get_tablespace_acls(self, tablespaces):
query = """SELECT spcacl FROM pg_catalog.pg_tablespace query = """SELECT spcacl FROM pg_catalog.pg_tablespace
WHERE spcname = ANY (%s) ORDER BY spcname""" WHERE spcname = ANY (%s) ORDER BY spcname"""
self.cursor.execute(query, (tablespaces,)) self.cursor.execute(query, (tablespaces,))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_database_acls(self, databases): def get_database_acls(self, databases):
query = """SELECT datacl FROM pg_catalog.pg_database query = """SELECT datacl FROM pg_catalog.pg_database
WHERE datname = ANY (%s) ORDER BY datname""" WHERE datname = ANY (%s) ORDER BY datname"""
self.cursor.execute(query, (databases,)) self.cursor.execute(query, (databases,))
return [t[0] for t in self.cursor.fetchall()] return [t[0] for t in self.cursor.fetchall()]
def get_group_memberships(self, groups): def get_group_memberships(self, groups):
query = """SELECT roleid, grantor, member, admin_option query = """SELECT roleid, grantor, member, admin_option
FROM pg_catalog.pg_auth_members am FROM pg_catalog.pg_auth_members am
@ -438,8 +424,7 @@ class Connection(object):
self.cursor.execute(query, (groups,)) self.cursor.execute(query, (groups,))
return self.cursor.fetchall() return self.cursor.fetchall()
# Manipulating privileges
### Manipulating privileges
def manipulate_privs(self, obj_type, privs, objs, roles, def manipulate_privs(self, obj_type, privs, objs, roles,
state, grant_option, schema_qualifier=None): state, grant_option, schema_qualifier=None):
@ -545,7 +530,7 @@ class Connection(object):
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec = dict( argument_spec=dict(
database=dict(required=True, aliases=['db']), database=dict(required=True, aliases=['db']),
state=dict(default='present', choices=['present', 'absent']), state=dict(default='present', choices=['present', 'absent']),
privs=dict(required=False, aliases=['priv']), privs=dict(required=False, aliases=['priv']),
@ -571,7 +556,7 @@ def main():
ssl_mode=dict(default="prefer", choices=['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']), ssl_mode=dict(default="prefer", choices=['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']),
ssl_rootcert=dict(default=None) ssl_rootcert=dict(default=None)
), ),
supports_check_mode = True supports_check_mode=True
) )
# Create type object as namespace for module params # Create type object as namespace for module params
@ -643,12 +628,12 @@ def main():
roles = p.roles.split(',') roles = p.roles.split(',')
changed = conn.manipulate_privs( changed = conn.manipulate_privs(
obj_type = p.type, obj_type=p.type,
privs = privs, privs=privs,
objs = objs, objs=objs,
roles = roles, roles=roles,
state = p.state, state=p.state,
grant_option = p.grant_option, grant_option=p.grant_option,
schema_qualifier=p.schema schema_qualifier=p.schema
) )
@ -658,9 +643,7 @@ def main():
except psycopg2.Error as e: except psycopg2.Error as e:
conn.rollback() conn.rollback()
# psycopg2 errors come in connection encoding module.fail_json(msg=to_native(e.message))
msg = to_text(e.message(encoding=conn.encoding))
module.fail_json(msg=msg)
if module.check_mode: if module.check_mode:
conn.rollback() conn.rollback()

View file

@ -135,7 +135,6 @@ lib/ansible/modules/database/mssql/mssql_db.py
lib/ansible/modules/database/postgresql/postgresql_db.py lib/ansible/modules/database/postgresql/postgresql_db.py
lib/ansible/modules/database/postgresql/postgresql_ext.py lib/ansible/modules/database/postgresql/postgresql_ext.py
lib/ansible/modules/database/postgresql/postgresql_lang.py lib/ansible/modules/database/postgresql/postgresql_lang.py
lib/ansible/modules/database/postgresql/postgresql_privs.py
lib/ansible/modules/database/postgresql/postgresql_schema.py lib/ansible/modules/database/postgresql/postgresql_schema.py
lib/ansible/modules/database/vertica/vertica_configuration.py lib/ansible/modules/database/vertica/vertica_configuration.py
lib/ansible/modules/database/vertica/vertica_facts.py lib/ansible/modules/database/vertica/vertica_facts.py