From 44e21f7062977b2cf6bd0c9fdab8b89a1ef896c4 Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Tue, 19 Apr 2016 20:08:01 -0700 Subject: [PATCH] Allow AnsibleModules to be instantiated more than once in a module Fix SELINUX monkeypatch in test_basic --- lib/ansible/module_utils/basic.py | 41 ++++++++++++------- .../module_utils/basic/test_exit_json.py | 3 ++ test/units/module_utils/basic/test_log.py | 4 ++ .../module_utils/basic/test_run_command.py | 1 + .../module_utils/basic/test_safe_eval.py | 7 ++++ test/units/module_utils/test_basic.py | 27 +++++++++++- 6 files changed, 66 insertions(+), 17 deletions(-) diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index 59e808afeb..7ae2cb941c 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -223,6 +223,12 @@ from ansible import __version__ # Backwards compat. New code should just import and use __version__ ANSIBLE_VERSION = __version__ +# Internal global holding passed in params and constants. This is consulted +# in case multiple AnsibleModules are created. Otherwise each AnsibleModule +# would attempt to read from stdin. Other code should not use this directly +# as it is an internal implementation detail +_ANSIBLE_ARGS = None + FILE_COMMON_ARGUMENTS=dict( src = dict(), mode = dict(type='raw'), @@ -1457,23 +1463,28 @@ class AnsibleModule(object): ''' read the input and set the params attribute. Sets the constants as well.''' # debug overrides to read args from file or cmdline - # Avoid tracebacks when locale is non-utf8 - # We control the args and we pass them as utf8 - if len(sys.argv) > 1: - if os.path.isfile(sys.argv[1]): - fd = open(sys.argv[1], 'rb') - buffer = fd.read() - fd.close() - else: - buffer = sys.argv[1] - if sys.version_info >= (3,): - buffer = buffer.encode('utf-8', errors='surrogateescape') - # default case, read from stdin + global _ANSIBLE_ARGS + if _ANSIBLE_ARGS is not None: + buffer = _ANSIBLE_ARGS else: - if sys.version_info < (3,): - buffer = sys.stdin.read() + # Avoid tracebacks when locale is non-utf8 + # We control the args and we pass them as utf8 + if len(sys.argv) > 1: + if os.path.isfile(sys.argv[1]): + fd = open(sys.argv[1], 'rb') + buffer = fd.read() + fd.close() + else: + buffer = sys.argv[1] + if sys.version_info >= (3,): + buffer = buffer.encode('utf-8', errors='surrogateescape') + # default case, read from stdin else: - buffer = sys.stdin.buffer.read() + if sys.version_info < (3,): + buffer = sys.stdin.read() + else: + buffer = sys.stdin.buffer.read() + _ANSIBLE_ARGS = buffer try: params = json.loads(buffer.decode('utf-8')) diff --git a/test/units/module_utils/basic/test_exit_json.py b/test/units/module_utils/basic/test_exit_json.py index 1bd25002d4..73a97ed687 100644 --- a/test/units/module_utils/basic/test_exit_json.py +++ b/test/units/module_utils/basic/test_exit_json.py @@ -45,6 +45,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase): self.stdout_swap_ctx = swap_stdout() self.fake_stream = self.stdout_swap_ctx.__enter__() + reload(basic) self.module = basic.AnsibleModule(argument_spec=dict()) def tearDown(self): @@ -125,6 +126,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase): params = json.dumps(params) with swap_stdin_and_argv(stdin_data=params): + reload(basic) with swap_stdout(): module = basic.AnsibleModule( argument_spec = dict( @@ -146,6 +148,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase): params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={}) params = json.dumps(params) with swap_stdin_and_argv(stdin_data=params): + reload(basic) with swap_stdout(): module = basic.AnsibleModule( argument_spec = dict( diff --git a/test/units/module_utils/basic/test_log.py b/test/units/module_utils/basic/test_log.py index e0258999db..f9a00bb55e 100644 --- a/test/units/module_utils/basic/test_log.py +++ b/test/units/module_utils/basic/test_log.py @@ -49,6 +49,7 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase): self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap.__enter__() + reload(basic) self.am = basic.AnsibleModule( argument_spec = dict(), ) @@ -85,6 +86,7 @@ class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase): self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap.__enter__() + reload(basic) self.am = basic.AnsibleModule( argument_spec = dict(), ) @@ -132,6 +134,7 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase): self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap.__enter__() + reload(basic) self.am = basic.AnsibleModule( argument_spec = dict(), ) @@ -192,6 +195,7 @@ class TestAnsibleModuleLogJournal(unittest.TestCase): self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap.__enter__() + reload(basic) self.am = basic.AnsibleModule( argument_spec = dict(), ) diff --git a/test/units/module_utils/basic/test_run_command.py b/test/units/module_utils/basic/test_run_command.py index 781dfa286e..aaf9b0e134 100644 --- a/test/units/module_utils/basic/test_run_command.py +++ b/test/units/module_utils/basic/test_run_command.py @@ -67,6 +67,7 @@ class TestAnsibleModuleRunCommand(unittest.TestCase): self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap.__enter__() + reload(basic) self.module = AnsibleModule(argument_spec=dict()) self.module.fail_json = MagicMock(side_effect=SystemExit) diff --git a/test/units/module_utils/basic/test_safe_eval.py b/test/units/module_utils/basic/test_safe_eval.py index 912dee17e3..e38ed9b46e 100644 --- a/test/units/module_utils/basic/test_safe_eval.py +++ b/test/units/module_utils/basic/test_safe_eval.py @@ -26,6 +26,12 @@ import json from ansible.compat.tests import unittest from units.mock.procenv import swap_stdin_and_argv +try: + from importlib import reload +except: + # Py2 has reload as a builtin + pass + class TestAnsibleModuleExitJson(unittest.TestCase): def test_module_utils_basic_safe_eval(self): @@ -34,6 +40,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) with swap_stdin_and_argv(stdin_data=args): + reload(basic) am = basic.AnsibleModule( argument_spec=dict(), ) diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index 8beefed6e8..662d3fe37c 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -31,6 +31,12 @@ try: except ImportError: import __builtin__ as builtins +try: + from importlib import reload +except: + # Py2 has reload as a builtin + pass + from units.mock.procenv import swap_stdin_and_argv from ansible.compat.tests import unittest @@ -291,6 +297,7 @@ class TestModuleUtilsBasic(unittest.TestCase): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={})) with swap_stdin_and_argv(stdin_data=args): + reload(basic) am = basic.AnsibleModule( argument_spec = arg_spec, mutually_exclusive = mut_ex, @@ -307,6 +314,7 @@ class TestModuleUtilsBasic(unittest.TestCase): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) with swap_stdin_and_argv(stdin_data=args): + reload(basic) self.assertRaises( SystemExit, basic.AnsibleModule, @@ -353,6 +361,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_load_file_common_arguments(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -401,6 +410,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_mls_enabled(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -420,6 +430,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_initial_context(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -433,6 +444,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_enabled(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -464,6 +476,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_default_context(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -499,6 +512,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_context(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -540,6 +554,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_is_special_selinux_path(self): from ansible.module_utils import basic + reload(basic) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"})) @@ -584,6 +599,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_to_filesystem_str(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -608,6 +624,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_find_mount_point(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -631,18 +648,19 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_context_if_different(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), ) - basic.HAS_SELINUX = False + basic.HAVE_SELINUX = False am.selinux_enabled = MagicMock(return_value=False) self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True), True) self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False), False) - basic.HAS_SELINUX = True + basic.HAVE_SELINUX = True am.selinux_enabled = MagicMock(return_value=True) am.selinux_context = MagicMock(return_value=['bar_u', 'bar_r', None, None]) @@ -675,6 +693,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_owner_if_different(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -713,6 +732,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_group_if_different(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -751,6 +771,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_mode_if_different(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -838,6 +859,7 @@ class TestModuleUtilsBasic(unittest.TestCase): ): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(), @@ -1015,6 +1037,7 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module__symbolic_mode_to_octal(self): from ansible.module_utils import basic + reload(basic) am = basic.AnsibleModule( argument_spec = dict(),