# Copyright: (c) 2019, Andrew Klychkov (@Andersson007) # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) import pytest import ansible_collections.community.general.plugins.module_utils.postgres as pg INPUT_DICT = dict( session_role=dict(default=''), login_user=dict(default='postgres'), login_password=dict(default='test', no_log=True), login_host=dict(default='test'), login_unix_socket=dict(default=''), port=dict(type='int', default=5432, aliases=['login_port']), ssl_mode=dict( default='prefer', choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full'] ), ca_cert=dict(aliases=['ssl_rootcert']), ) EXPECTED_DICT = dict( user=dict(default='postgres'), password=dict(default='test', no_log=True), host=dict(default='test'), port=dict(type='int', default=5432, aliases=['login_port']), sslmode=dict( default='prefer', choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full'] ), sslrootcert=dict(aliases=['ssl_rootcert']), ) class TestPostgresCommonArgSpec(): """ Namespace for testing postgresql_common_arg_spec() function. """ def test_postgres_common_argument_spec(self): """ Test for postgresql_common_arg_spec() function. The tested function just returns a dictionary with the default parameters and their values for PostgreSQL modules. The return and expected dictionaries must be compared. """ expected_dict = dict( login_user=dict(default='postgres'), login_password=dict(default='', no_log=True), login_host=dict(default=''), login_unix_socket=dict(default=''), port=dict(type='int', default=5432, aliases=['login_port']), ssl_mode=dict( default='prefer', choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full'] ), ca_cert=dict(aliases=['ssl_rootcert']), ) assert pg.postgres_common_argument_spec() == expected_dict @pytest.fixture def m_psycopg2(): """Return mock object for psycopg2 emulation.""" global Cursor Cursor = None class Cursor(): def __init__(self): self.passed_query = None def execute(self, query): self.passed_query = query def close(self): pass global DbConnection DbConnection = None class DbConnection(): def __init__(self): pass def cursor(self, cursor_factory=None): return Cursor() def set_session(self, autocommit=None): pass def set_isolation_level(self, isolevel): pass class Extras(): def __init__(self): self.DictCursor = True class Extensions(): def __init__(self): self.ISOLATION_LEVEL_AUTOCOMMIT = True class DummyPsycopg2(): def __init__(self): self.__version__ = '2.4.3' self.extras = Extras() self.extensions = Extensions() def connect(self, host=None, port=None, user=None, password=None, sslmode=None, sslrootcert=None): if user == 'Exception': raise Exception() return DbConnection() return DummyPsycopg2() class TestEnsureReqLibs(): """ Namespace for testing ensure_required_libs() function. If there is something wrong with libs, the function invokes fail_json() method of AnsibleModule object passed as an argument called 'module'. Therefore we must check: 1. value of err_msg attribute of m_ansible_module mock object. """ @pytest.fixture(scope='class') def m_ansible_module(self): """Return an object of dummy AnsibleModule class.""" class Dummym_ansible_module(): def __init__(self): self.params = {'ca_cert': False} self.err_msg = '' def fail_json(self, msg): self.err_msg = msg return Dummym_ansible_module() def test_ensure_req_libs_has_not_psycopg2(self, m_ansible_module): """Test ensure_required_libs() with psycopg2 is None.""" # HAS_PSYCOPG2 is False by default pg.ensure_required_libs(m_ansible_module) assert 'Failed to import the required Python library (psycopg2)' in m_ansible_module.err_msg def test_ensure_req_libs_has_psycopg2(self, m_ansible_module, monkeypatch): """Test ensure_required_libs() with psycopg2 is not None.""" monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) pg.ensure_required_libs(m_ansible_module) assert m_ansible_module.err_msg == '' def test_ensure_req_libs_ca_cert(self, m_ansible_module, m_psycopg2, monkeypatch): """ Test with module.params['ca_cert'], psycopg2 version is suitable. """ m_ansible_module.params['ca_cert'] = True monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) pg.ensure_required_libs(m_ansible_module) assert m_ansible_module.err_msg == '' def test_ensure_req_libs_ca_cert_low_psycopg2_ver(self, m_ansible_module, m_psycopg2, monkeypatch): """ Test with module.params['ca_cert'], psycopg2 version is wrong. """ m_ansible_module.params['ca_cert'] = True monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) # Set wrong psycopg2 version number: psycopg2 = m_psycopg2 psycopg2.__version__ = '2.4.2' monkeypatch.setattr(pg, 'psycopg2', psycopg2) pg.ensure_required_libs(m_ansible_module) assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg @pytest.fixture(scope='class') def m_ansible_module(): """Return an object of dummy AnsibleModule class.""" class DummyAnsibleModule(): def __init__(self): self.params = pg.postgres_common_argument_spec() self.err_msg = '' self.warn_msg = '' def fail_json(self, msg): self.err_msg = msg def warn(self, msg): self.warn_msg = msg return DummyAnsibleModule() class TestConnectToDb(): """ Namespace for testing connect_to_db() function. When some connection errors occure connect_to_db() caught any of them and invoke fail_json() or warn() methods of AnsibleModule object depending on the passed parameters. connect_to_db may return db_connection object or None if errors occured. Therefore we must check: 1. Values of err_msg and warn_msg attributes of m_ansible_module mock object. 2. Types of return objects (db_connection and cursor). """ def test_connect_to_db(self, m_ansible_module, monkeypatch, m_psycopg2): """Test connect_to_db(), common test.""" monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) db_connection = pg.connect_to_db(m_ansible_module, conn_params) cursor = db_connection.cursor() # if errors, db_connection returned as None: assert isinstance(db_connection, DbConnection) assert isinstance(cursor, Cursor) assert m_ansible_module.err_msg == '' # The default behaviour, normal in this case: assert 'Database name has not been passed' in m_ansible_module.warn_msg def test_session_role(self, m_ansible_module, monkeypatch, m_psycopg2): """Test connect_to_db(), switch on session_role.""" monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) m_ansible_module.params['session_role'] = 'test_role' conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) db_connection = pg.connect_to_db(m_ansible_module, conn_params) cursor = db_connection.cursor() # if errors, db_connection returned as None: assert isinstance(db_connection, DbConnection) assert isinstance(cursor, Cursor) assert m_ansible_module.err_msg == '' # The default behaviour, normal in this case: assert 'Database name has not been passed' in m_ansible_module.warn_msg def test_fail_on_conn_true(self, m_ansible_module, monkeypatch, m_psycopg2): """ Test connect_to_db(), fail_on_conn arg passed as True (the default behavior). """ monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) m_ansible_module.params['login_user'] = 'Exception' # causes Exception conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=True) assert 'unable to connect to database' in m_ansible_module.err_msg assert db_connection is None def test_fail_on_conn_false(self, m_ansible_module, monkeypatch, m_psycopg2): """ Test connect_to_db(), fail_on_conn arg passed as False. """ monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) m_ansible_module.params['login_user'] = 'Exception' # causes Exception conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=False) assert m_ansible_module.err_msg == '' assert 'PostgreSQL server is unavailable' in m_ansible_module.warn_msg assert db_connection is None def test_autocommit_true(self, m_ansible_module, monkeypatch, m_psycopg2): """ Test connect_to_db(), autocommit arg passed as True (the default is False). """ monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) # case 1: psycopg2.__version >= 2.4.2 (the default in m_psycopg2) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True) cursor = db_connection.cursor() # if errors, db_connection returned as None: assert isinstance(db_connection, DbConnection) assert isinstance(cursor, Cursor) assert m_ansible_module.err_msg == '' # case 2: psycopg2.__version < 2.4.2 m_psycopg2.__version__ = '2.4.1' monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True) cursor = db_connection.cursor() # if errors, db_connection returned as None: assert isinstance(db_connection, DbConnection) assert isinstance(cursor, Cursor) assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg class TestGetConnParams(): """Namespace for testing get_conn_params() function.""" def test_get_conn_params_def(self, m_ansible_module): """Test get_conn_params(), warn_db_default kwarg is default.""" assert pg.get_conn_params(m_ansible_module, INPUT_DICT) == EXPECTED_DICT assert m_ansible_module.warn_msg == 'Database name has not been passed, used default database to connect to.' def test_get_conn_params_warn_db_def_false(self, m_ansible_module): """Test get_conn_params(), warn_db_default kwarg is False.""" assert pg.get_conn_params(m_ansible_module, INPUT_DICT, warn_db_default=False) == EXPECTED_DICT assert m_ansible_module.warn_msg == ''