diff --git a/bin/ansible b/bin/ansible index ccd6bb8e17..59b4fec305 100755 --- a/bin/ansible +++ b/bin/ansible @@ -33,6 +33,7 @@ except Exception: pass import os +import shutil import sys import traceback @@ -40,6 +41,7 @@ import traceback from multiprocessing import Lock debug_lock = Lock() +import ansible.constants as C from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError from ansible.utils.display import Display from ansible.utils.unicode import to_unicode @@ -87,28 +89,28 @@ if __name__ == '__main__': cli = mycli(sys.argv) cli.parse() - sys.exit(cli.run()) + exit_code = cli.run() except AnsibleOptionsError as e: cli.parser.print_help() display.error(to_unicode(e), wrap_text=False) - sys.exit(5) + exit_code = 5 except AnsibleParserError as e: display.error(to_unicode(e), wrap_text=False) - sys.exit(4) + exit_code = 4 # TQM takes care of these, but leaving comment to reserve the exit codes # except AnsibleHostUnreachable as e: # display.error(str(e)) -# sys.exit(3) +# exit_code = 3 # except AnsibleHostFailed as e: # display.error(str(e)) -# sys.exit(2) +# exit_code = 2 except AnsibleError as e: display.error(to_unicode(e), wrap_text=False) - sys.exit(1) + exit_code = 1 except KeyboardInterrupt: display.error("User interrupted execution") - sys.exit(99) + exit_code = 99 except Exception as e: have_cli_options = cli is not None and cli.options is not None display.error("Unexpected Exception: %s" % to_unicode(e), wrap_text=False) @@ -116,4 +118,9 @@ if __name__ == '__main__': display.display(u"the full traceback was:\n\n%s" % to_unicode(traceback.format_exc())) else: display.display("to see the full traceback, use -vvv") - sys.exit(250) + exit_code = 250 + finally: + # Remove ansible tempdir + shutil.rmtree(C.DEFAULT_LOCAL_TMP, True) + + sys.exit(exit_code) diff --git a/docsite/rst/intro_configuration.rst b/docsite/rst/intro_configuration.rst index d5820a5b5f..9628735087 100644 --- a/docsite/rst/intro_configuration.rst +++ b/docsite/rst/intro_configuration.rst @@ -452,6 +452,22 @@ This is the default location Ansible looks to find modules:: Ansible knows how to look in multiple locations if you feed it a colon separated path, and it also will look for modules in the "./library" directory alongside a playbook. +.. _local_tmp: + +local_tmp +========= + +When Ansible gets ready to send a module to a remote machine it usually has to +add a few things to the module: Some boilerplate code, the module's +parameters, and a few constants from the config file. This combination of +things gets stored in a temporary file until ansible exits and cleans up after +itself. The default location is a subdirectory of the user's home directory. +If you'd like to change that, you can do so by altering this setting:: + + local_tmp = $HOME/.ansible/tmp + +Ansible will then choose a random directory name inside this location. + .. _log_path: log_path diff --git a/examples/ansible.cfg b/examples/ansible.cfg index 672c99dfc3..2b50d71e24 100644 --- a/examples/ansible.cfg +++ b/examples/ansible.cfg @@ -14,6 +14,7 @@ #inventory = /etc/ansible/hosts #library = /usr/share/my_modules/ #remote_tmp = $HOME/.ansible/tmp +#local_tmp = $HOME/.ansible/tmp #forks = 5 #poll_interval = 15 #sudo_user = root diff --git a/hacking/test-module b/hacking/test-module index a3a6aa6fab..5d852f1eae 100755 --- a/hacking/test-module +++ b/hacking/test-module @@ -29,12 +29,14 @@ # test-module -m ../library/file/lineinfile -a "dest=/etc/exports line='/srv/home hostname1(rw,sync)'" --check # test-module -m ../library/commands/command -a "echo hello" -n -o "test_hello" -import sys import base64 +from multiprocessing import Lock +import optparse import os import subprocess +import sys import traceback -import optparse + import ansible.utils.vars as utils_vars from ansible.parsing.dataloader import DataLoader from ansible.parsing.utils.jsonify import jsonify @@ -133,10 +135,12 @@ def boilerplate_module(modfile, args, interpreter, check, destfile): modname = os.path.basename(modfile) modname = os.path.splitext(modname)[0] + action_write_lock = Lock() (module_data, module_style, shebang) = module_common.modify_module( modname, modfile, complex_args, + action_write_lock, task_vars=task_vars ) diff --git a/lib/ansible/constants.py b/lib/ansible/constants.py index cd7659a0c6..514fda8160 100644 --- a/lib/ansible/constants.py +++ b/lib/ansible/constants.py @@ -20,6 +20,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type import os +import tempfile from string import ascii_letters, digits from ansible.compat.six import string_types @@ -47,7 +48,7 @@ def shell_expand(path): path = os.path.expanduser(os.path.expandvars(path)) return path -def get_config(p, section, key, env_var, default, boolean=False, integer=False, floating=False, islist=False, isnone=False, ispath=False, ispathlist=False): +def get_config(p, section, key, env_var, default, boolean=False, integer=False, floating=False, islist=False, isnone=False, ispath=False, ispathlist=False, istmppath=False): ''' return a configuration variable with casting ''' value = _get_config(p, section, key, env_var, default) if boolean: @@ -65,6 +66,11 @@ def get_config(p, section, key, env_var, default, boolean=False, integer=False, value = None elif ispath: value = shell_expand(value) + elif istmppath: + value = shell_expand(value) + if not os.path.exists(value): + os.makedirs(value, 0o700) + value = tempfile.mkdtemp(prefix='ansible-local-tmp', dir=value) elif ispathlist: if isinstance(value, string_types): value = [shell_expand(x) for x in value.split(os.pathsep)] @@ -136,6 +142,7 @@ DEFAULT_HOST_LIST = get_config(p, DEFAULTS,'inventory', 'ANSIBLE_INVENTO DEFAULT_MODULE_PATH = get_config(p, DEFAULTS, 'library', 'ANSIBLE_LIBRARY', None, ispathlist=True) DEFAULT_ROLES_PATH = get_config(p, DEFAULTS, 'roles_path', 'ANSIBLE_ROLES_PATH', '/etc/ansible/roles', ispathlist=True) DEFAULT_REMOTE_TMP = get_config(p, DEFAULTS, 'remote_tmp', 'ANSIBLE_REMOTE_TEMP', '$HOME/.ansible/tmp') +DEFAULT_LOCAL_TMP = get_config(p, DEFAULTS, 'local_tmp', 'ANSIBLE_LOCAL_TEMP', '$HOME/.ansible/tmp', istmppath=True) DEFAULT_MODULE_NAME = get_config(p, DEFAULTS, 'module_name', None, 'command') DEFAULT_FORKS = get_config(p, DEFAULTS, 'forks', 'ANSIBLE_FORKS', 5, integer=True) DEFAULT_MODULE_ARGS = get_config(p, DEFAULTS, 'module_args', 'ANSIBLE_MODULE_ARGS', '') diff --git a/lib/ansible/executor/module_common.py b/lib/ansible/executor/module_common.py index a22e43f240..50b4947b59 100644 --- a/lib/ansible/executor/module_common.py +++ b/lib/ansible/executor/module_common.py @@ -20,6 +20,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import ast import base64 import json import os @@ -32,6 +33,7 @@ from ansible import __version__ from ansible import constants as C from ansible.errors import AnsibleError from ansible.utils.unicode import to_bytes, to_unicode +from ansible.plugins.strategy import action_write_locks try: from __main__ import display @@ -48,7 +50,7 @@ REPLACER_SELINUX = b"<>" # We could end up writing out parameters with unicode characters so we need to # specify an encoding for the python source file -ENCODING_STRING = b'# -*- coding: utf-8 -*-' +ENCODING_STRING = u'# -*- coding: utf-8 -*-' # we've moved the module_common relative to the snippets, so fix the path _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils') @@ -56,7 +58,7 @@ _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils') # ****************************************************************************** ZIPLOADER_TEMPLATE = u'''%(shebang)s -# -*- coding: utf-8 -*-' +%(coding)s # This code is part of Ansible, but is an independent component. # The code in this particular templatable string, and this templatable string # only, is BSD licensed. Modules which end up using this snippet, which is @@ -87,17 +89,49 @@ ZIPLOADER_TEMPLATE = u'''%(shebang)s import os import sys import base64 +import shutil +import zipfile import tempfile import subprocess if sys.version_info < (3,): bytes = str + PY3 = False else: unicode = str + PY3 = True +try: + # Python-2.6+ + from io import BytesIO as IOStream +except ImportError: + # Python < 2.6 + from StringIO import StringIO as IOStream ZIPDATA = """%(zipdata)s""" -def debug(command, zipped_mod): +def invoke_module(module, modlib_path, json_params): + pythonpath = os.environ.get('PYTHONPATH') + if pythonpath: + os.environ['PYTHONPATH'] = ':'.join((modlib_path, pythonpath)) + else: + os.environ['PYTHONPATH'] = modlib_path + + p = subprocess.Popen(['%(interpreter)s', module], env=os.environ, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) + (stdout, stderr) = p.communicate(json_params) + + if not isinstance(stderr, (bytes, unicode)): + stderr = stderr.read() + if not isinstance(stdout, (bytes, unicode)): + stdout = stdout.read() + if PY3: + sys.stderr.buffer.write(stderr) + sys.stdout.buffer.write(stdout) + else: + sys.stderr.write(stderr) + sys.stdout.write(stdout) + return p.returncode + +def debug(command, zipped_mod, json_params): # The code here normally doesn't run. It's only used for debugging on the # remote machine. Run with ANSIBLE_KEEP_REMOTE_FILES=1 envvar and -vvv # to save the module file remotely. Login to the remote machine and use @@ -105,7 +139,7 @@ def debug(command, zipped_mod): # files. Edit the source files to instrument the code or experiment with # different values. Then use /path/to/module execute to run the extracted # files you've edited instead of the actual zipped module. - # + # Okay to use __file__ here because we're running from a kept file basedir = os.path.dirname(__file__) if command == 'explode': @@ -113,11 +147,11 @@ def debug(command, zipped_mod): # print the path to the code. This is an easy way for people to look # at the code on the remote machine for debugging it in that # environment - import zipfile z = zipfile.ZipFile(zipped_mod) for filename in z.namelist(): if filename.startswith('/'): raise Exception('Something wrong with this module zip file: should not contain absolute paths') + dest_filename = os.path.join(basedir, filename) if dest_filename.endswith(os.path.sep) and not os.path.exists(dest_filename): os.makedirs(dest_filename) @@ -128,26 +162,17 @@ def debug(command, zipped_mod): f = open(dest_filename, 'w') f.write(z.read(filename)) f.close() + print('Module expanded into:') print('%%s' %% os.path.join(basedir, 'ansible')) + exitcode = 0 + elif command == 'execute': # Execute the exploded code instead of executing the module from the # embedded ZIPDATA. This allows people to easily run their modified # code on the remote machine to see how changes will affect it. - pythonpath = os.environ.get('PYTHONPATH') - if pythonpath: - os.environ['PYTHONPATH'] = ':'.join((basedir, pythonpath)) - else: - os.environ['PYTHONPATH'] = basedir - p = subprocess.Popen(['%(interpreter)s', '-m', 'ansible.module_exec.%(ansible_module)s.__main__'], env=os.environ, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - (stdout, stderr) = p.communicate() - if not isinstance(stderr, (bytes, unicode)): - stderr = stderr.read() - if not isinstance(stdout, (bytes, unicode)): - stdout = stdout.read() - sys.stderr.write(stderr) - sys.stdout.write(stdout) - sys.exit(p.returncode) + exitcode = invoke_module(os.path.join(basedir, 'ansible_module_%(ansible_module)s.py'), basedir, json_params) + elif command == 'excommunicate': # This attempts to run the module in-process (by importing a main # function and then calling it). It is not the way ansible generally @@ -157,43 +182,78 @@ def debug(command, zipped_mod): # when using this that are only artifacts of how we're invoking here, # not actual bugs (as they don't affect the real way that we invoke # ansible modules) + sys.stdin = IOStream(json_params) sys.path.insert(0, basedir) - from ansible.module_exec.%(ansible_module)s.__main__ import main + from ansible_module_%(ansible_module)s import main main() - -os.environ['ANSIBLE_MODULE_ARGS'] = %(args)s -os.environ['ANSIBLE_MODULE_CONSTANTS'] = %(constants)s - -try: - temp_fd, temp_path = tempfile.mkstemp(prefix='ansible_') - os.write(temp_fd, base64.b64decode(ZIPDATA)) - if len(sys.argv) == 2: - debug(sys.argv[1], temp_path) + print('WARNING: Module returned to wrapper instead of exiting') + sys.exit(1) else: - pythonpath = os.environ.get('PYTHONPATH') - if pythonpath: - os.environ['PYTHONPATH'] = ':'.join((temp_path, pythonpath)) - else: - os.environ['PYTHONPATH'] = temp_path - p = subprocess.Popen(['%(interpreter)s', '-m', 'ansible.module_exec.%(ansible_module)s.__main__'], env=os.environ, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - (stdout, stderr) = p.communicate() - if not isinstance(stderr, (bytes, unicode)): - stderr = stderr.read() - if not isinstance(stdout, (bytes, unicode)): - stdout = stdout.read() - sys.stderr.write(stderr) - sys.stdout.write(stdout) - sys.exit(p.returncode) + print('WARNING: Unknown debug command. Doing nothing.') + exitcode = 0 -finally: + return exitcode + +if __name__ == '__main__': + ZIPLOADER_PARAMS = %(params)s + if PY3: + ZIPLOADER_PARAMS = ZIPLOADER_PARAMS.encode('utf-8') try: - os.close(temp_fd) - os.remove(temp_path) - except NameError: - # mkstemp failed - pass + temp_path = tempfile.mkdtemp(prefix='ansible_') + zipped_mod = os.path.join(temp_path, 'ansible_modlib.zip') + modlib = open(zipped_mod, 'wb') + modlib.write(base64.b64decode(ZIPDATA)) + modlib.close() + if len(sys.argv) == 2: + exitcode = debug(sys.argv[1], zipped_mod, ZIPLOADER_PARAMS) + else: + z = zipfile.ZipFile(zipped_mod) + module = os.path.join(temp_path, 'ansible_module_%(ansible_module)s.py') + f = open(module, 'wb') + f.write(z.read('ansible_module_%(ansible_module)s.py')) + f.close() + exitcode = invoke_module(module, zipped_mod, ZIPLOADER_PARAMS) + finally: + try: + shutil.rmtree(temp_path) + except OSError: + # tempdir creation probably failed + pass + sys.exit(exitcode) ''' +class ModuleDepFinder(ast.NodeVisitor): + # Caveats: + # This code currently does not handle: + # * relative imports from py2.6+ from . import urls + # * python packages (directories with __init__.py in them) + IMPORT_PREFIX_SIZE = len('ansible.module_utils.') + + def __init__(self, *args, **kwargs): + super(ModuleDepFinder, self).__init__(*args, **kwargs) + self.module_files = set() + + def visit_Import(self, node): + # import ansible.module_utils.MODLIB[.other] + for alias in (a for a in node.names if a.name.startswith('ansible.module_utils.')): + py_mod = alias.name[self.IMPORT_PREFIX_SIZE:].split('.', 1)[0] + self.module_files.add(py_mod) + self.generic_visit(node) + + def visit_ImportFrom(self, node): + if node.module.startswith('ansible.module_utils'): + where_from = node.module[self.IMPORT_PREFIX_SIZE:] + # from ansible.module_utils.MODLIB[.other] import foo + if where_from: + py_mod = where_from.split('.', 1)[0] + self.module_files.add(py_mod) + else: + # from ansible.module_utils import MODLIB + for alias in node.names: + self.module_files.add(alias.name) + self.generic_visit(node) + + def _strip_comments(source): # Strip comments and blank lines from the wrapper buf = [] @@ -242,6 +302,28 @@ def _get_facility(task_vars): facility = task_vars['ansible_syslog_facility'] return facility +def recursive_finder(data, snippet_names, snippet_data, zf): + """ + Using ModuleDepFinder, make sure we have all of the module_utils files that + the module its module_utils files needs. + """ + tree = ast.parse(data) + finder = ModuleDepFinder() + finder.visit(tree) + + new_snippets = set() + for snippet_name in finder.module_files.difference(snippet_names): + fname = '%s.py' % snippet_name + new_snippets.add(snippet_name) + if snippet_name not in snippet_data: + snippet_data[snippet_name] = _slurp(os.path.join(_SNIPPET_PATH, fname)) + zf.writestr(os.path.join("ansible/module_utils", fname), snippet_data[snippet_name]) + snippet_names.update(new_snippets) + + for snippet_name in tuple(new_snippets): + recursive_finder(snippet_data[snippet_name], snippet_names, snippet_data, zf) + del snippet_data[snippet_name] + def _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression): """ Given the source of the module, convert it to a Jinja2 template to insert @@ -280,59 +362,87 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta if module_style in ('old', 'non_native_want_json'): return module_data, module_style, shebang - module_args_json = to_bytes(json.dumps(module_args)) - output = BytesIO() - lines = module_data.split(b'\n') - snippet_names = set() if module_substyle == 'python': # ziploader for new-style python classes - python_repred_args = to_bytes(repr(module_args_json)) constants = dict( SELINUX_SPECIAL_FS=C.DEFAULT_SELINUX_SPECIAL_FS, SYSLOG_FACILITY=_get_facility(task_vars), ) - python_repred_constants = to_bytes(repr(json.dumps(constants)), errors='strict') + params = dict(ANSIBLE_MODULE_ARGS=module_args, + ANSIBLE_MODULE_CONSTANTS=constants, + ) + #python_repred_args = to_bytes(repr(module_args_json)) + #python_repred_constants = to_bytes(repr(json.dumps(constants)), errors='strict') + python_repred_params = to_bytes(repr(json.dumps(params)), errors='strict') try: compression_method = getattr(zipfile, module_compression) except AttributeError: display.warning(u'Bad module compression string specified: %s. Using ZIP_STORED (no compression)' % module_compression) compression_method = zipfile.ZIP_STORED - zipoutput = BytesIO() - zf = zipfile.ZipFile(zipoutput, mode='w', compression=compression_method) - zf.writestr('ansible/__init__.py', b''.join((b"__version__ = '", to_bytes(__version__), b"'\n"))) - zf.writestr('ansible/module_utils/__init__.py', b'') - zf.writestr('ansible/module_exec/__init__.py', b'') - zf.writestr('ansible/module_exec/%s/__init__.py' % module_name, b"") - final_data = [] + lookup_path = os.path.join(C.DEFAULT_LOCAL_TMP, 'ziploader_cache') + if not os.path.exists(lookup_path): + os.mkdir(lookup_path) + cached_module_filename = os.path.join(lookup_path, "%s-%s" % (module_name, module_compression)) - for line in lines: - if line.startswith(b'from ansible.module_utils.'): - tokens=line.split(b".") - snippet_name = tokens[2].split()[0] - snippet_names.add(snippet_name) - fname = to_unicode(snippet_name + b".py") - zf.writestr(os.path.join("ansible/module_utils", fname), _slurp(os.path.join(_SNIPPET_PATH, fname))) - final_data.append(line) - else: - final_data.append(line) + zipdata = None + # Optimization -- don't lock if the module has already been cached + if os.path.exists(cached_module_filename): + zipdata = open(cached_module_filename, 'rb').read() + # Fool the check later... I think we should just remove the check + snippet_names.add('basic') + else: + with action_write_locks[module_name]: + # Check that no other process has created this while we were + # waiting for the lock + if not os.path.exists(cached_module_filename): + # Create the module zip data + zipoutput = BytesIO() + zf = zipfile.ZipFile(zipoutput, mode='w', compression=compression_method) + zf.writestr('ansible/__init__.py', b''.join((b"__version__ = '", to_bytes(__version__), b"'\n"))) + zf.writestr('ansible/module_utils/__init__.py', b'') - zf.writestr('ansible/module_exec/%s/__main__.py' % module_name, b"\n".join(final_data)) - zf.close() + zf.writestr('ansible_module_%s.py' % module_name, module_data) + + snippet_data = dict() + recursive_finder(module_data, snippet_names, snippet_data, zf) + zf.close() + zipdata = base64.b64encode(zipoutput.getvalue()) + + # Write the assembled module to a temp file (write to temp + # so that no one looking for the file reads a partially + # written file) + with open(cached_module_filename + '-part', 'w') as f: + f.write(zipdata) + + # Rename the file into its final position in the cache so + # future users of this module can read it off the + # filesystem instead of constructing from scratch. + os.rename(cached_module_filename + '-part', cached_module_filename) + + if zipdata is None: + # Another process wrote the file while we were waiting for + # the write lock. Go ahead and read the data from disk + # instead of re-creating it. + zipdata = open(cached_module_filename, 'rb').read() + # Fool the check later... I think we should just remove the check + snippet_names.add('basic') shebang, interpreter = _get_shebang(u'/usr/bin/python', task_vars) if shebang is None: shebang = u'#!/usr/bin/python' output.write(to_bytes(STRIPPED_ZIPLOADER_TEMPLATE % dict( - zipdata=base64.b64encode(zipoutput.getvalue()), + zipdata=zipdata, ansible_module=module_name, - args=python_repred_args, - constants=python_repred_constants, + #args=python_repred_args, + #constants=python_repred_constants, + params=python_repred_params, shebang=shebang, interpreter=interpreter, + coding=ENCODING_STRING, ))) module_data = output.getvalue() @@ -340,11 +450,12 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta # modules that use ziploader may implement their own helpers and not # need basic.py. All the constants that we substituted into basic.py # for module_replacer are now available in other, better ways. - if b'basic' not in snippet_names: + if 'basic' not in snippet_names: raise AnsibleError("missing required import in %s: Did not import ansible.module_utils.basic for boilerplate helper code" % module_path) elif module_substyle == 'powershell': # Module replacer for jsonargs and windows + lines = module_data.split(b'\n') for line in lines: if REPLACER_WINDOWS in line: ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1")) @@ -353,6 +464,8 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta continue output.write(line + b'\n') module_data = output.getvalue() + + module_args_json = to_bytes(json.dumps(module_args)) module_data = module_data.replace(REPLACER_JSONARGS, module_args_json) # Sanity check from 1.x days. This is currently useless as we only @@ -363,11 +476,14 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path) elif module_substyle == 'jsonargs': + module_args_json = to_bytes(json.dumps(module_args)) + # these strings could be included in a third-party module but # officially they were included in the 'basic' snippet for new-style # python modules (which has been replaced with something else in # ziploader) If we remove them from jsonargs-style module replacer # then we can remove them everywhere. + python_repred_args = to_bytes(repr(module_args_json)) module_data = module_data.replace(REPLACER_VERSION, to_bytes(repr(__version__))) module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args) module_data = module_data.replace(REPLACER_SELINUX, to_bytes(','.join(C.DEFAULT_SELINUX_SPECIAL_FS))) @@ -409,17 +525,6 @@ def modify_module(module_name, module_path, module_args, task_vars=dict(), modul which results in the inclusion of the common code from powershell.ps1 """ - ### TODO: Optimization ideas if this code is actually a source of slowness: - # * Fix comment stripping: Currently doesn't preserve shebangs and encoding info (but we unconditionally add encoding info) - # * Use pyminifier if installed - # * comment stripping/pyminifier needs to have config setting to turn it - # off for debugging purposes (goes along with keep remote but should be - # separate otherwise users wouldn't be able to get info on what the - # minifier output) - # * Only split into lines and recombine into strings once - # * Cache the modified module? If only the args are different and we do - # that as the last step we could cache all the work up to that point. - with open(module_path, 'rb') as f: # read in the module source @@ -440,7 +545,7 @@ def modify_module(module_name, module_path, module_args, task_vars=dict(), modul lines[0] = shebang = new_shebang if os.path.basename(interpreter).startswith(b'python'): - lines.insert(1, ENCODING_STRING) + lines.insert(1, to_bytes(ENCODING_STRING)) else: # No shebang, assume a binary module? pass diff --git a/lib/ansible/executor/process/worker.py b/lib/ansible/executor/process/worker.py index ce5de460a1..d1bc56f637 100644 --- a/lib/ansible/executor/process/worker.py +++ b/lib/ansible/executor/process/worker.py @@ -95,12 +95,9 @@ class WorkerProcess(multiprocessing.Process): def run(self): ''' - Called when the process is started, and loops indefinitely - until an error is encountered (typically an IOerror from the - queue pipe being disconnected). During the loop, we attempt - to pull tasks off the job queue and run them, pushing the result - onto the results queue. We also remove the host from the blocked - hosts list, to signify that they are ready for their next task. + Called when the process is started. Pushes the result onto the + results queue. We also remove the host from the blocked hosts list, to + signify that they are ready for their next task. ''' if HAS_ATFORK: diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index f202dba028..983fcb7ec6 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -223,23 +223,6 @@ from ansible import __version__ # Backwards compat. New code should just import and use __version__ ANSIBLE_VERSION = __version__ -try: - # MODULE_COMPLEX_ARGS is an old name kept for backwards compat - MODULE_COMPLEX_ARGS = os.environ.pop('ANSIBLE_MODULE_ARGS') -except KeyError: - # This file might be used for its utility functions. So don't fail if - # running outside of a module environment (will fail in _load_params() - # instead) - MODULE_COMPLEX_ARGS = None - -try: - # ARGS are for parameters given in the playbook. Constants are for things - # that ansible needs to configure controller side but are passed to all - # modules. - MODULE_CONSTANTS = os.environ.pop('ANSIBLE_MODULE_CONSTANTS') -except KeyError: - MODULE_CONSTANTS = None - FILE_COMMON_ARGUMENTS=dict( src = dict(), mode = dict(type='raw'), @@ -560,7 +543,6 @@ class AnsibleModule(object): if k not in self.argument_spec: self.argument_spec[k] = v - self._load_constants() self._load_params() self._set_fallbacks() @@ -1452,32 +1434,29 @@ class AnsibleModule(object): continue def _load_params(self): - ''' read the input and set the params attribute''' - if MODULE_COMPLEX_ARGS is None: - # This helper used too early for fail_json to work. - print('{"msg": "Error: ANSIBLE_MODULE_ARGS not found in environment. Unable to figure out what parameters were passed", "failed": true}') - sys.exit(1) - - params = json_dict_unicode_to_bytes(json.loads(MODULE_COMPLEX_ARGS)) - if params is None: - params = dict() - self.params = params - - def _load_constants(self): - ''' read the input and set the constants attribute''' - if MODULE_CONSTANTS is None: - # This helper used too early for fail_json to work. - print('{"msg": "Error: ANSIBLE_MODULE_CONSTANTS not found in environment. Unable to figure out what constants were passed", "failed": true}') - sys.exit(1) - - # Make constants into "native string" - if sys.version_info >= (3,): - constants = json_dict_bytes_to_unicode(json.loads(MODULE_CONSTANTS)) + ''' read the input and set the params attribute. Sets the constants as well.''' + # Avoid tracebacks when locale is non-utf8 + if sys.version_info < (3,): + buffer = sys.stdin.read() else: - constants = json_dict_unicode_to_bytes(json.loads(MODULE_CONSTANTS)) - if constants is None: - constants = dict() - self.constants = constants + buffer = sys.stdin.buffer.read() + try: + params = json.loads(buffer.decode('utf-8')) + except ValueError: + # This helper used too early for fail_json to work. + print('{"msg": "Error: Module unable to decode valid JSON on stdin. Unable to figure out what parameters were passed", "failed": true}') + sys.exit(1) + + if sys.version_info < (3,): + params = json_dict_unicode_to_bytes(params) + + try: + self.params = params['ANSIBLE_MODULE_ARGS'] + self.constants = params['ANSIBLE_MODULE_CONSTANTS'] + except KeyError: + # This helper used too early for fail_json to work. + print('{"msg": "Error: Module unable to locate ANSIBLE_MODULE_ARGS and ANSIBLE_MODULE_CONSTANTS in json data from stdin. Unable to figure out what parameters were passed", "failed": true}') + sys.exit(1) def _log_to_syslog(self, msg): if HAS_SYSLOG: diff --git a/lib/ansible/plugins/strategy/__init__.py b/lib/ansible/plugins/strategy/__init__.py index 7b2a40a71a..f06d4f6f75 100644 --- a/lib/ansible/plugins/strategy/__init__.py +++ b/lib/ansible/plugins/strategy/__init__.py @@ -19,15 +19,16 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -from ansible.compat.six.moves import queue as Queue -from ansible.compat.six import iteritems, text_type, string_types - import json import time import zlib +from collections import defaultdict +from multiprocessing import Lock from jinja2.exceptions import UndefinedError +from ansible.compat.six.moves import queue as Queue +from ansible.compat.six import iteritems, text_type, string_types from ansible import constants as C from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable from ansible.executor.play_iterator import PlayIterator @@ -51,6 +52,8 @@ except ImportError: __all__ = ['StrategyBase'] +action_write_locks = defaultdict(Lock) + # TODO: this should probably be in the plugins/__init__.py, with # a smarter mechanism to set all of the attributes based on @@ -141,6 +144,20 @@ class StrategyBase: display.debug("entering _queue_task() for %s/%s" % (host, task)) + # Add a write lock for tasks. + # Maybe this should be added somewhere further up the call stack but + # this is the earliest in the code where we have task (1) extracted + # into its own variable and (2) there's only a single code path + # leading to the module being run. This is called by three + # functions: __init__.py::_do_handler_run(), linear.py::run(), and + # free.py::run() so we'd have to add to all three to do it there. + # The next common higher level is __init__.py::run() and that has + # tasks inside of play_iterator so we'd have to extract them to do it + # there. + if not action_write_locks[task.action]: + display.warning('Python defaultdict did not create the Lock for us. Creating manually') + action_write_locks[task.action] = Lock() + # and then queue the new task display.debug("%s - putting task (%s) in queue" % (host, task)) try: diff --git a/test/units/module_utils/basic/test__log_invocation.py b/test/units/module_utils/basic/test__log_invocation.py index 5e7524e360..34037f963c 100644 --- a/test/units/module_utils/basic/test__log_invocation.py +++ b/test/units/module_utils/basic/test__log_invocation.py @@ -22,18 +22,38 @@ __metaclass__ = type import sys import json +from io import BytesIO, StringIO + +from ansible.compat.six import PY3 +from ansible.utils.unicode import to_bytes from ansible.compat.tests import unittest from ansible.compat.tests.mock import MagicMock class TestModuleUtilsBasic(unittest.TestCase): + def setUp(self): + self.real_stdin = sys.stdin + args = json.dumps( + dict( + ANSIBLE_MODULE_ARGS=dict( + foo=False, bar=[1,2,3], bam="bam", baz=u'baz'), + ANSIBLE_MODULE_CONSTANTS=dict() + ) + ) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + + def tearDown(self): + sys.stdin = self.real_stdin + @unittest.skipIf(sys.version_info[0] >= 3, "Python 3 is not supported on targets (yet)") def test_module_utils_basic__log_invocation(self): from ansible.module_utils import basic # test basic log invocation - basic.MODULE_COMPLEX_ARGS = json.dumps(dict(foo=False, bar=[1,2,3], bam="bam", baz=u'baz')) - basic.MODULE_CONSTANTS = '{}' am = basic.AnsibleModule( argument_spec=dict( foo = dict(default=True, type='bool'), diff --git a/test/units/module_utils/basic/test_exit_json.py b/test/units/module_utils/basic/test_exit_json.py index ffb98e0b58..249dc380d9 100644 --- a/test/units/module_utils/basic/test_exit_json.py +++ b/test/units/module_utils/basic/test_exit_json.py @@ -23,8 +23,10 @@ __metaclass__ = type import copy import json import sys -from io import BytesIO +from io import BytesIO, StringIO +from ansible.compat.six import PY3 +from ansible.utils.unicode import to_bytes from ansible.compat.tests import unittest from ansible.module_utils import basic @@ -37,9 +39,13 @@ empty_invocation = {u'module_args': {}} class TestAnsibleModuleExitJson(unittest.TestCase): def setUp(self): - self.COMPLEX_ARGS = basic.MODULE_COMPLEX_ARGS - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + self.old_stdin = sys.stdin + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) self.old_stdout = sys.stdout self.fake_stream = BytesIO() @@ -48,8 +54,8 @@ class TestAnsibleModuleExitJson(unittest.TestCase): self.module = basic.AnsibleModule(argument_spec=dict()) def tearDown(self): - basic.MODULE_COMPLEX_ARGS = self.COMPLEX_ARGS sys.stdout = self.old_stdout + sys.stdin = self.old_stdin def test_exit_json_no_args_exits(self): with self.assertRaises(SystemExit) as ctx: @@ -118,19 +124,31 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase): ) def setUp(self): - self.COMPLEX_ARGS = basic.MODULE_COMPLEX_ARGS + self.old_stdin = sys.stdin + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + self.old_stdout = sys.stdout def tearDown(self): - basic.MODULE_COMPLEX_ARGS = self.COMPLEX_ARGS + sys.stdin = self.old_stdin sys.stdout = self.old_stdout def test_exit_json_removes_values(self): self.maxDiff = None for args, return_val, expected in self.dataset: sys.stdout = BytesIO() - basic.MODULE_COMPLEX_ARGS = json.dumps(args) - basic.MODULE_CONSTANTS = '{}' + params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={}) + params = json.dumps(params) + if PY3: + sys.stdin = StringIO(params) + sys.stdin.buffer = BytesIO(to_bytes(params)) + else: + sys.stdin = BytesIO(to_bytes(params)) module = basic.AnsibleModule( argument_spec = dict( username=dict(), @@ -149,8 +167,13 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase): del expected['changed'] expected['failed'] = True sys.stdout = BytesIO() - basic.MODULE_COMPLEX_ARGS = json.dumps(args) - basic.MODULE_CONSTANTS = '{}' + params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={}) + params = json.dumps(params) + if PY3: + sys.stdin = StringIO(params) + sys.stdin.buffer = BytesIO(to_bytes(params)) + else: + sys.stdin = BytesIO(to_bytes(params)) module = basic.AnsibleModule( argument_spec = dict( username=dict(), diff --git a/test/units/module_utils/basic/test_log.py b/test/units/module_utils/basic/test_log.py index 0a78ffb96d..0452ce7d90 100644 --- a/test/units/module_utils/basic/test_log.py +++ b/test/units/module_utils/basic/test_log.py @@ -21,7 +21,12 @@ from __future__ import (absolute_import, division) __metaclass__ = type import sys +import json import syslog +from io import BytesIO, StringIO + +from ansible.compat.six import PY3 +from ansible.utils.unicode import to_bytes from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, MagicMock @@ -41,10 +46,14 @@ except ImportError: class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase): def setUp(self): - self.complex_args_token = basic.MODULE_COMPLEX_ARGS - self.constants_sentinel = basic.MODULE_CONSTANTS - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + self.real_stdin = sys.stdin + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + self.am = basic.AnsibleModule( argument_spec = dict(), ) @@ -55,8 +64,7 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase): basic.has_journal = False def tearDown(self): - basic.MODULE_COMPLEX_ARGS = self.complex_args_token - basic.MODULE_CONSTANTS = self.constants_sentinel + sys.stdin = self.real_stdin basic.has_journal = self.has_journal def test_smoketest_syslog(self): @@ -75,17 +83,21 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase): class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase): def setUp(self): - self.complex_args_token = basic.MODULE_COMPLEX_ARGS - self.constants_sentinel = basic.MODULE_CONSTANTS - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + self.real_stdin = sys.stdin + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + + self.am = basic.AnsibleModule( argument_spec = dict(), ) def tearDown(self): - basic.MODULE_COMPLEX_ARGS = self.complex_args_token - basic.MODULE_CONSTANTS = self.constants_sentinel + sys.stdin = self.real_stdin @unittest.skipUnless(basic.has_journal, 'python systemd bindings not installed') def test_smoketest_journal(self): @@ -121,10 +133,15 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase): } def setUp(self): - self.complex_args_token = basic.MODULE_COMPLEX_ARGS - self.constants_sentinel = basic.MODULE_CONSTANTS - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + self.real_stdin = sys.stdin + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + + self.am = basic.AnsibleModule( argument_spec = dict(), ) @@ -134,8 +151,7 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase): basic.has_journal = False def tearDown(self): - basic.MODULE_COMPLEX_ARGS = self.complex_args_token - basic.MODULE_CONSTANTS = self.constants_sentinel + sys.stdin = self.real_stdin basic.has_journal = self.has_journal @patch('syslog.syslog', autospec=True) @@ -176,10 +192,14 @@ class TestAnsibleModuleLogJournal(unittest.TestCase): } def setUp(self): - self.complex_args_token = basic.MODULE_COMPLEX_ARGS - self.constants_sentinel = basic.MODULE_CONSTANTS - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + self.real_stdin = sys.stdin + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + self.am = basic.AnsibleModule( argument_spec = dict(), ) @@ -198,8 +218,8 @@ class TestAnsibleModuleLogJournal(unittest.TestCase): self._fake_out_reload(basic) def tearDown(self): - basic.MODULE_COMPLEX_ARGS = self.complex_args_token - basic.MODULE_CONSTANTS = self.constants_sentinel + sys.stdin = self.real_stdin + basic.has_journal = self.has_journal if self.module_patcher: self.module_patcher.stop() diff --git a/test/units/module_utils/basic/test_run_command.py b/test/units/module_utils/basic/test_run_command.py index 8a17f7e55a..3c56365816 100644 --- a/test/units/module_utils/basic/test_run_command.py +++ b/test/units/module_utils/basic/test_run_command.py @@ -20,9 +20,13 @@ from __future__ import (absolute_import, division) __metaclass__ = type import errno +import json import sys import time -from io import BytesIO +from io import BytesIO, StringIO + +from ansible.compat.six import PY3 +from ansible.utils.unicode import to_bytes from ansible.compat.tests import unittest from ansible.compat.tests.mock import call, MagicMock, Mock, patch, sentinel @@ -61,8 +65,12 @@ class TestAnsibleModuleRunCommand(unittest.TestCase): if path == '/inaccessible': raise OSError(errno.EPERM, "Permission denied: '/inaccessible'") - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) self.module = AnsibleModule(argument_spec=dict()) self.module.fail_json = MagicMock(side_effect=SystemExit) diff --git a/test/units/module_utils/basic/test_safe_eval.py b/test/units/module_utils/basic/test_safe_eval.py index cb28e9063f..36e9e1e399 100644 --- a/test/units/module_utils/basic/test_safe_eval.py +++ b/test/units/module_utils/basic/test_safe_eval.py @@ -20,16 +20,32 @@ from __future__ import (absolute_import, division) __metaclass__ = type -from ansible.compat.tests import unittest +import sys +import json +from io import BytesIO, StringIO +from ansible.compat.tests import unittest +from ansible.compat.six import PY3 +from ansible.utils.unicode import to_bytes class TestAnsibleModuleExitJson(unittest.TestCase): + def setUp(self): + self.real_stdin = sys.stdin + + def tearDown(self): + sys.stdin = self.real_stdin + def test_module_utils_basic_safe_eval(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec=dict(), ) diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index 39d9efe065..f8c96c6536 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -21,14 +21,19 @@ from __future__ import (absolute_import, division) __metaclass__ = type import errno +import json import os import sys +from io import BytesIO, StringIO try: import builtins except ImportError: import __builtin__ as builtins +from ansible.compat.six import PY3 +from ansible.utils.unicode import to_bytes + from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, MagicMock, mock_open, Mock, call @@ -37,10 +42,10 @@ realimport = builtins.__import__ class TestModuleUtilsBasic(unittest.TestCase): def setUp(self): - pass + self.real_stdin = sys.stdin def tearDown(self): - pass + sys.stdin = self.real_stdin def clear_modules(self, mods): for mod in mods: @@ -266,8 +271,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_creation(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec=dict(), ) @@ -282,8 +292,13 @@ class TestModuleUtilsBasic(unittest.TestCase): req_to = (('bam', 'baz'),) # should test ok - basic.MODULE_COMPLEX_ARGS = '{"foo":"hello"}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = arg_spec, mutually_exclusive = mut_ex, @@ -297,8 +312,13 @@ class TestModuleUtilsBasic(unittest.TestCase): # FIXME: add asserts here to verify the basic config # fail, because a required param was not specified - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + self.assertRaises( SystemExit, basic.AnsibleModule, @@ -312,8 +332,13 @@ class TestModuleUtilsBasic(unittest.TestCase): ) # fail because of mutually exclusive parameters - basic.MODULE_COMPLEX_ARGS = '{"foo":"hello", "bar": "bad", "bam": "bad"}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + self.assertRaises( SystemExit, basic.AnsibleModule, @@ -327,8 +352,13 @@ class TestModuleUtilsBasic(unittest.TestCase): ) # fail because a param required due to another param was not specified - basic.MODULE_COMPLEX_ARGS = '{"bam":"bad"}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + self.assertRaises( SystemExit, basic.AnsibleModule, @@ -344,8 +374,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_load_file_common_arguments(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -394,8 +429,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_mls_enabled(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -415,8 +455,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_initial_context(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -430,8 +475,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_enabled(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -463,8 +513,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_default_context(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -500,8 +555,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_context(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -543,8 +603,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_is_special_selinux_path(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -585,20 +650,30 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_to_filesystem_str(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) self.assertEqual(am._to_filesystem_str(u'foo'), b'foo') self.assertEqual(am._to_filesystem_str(u'föö'), b'f\xc3\xb6\xc3\xb6') - + def test_module_utils_basic_ansible_module_user_and_group(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -613,8 +688,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_find_mount_point(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -638,8 +718,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_context_if_different(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -684,8 +769,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_owner_if_different(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -724,7 +814,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_group_if_different(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -763,8 +859,13 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_mode_if_different(self): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -852,8 +953,13 @@ class TestModuleUtilsBasic(unittest.TestCase): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) @@ -1031,8 +1137,13 @@ class TestModuleUtilsBasic(unittest.TestCase): from ansible.module_utils import basic - basic.MODULE_COMPLEX_ARGS = '{}' - basic.MODULE_CONSTANTS = '{}' + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + if PY3: + sys.stdin = StringIO(args) + sys.stdin.buffer = BytesIO(to_bytes(args)) + else: + sys.stdin = BytesIO(to_bytes(args)) + am = basic.AnsibleModule( argument_spec = dict(), ) diff --git a/test/units/plugins/action/test_action.py b/test/units/plugins/action/test_action.py index dc6324d924..d0ea6ac789 100644 --- a/test/units/plugins/action/test_action.py +++ b/test/units/plugins/action/test_action.py @@ -212,14 +212,15 @@ class TestActionBase(unittest.TestCase): # test python module formatting with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))) as m: - mock_task.args = dict(a=1, foo='fö〩') - mock_connection.module_implementation_preferences = ('',) - (style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args) - self.assertEqual(style, "new") - self.assertEqual(shebang, b"#!/usr/bin/python") + with patch.object(os, 'rename') as m: + mock_task.args = dict(a=1, foo='fö〩') + mock_connection.module_implementation_preferences = ('',) + (style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args) + self.assertEqual(style, "new") + self.assertEqual(shebang, b"#!/usr/bin/python") - # test module not found - self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args) + # test module not found + self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args) # test powershell module formatting with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))) as m: