1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Allow AnsibleModules to be instantiated more than once in a module

Fix SELINUX monkeypatch in test_basic
This commit is contained in:
Toshio Kuratomi 2016-04-19 20:08:01 -07:00
parent 0f373c1767
commit 44e21f7062
6 changed files with 66 additions and 17 deletions

View file

@ -223,6 +223,12 @@ from ansible import __version__
# Backwards compat. New code should just import and use __version__ # Backwards compat. New code should just import and use __version__
ANSIBLE_VERSION = __version__ ANSIBLE_VERSION = __version__
# Internal global holding passed in params and constants. This is consulted
# in case multiple AnsibleModules are created. Otherwise each AnsibleModule
# would attempt to read from stdin. Other code should not use this directly
# as it is an internal implementation detail
_ANSIBLE_ARGS = None
FILE_COMMON_ARGUMENTS=dict( FILE_COMMON_ARGUMENTS=dict(
src = dict(), src = dict(),
mode = dict(type='raw'), mode = dict(type='raw'),
@ -1457,23 +1463,28 @@ class AnsibleModule(object):
''' read the input and set the params attribute. Sets the constants as well.''' ''' read the input and set the params attribute. Sets the constants as well.'''
# debug overrides to read args from file or cmdline # debug overrides to read args from file or cmdline
# Avoid tracebacks when locale is non-utf8 global _ANSIBLE_ARGS
# We control the args and we pass them as utf8 if _ANSIBLE_ARGS is not None:
if len(sys.argv) > 1: buffer = _ANSIBLE_ARGS
if os.path.isfile(sys.argv[1]):
fd = open(sys.argv[1], 'rb')
buffer = fd.read()
fd.close()
else:
buffer = sys.argv[1]
if sys.version_info >= (3,):
buffer = buffer.encode('utf-8', errors='surrogateescape')
# default case, read from stdin
else: else:
if sys.version_info < (3,): # Avoid tracebacks when locale is non-utf8
buffer = sys.stdin.read() # We control the args and we pass them as utf8
if len(sys.argv) > 1:
if os.path.isfile(sys.argv[1]):
fd = open(sys.argv[1], 'rb')
buffer = fd.read()
fd.close()
else:
buffer = sys.argv[1]
if sys.version_info >= (3,):
buffer = buffer.encode('utf-8', errors='surrogateescape')
# default case, read from stdin
else: else:
buffer = sys.stdin.buffer.read() if sys.version_info < (3,):
buffer = sys.stdin.read()
else:
buffer = sys.stdin.buffer.read()
_ANSIBLE_ARGS = buffer
try: try:
params = json.loads(buffer.decode('utf-8')) params = json.loads(buffer.decode('utf-8'))

View file

@ -45,6 +45,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase):
self.stdout_swap_ctx = swap_stdout() self.stdout_swap_ctx = swap_stdout()
self.fake_stream = self.stdout_swap_ctx.__enter__() self.fake_stream = self.stdout_swap_ctx.__enter__()
reload(basic)
self.module = basic.AnsibleModule(argument_spec=dict()) self.module = basic.AnsibleModule(argument_spec=dict())
def tearDown(self): def tearDown(self):
@ -125,6 +126,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
params = json.dumps(params) params = json.dumps(params)
with swap_stdin_and_argv(stdin_data=params): with swap_stdin_and_argv(stdin_data=params):
reload(basic)
with swap_stdout(): with swap_stdout():
module = basic.AnsibleModule( module = basic.AnsibleModule(
argument_spec = dict( argument_spec = dict(
@ -146,6 +148,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={}) params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={})
params = json.dumps(params) params = json.dumps(params)
with swap_stdin_and_argv(stdin_data=params): with swap_stdin_and_argv(stdin_data=params):
reload(basic)
with swap_stdout(): with swap_stdout():
module = basic.AnsibleModule( module = basic.AnsibleModule(
argument_spec = dict( argument_spec = dict(

View file

@ -49,6 +49,7 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__() self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -85,6 +86,7 @@ class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__() self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -132,6 +134,7 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__() self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -192,6 +195,7 @@ class TestAnsibleModuleLogJournal(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__() self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )

View file

@ -67,6 +67,7 @@ class TestAnsibleModuleRunCommand(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args) self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__() self.stdin_swap.__enter__()
reload(basic)
self.module = AnsibleModule(argument_spec=dict()) self.module = AnsibleModule(argument_spec=dict())
self.module.fail_json = MagicMock(side_effect=SystemExit) self.module.fail_json = MagicMock(side_effect=SystemExit)

View file

@ -26,6 +26,12 @@ import json
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from units.mock.procenv import swap_stdin_and_argv from units.mock.procenv import swap_stdin_and_argv
try:
from importlib import reload
except:
# Py2 has reload as a builtin
pass
class TestAnsibleModuleExitJson(unittest.TestCase): class TestAnsibleModuleExitJson(unittest.TestCase):
def test_module_utils_basic_safe_eval(self): def test_module_utils_basic_safe_eval(self):
@ -34,6 +40,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
with swap_stdin_and_argv(stdin_data=args): with swap_stdin_and_argv(stdin_data=args):
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec=dict(), argument_spec=dict(),
) )

View file

@ -31,6 +31,12 @@ try:
except ImportError: except ImportError:
import __builtin__ as builtins import __builtin__ as builtins
try:
from importlib import reload
except:
# Py2 has reload as a builtin
pass
from units.mock.procenv import swap_stdin_and_argv from units.mock.procenv import swap_stdin_and_argv
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
@ -291,6 +297,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={}))
with swap_stdin_and_argv(stdin_data=args): with swap_stdin_and_argv(stdin_data=args):
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = arg_spec, argument_spec = arg_spec,
mutually_exclusive = mut_ex, mutually_exclusive = mut_ex,
@ -307,6 +314,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
with swap_stdin_and_argv(stdin_data=args): with swap_stdin_and_argv(stdin_data=args):
reload(basic)
self.assertRaises( self.assertRaises(
SystemExit, SystemExit,
basic.AnsibleModule, basic.AnsibleModule,
@ -353,6 +361,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_load_file_common_arguments(self): def test_module_utils_basic_ansible_module_load_file_common_arguments(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -401,6 +410,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_mls_enabled(self): def test_module_utils_basic_ansible_module_selinux_mls_enabled(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -420,6 +430,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_initial_context(self): def test_module_utils_basic_ansible_module_selinux_initial_context(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -433,6 +444,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_enabled(self): def test_module_utils_basic_ansible_module_selinux_enabled(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -464,6 +476,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_default_context(self): def test_module_utils_basic_ansible_module_selinux_default_context(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -499,6 +512,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_context(self): def test_module_utils_basic_ansible_module_selinux_context(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -540,6 +554,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_is_special_selinux_path(self): def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"}))
@ -584,6 +599,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_to_filesystem_str(self): def test_module_utils_basic_ansible_module_to_filesystem_str(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -608,6 +624,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_find_mount_point(self): def test_module_utils_basic_ansible_module_find_mount_point(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -631,18 +648,19 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_context_if_different(self): def test_module_utils_basic_ansible_module_set_context_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
basic.HAS_SELINUX = False basic.HAVE_SELINUX = False
am.selinux_enabled = MagicMock(return_value=False) am.selinux_enabled = MagicMock(return_value=False)
self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True), True) self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True), True)
self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False), False) self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False), False)
basic.HAS_SELINUX = True basic.HAVE_SELINUX = True
am.selinux_enabled = MagicMock(return_value=True) am.selinux_enabled = MagicMock(return_value=True)
am.selinux_context = MagicMock(return_value=['bar_u', 'bar_r', None, None]) am.selinux_context = MagicMock(return_value=['bar_u', 'bar_r', None, None])
@ -675,6 +693,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_owner_if_different(self): def test_module_utils_basic_ansible_module_set_owner_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -713,6 +732,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_group_if_different(self): def test_module_utils_basic_ansible_module_set_group_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -751,6 +771,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_mode_if_different(self): def test_module_utils_basic_ansible_module_set_mode_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -838,6 +859,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
): ):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
@ -1015,6 +1037,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module__symbolic_mode_to_octal(self): def test_module_utils_basic_ansible_module__symbolic_mode_to_octal(self):
from ansible.module_utils import basic from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),