# -*- coding: utf-8 -*- # (c) 2020, Alexei Znamensky # Copyright: (c) 2020, Ansible Project # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) from __future__ import absolute_import, division, print_function __metaclass__ = type from functools import partial, wraps import traceback from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.dict_transformations import dict_merge class ModuleHelperException(Exception): @staticmethod def _get_remove(key, kwargs): if key in kwargs: result = kwargs[key] del kwargs[key] return result return None def __init__(self, *args, **kwargs): self.msg = self._get_remove('msg', kwargs) or "Module failed with exception: {0}".format(self) self.update_output = self._get_remove('update_output', kwargs) or {} super(ModuleHelperException, self).__init__(*args) class ArgFormat(object): """ Argument formatter for use as a command line parameter. Used in CmdMixin. """ BOOLEAN = 0 PRINTF = 1 FORMAT = 2 @staticmethod def stars_deco(num): if num == 1: def deco(f): return lambda v: f(*v) return deco elif num == 2: def deco(f): return lambda v: f(**v) return deco return lambda f: f def __init__(self, name, fmt=None, style=FORMAT, stars=0): """ Creates a CLI-formatter for one specific argument. The argument may be a module parameter or just a named parameter for the CLI command execution. :param name: Name of the argument to be formatted :param fmt: Either a str to be formatted (using or not printf-style) or a callable that does that :param style: Whether arg_format (as str) should use printf-style formatting. Ignored if arg_format is None or not a str (should be callable). :param stars: A int with 0, 1 or 2 value, indicating to formatting the value as: value, *value or **value """ def printf_fmt(_fmt, v): try: return [_fmt % v] except TypeError as e: if e.args[0] != 'not all arguments converted during string formatting': raise return [_fmt] _fmts = { ArgFormat.BOOLEAN: lambda _fmt, v: ([_fmt] if bool(v) else []), ArgFormat.PRINTF: printf_fmt, ArgFormat.FORMAT: lambda _fmt, v: [_fmt.format(v)], } self.name = name self.stars = stars if fmt is None: fmt = "{0}" style = ArgFormat.FORMAT if isinstance(fmt, str): func = _fmts[style] self.arg_format = partial(func, fmt) elif isinstance(fmt, list) or isinstance(fmt, tuple): self.arg_format = lambda v: [_fmts[style](f, v)[0] for f in fmt] elif hasattr(fmt, '__call__'): self.arg_format = fmt else: raise TypeError('Parameter fmt must be either: a string, a list/tuple of ' 'strings or a function: type={0}, value={1}'.format(type(fmt), fmt)) if stars: self.arg_format = (self.stars_deco(stars))(self.arg_format) def to_text(self, value): if value is None: return [] func = self.arg_format return [str(p) for p in func(value)] def cause_changes(on_success=None, on_failure=None): def deco(func): if on_success is None and on_failure is None: return func @wraps(func) def wrapper(*args, **kwargs): try: self = args[0] func(*args, **kwargs) if on_success is not None: self.changed = on_success except Exception: if on_failure is not None: self.changed = on_failure raise return wrapper return deco def module_fails_on_exception(func): @wraps(func) def wrapper(self, *args, **kwargs): try: func(self, *args, **kwargs) except SystemExit: raise except ModuleHelperException as e: if e.update_output: self.update_output(e.update_output) self.module.fail_json(msg=e.msg, exception=traceback.format_exc(), output=self.output, vars=self.vars.output(), **self.output) except Exception as e: msg = "Module failed with exception: {0}".format(str(e).strip()) self.module.fail_json(msg=msg, exception=traceback.format_exc(), output=self.output, vars=self.vars.output(), **self.output) return wrapper class DependencyCtxMgr(object): def __init__(self, name, msg=None): self.name = name self.msg = msg self.has_it = False self.exc_type = None self.exc_val = None self.exc_tb = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.has_it = exc_type is None self.exc_type = exc_type self.exc_val = exc_val self.exc_tb = exc_tb return not self.has_it @property def text(self): return self.msg or str(self.exc_val) class VarMeta(object): NOTHING = object() def __init__(self, diff=False, output=True, change=None, fact=False): self.init = False self.initial_value = None self.value = None self.diff = diff self.change = diff if change is None else change self.output = output self.fact = fact def set(self, diff=None, output=None, change=None, fact=None, initial_value=NOTHING): if diff is not None: self.diff = diff if output is not None: self.output = output if change is not None: self.change = change if fact is not None: self.fact = fact if initial_value is not self.NOTHING: self.initial_value = initial_value def set_value(self, value): if not self.init: self.initial_value = value self.init = True self.value = value return self @property def has_changed(self): return self.change and (self.initial_value != self.value) @property def diff_result(self): return None if not (self.diff and self.has_changed) else { 'before': self.initial_value, 'after': self.value, } def __str__(self): return "".format( self.value, self.initial_value, self.diff, self.output, self.change ) class ModuleHelper(object): _output_conflict_list = ('msg', 'exception', 'output', 'vars', 'changed') _dependencies = [] module = None facts_name = None output_params = () diff_params = () change_params = () facts_params = () class VarDict(object): def __init__(self): self._data = dict() self._meta = dict() def __getitem__(self, item): return self._data[item] def __setitem__(self, key, value): self.set(key, value) def __getattr__(self, item): try: return self._data[item] except KeyError: return getattr(self._data, item) def __setattr__(self, key, value): if key in ('_data', '_meta'): super(ModuleHelper.VarDict, self).__setattr__(key, value) else: self.set(key, value) def meta(self, name): return self._meta[name] def set_meta(self, name, **kwargs): self.meta(name).set(**kwargs) def set(self, name, value, **kwargs): if name in ('_data', '_meta'): raise ValueError("Names _data and _meta are reserved for use by ModuleHelper") self._data[name] = value if name in self._meta: meta = self.meta(name) else: meta = VarMeta(**kwargs) meta.set_value(value) self._meta[name] = meta def output(self): return dict((k, v) for k, v in self._data.items() if self.meta(k).output) def diff(self): diff_results = [(k, self.meta(k).diff_result) for k in self._data] diff_results = [dr for dr in diff_results if dr[1] is not None] if diff_results: before = dict((dr[0], dr[1]['before']) for dr in diff_results) after = dict((dr[0], dr[1]['after']) for dr in diff_results) return {'before': before, 'after': after} return None def facts(self): facts_result = dict((k, v) for k, v in self._data.items() if self._meta[k].fact) return facts_result if facts_result else None def change_vars(self): return [v for v in self._data if self.meta(v).change] def has_changed(self, v): return self._meta[v].has_changed def __init__(self, module=None): self.vars = ModuleHelper.VarDict() self._changed = False if module: self.module = module if not isinstance(self.module, AnsibleModule): self.module = AnsibleModule(**self.module) for name, value in self.module.params.items(): self.vars.set( name, value, diff=name in self.diff_params, output=name in self.output_params, change=None if not self.change_params else name in self.change_params, fact=name in self.facts_params, ) def update_vars(self, meta=None, **kwargs): if meta is None: meta = {} for k, v in kwargs.items(): self.vars.set(k, v, **meta) def update_output(self, **kwargs): self.update_vars(meta={"output": True}, **kwargs) def update_facts(self, **kwargs): self.update_vars(meta={"fact": True}, **kwargs) def __init_module__(self): pass def __run__(self): raise NotImplementedError() def __quit_module__(self): pass def _vars_changed(self): return any(self.vars.has_changed(v) for v in self.vars.change_vars()) @property def changed(self): return self._changed @changed.setter def changed(self, value): self._changed = value def has_changed(self): return self.changed or self._vars_changed() @property def output(self): result = dict(self.vars.output()) if self.facts_name: facts = self.vars.facts() if facts is not None: result['ansible_facts'] = {self.facts_name: facts} if self.module._diff: diff = result.get('diff', {}) vars_diff = self.vars.diff() or {} result['diff'] = dict_merge(dict(diff), vars_diff) for varname in result: if varname in self._output_conflict_list: result["_" + varname] = result[varname] del result[varname] return result @module_fails_on_exception def run(self): self.fail_on_missing_deps() self.__init_module__() self.__run__() self.__quit_module__() self.module.exit_json(changed=self.has_changed(), **self.output) @classmethod def dependency(cls, name, msg): cls._dependencies.append(DependencyCtxMgr(name, msg)) return cls._dependencies[-1] def fail_on_missing_deps(self): for d in self._dependencies: if not d.has_it: self.module.fail_json(changed=False, exception="\n".join(traceback.format_exception(d.exc_type, d.exc_val, d.exc_tb)), msg=d.text, **self.output) class StateMixin(object): state_param = 'state' default_state = None def _state(self): state = self.module.params.get(self.state_param) return self.default_state if state is None else state def _method(self, state): return "{0}_{1}".format(self.state_param, state) def __run__(self): state = self._state() self.vars.state = state # resolve aliases if state not in self.module.params: aliased = [name for name, param in self.module.argument_spec.items() if state in param.get('aliases', [])] if aliased: state = aliased[0] self.vars.effective_state = state method = self._method(state) if not hasattr(self, method): return self.__state_fallback__() func = getattr(self, method) return func() def __state_fallback__(self): raise ValueError("Cannot find method: {0}".format(self._method(self._state()))) class CmdMixin(object): """ Mixin for mapping module options to running a CLI command with its arguments. """ command = None command_args_formats = {} run_command_fixed_options = {} check_rc = False force_lang = "C" @property def module_formats(self): result = {} for param in self.module.params.keys(): result[param] = ArgFormat(param) return result @property def custom_formats(self): result = {} for param, fmt_spec in self.command_args_formats.items(): result[param] = ArgFormat(param, **fmt_spec) return result def _calculate_args(self, extra_params=None, params=None): def add_arg_formatted_param(_cmd_args, arg_format, _value): args = list(arg_format.to_text(_value)) return _cmd_args + args def find_format(_param): return self.custom_formats.get(_param, self.module_formats.get(_param)) extra_params = extra_params or dict() cmd_args = list([self.command]) if isinstance(self.command, str) else list(self.command) try: cmd_args[0] = self.module.get_bin_path(cmd_args[0], required=True) except ValueError: pass param_list = params if params else self.module.params.keys() for param in param_list: if isinstance(param, dict): if len(param) != 1: raise ModuleHelperException("run_command parameter as a dict must " "contain only one key: {0}".format(param)) _param = list(param.keys())[0] fmt = find_format(_param) value = param[_param] elif isinstance(param, str): if param in self.module.argument_spec: fmt = find_format(param) value = self.module.params[param] elif param in extra_params: fmt = find_format(param) value = extra_params[param] else: self.module.deprecate("Cannot determine value for parameter: {0}. " "From version 4.0.0 onwards this will generate an exception".format(param), version="4.0.0", collection_name="community.general") continue else: raise ModuleHelperException("run_command parameter must be either a str or a dict: {0}".format(param)) cmd_args = add_arg_formatted_param(cmd_args, fmt, value) return cmd_args def process_command_output(self, rc, out, err): return rc, out, err def run_command(self, extra_params=None, params=None, *args, **kwargs): self.vars.cmd_args = self._calculate_args(extra_params, params) options = dict(self.run_command_fixed_options) options['check_rc'] = options.get('check_rc', self.check_rc) options.update(kwargs) env_update = dict(options.get('environ_update', {})) if self.force_lang: env_update.update({ 'LANGUAGE': self.force_lang, 'LC_ALL': self.force_lang, }) self.update_output(force_lang=self.force_lang) options['environ_update'] = env_update rc, out, err = self.module.run_command(self.vars.cmd_args, *args, **options) self.update_output(rc=rc, stdout=out, stderr=err) return self.process_command_output(rc, out, err) class StateModuleHelper(StateMixin, ModuleHelper): pass class CmdModuleHelper(CmdMixin, ModuleHelper): pass class CmdStateModuleHelper(CmdMixin, StateMixin, ModuleHelper): pass