From 97a34cf008501088ae012f24e4712e2bc6dd6d4c Mon Sep 17 00:00:00 2001 From: Ganesh Nalawade Date: Tue, 1 Aug 2017 22:02:18 +0530 Subject: [PATCH] Add options sub spec validation (#27119) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add aggregate parameter validation aggregate parameter validation will support checking each individual dict to resolve conditions for aliases, no_log, mutually_exclusive, required, type check, values, required_together, required_one_of and required_if conditions in argspec. It will also set default values. eg: tasks: - name: Configure interface attribute with aggregate net_interface: aggregate: - {name: ge-0/0/1, description: test-interface-1, duplex: full, state: present} - {name: ge-0/0/2, description: test-interface-2, active: False} register: response purge: Yes Usage: ``` from ansible.module_utils.network_common import AggregateCollection transform = AggregateCollection(module) param = transform(module.params.get('aggregate')) ``` Aggregate allows supports for `purge` parameter, it will instruct the module to remove resources from remote device that hasn’t been explicitly defined in aggregate. This is not supported by with_* iterators Also, it improves performace as compared to with_* iterator for network device that has seperate candidate and running datastore. For with_* iteration the sequence of operartion is load-config-1 (candidate db) -> commit (running db) -> load_config-2 (candidate db) -> commit (running db) ... With aggregate the sequence of operation is load-config-1 (candidate db) -> load-config-2 (candidate db) -> commit (running db) As commit is executed only once per task for aggregate it has huge perfomance benefit for large configurations. * Fix CI issues * Fix review comments * Add support for options validation for aliases, no_log, mutually_exclusive, required, type check, value check, required_together, required_one_of and required_if conditions in sub-argspec. * Add unit test for options in argspec. * Reverted aggregate implementaion. * Minor change * Add multi-level argspec support * Multi-level argspec support with module's top most conditionals options. * Fix unit test failure * Add parent context in errors for sub options * Resolve merge conflict * Fix CI issue --- lib/ansible/module_utils/basic.py | 224 +++++++++++++++++++------- test/units/module_utils/test_basic.py | 211 ++++++++++++++++++++++++ 2 files changed, 377 insertions(+), 58 deletions(-) diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index 456cb4589e..94aec93974 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -790,7 +790,13 @@ class AnsibleModule(object): self.argument_spec = argument_spec self.supports_check_mode = supports_check_mode self.check_mode = False + self.bypass_checks = bypass_checks self.no_log = no_log + self.check_invalid_arguments = check_invalid_arguments + self.mutually_exclusive = mutually_exclusive + self.required_together = required_together + self.required_one_of = required_one_of + self.required_if = required_if self.cleanup_files = [] self._debug = False self._diff = False @@ -806,6 +812,7 @@ class AnsibleModule(object): self._legal_inputs = ['_ansible_check_mode', '_ansible_no_log', '_ansible_debug', '_ansible_diff', '_ansible_verbosity', '_ansible_selinux_special_fs', '_ansible_module_name', '_ansible_version', '_ansible_syslog_facility', '_ansible_socket'] + self._options_context = list() if add_file_common_args: for k, v in FILE_COMMON_ARGUMENTS.items(): @@ -826,19 +833,7 @@ class AnsibleModule(object): # Save parameter values that should never be logged self.no_log_values = set() - # Use the argspec to determine which args are no_log - for arg_name, arg_opts in self.argument_spec.items(): - if arg_opts.get('no_log', False): - # Find the value for the no_log'd param - no_log_object = self.params.get(arg_name, None) - if no_log_object: - self.no_log_values.update(return_values(no_log_object)) - - if arg_opts.get('removed_in_version') is not None and arg_name in self.params: - self._deprecations.append({ - 'msg': "Param '%s' is deprecated. See the module docs for more information" % arg_name, - 'version': arg_opts.get('removed_in_version') - }) + self._handle_no_log_values() # check the locale as set by the current environment, and reset to # a known valid (LANG=C) if it's an invalid/unavailable locale @@ -876,6 +871,9 @@ class AnsibleModule(object): self._set_defaults(pre=False) + # deal with options sub-spec + self._handle_options() + if not self.no_log: self._log_invocation() @@ -1530,9 +1528,12 @@ class AnsibleModule(object): e = get_exception() self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" % e) - def _handle_aliases(self, spec=None): + def _handle_aliases(self, spec=None, param=None): # this uses exceptions as it happens before we can safely call fail_json aliases_results = {} # alias:canon + if param is None: + param = self.params + if spec is None: spec = self.argument_spec for (k, v) in spec.items(): @@ -1550,15 +1551,42 @@ class AnsibleModule(object): for alias in aliases: self._legal_inputs.append(alias) aliases_results[alias] = k - if alias in self.params: - self.params[k] = self.params[alias] + if alias in param: + param[k] = param[alias] return aliases_results - def _check_arguments(self, check_invalid_arguments): + def _handle_no_log_values(self, spec=None, param=None): + if spec is None: + spec = self.argument_spec + if param is None: + param = self.params + + # Use the argspec to determine which args are no_log + for arg_name, arg_opts in spec.items(): + if arg_opts.get('no_log', False): + # Find the value for the no_log'd param + no_log_object = param.get(arg_name, None) + if no_log_object: + self.no_log_values.update(return_values(no_log_object)) + + if arg_opts.get('removed_in_version') is not None and arg_name in param: + self._deprecations.append({ + 'msg': "Param '%s' is deprecated. See the module docs for more information" % arg_name, + 'version': arg_opts.get('removed_in_version') + }) + + def _check_arguments(self, check_invalid_arguments, spec=None, param=None, legal_inputs=None): self._syslog_facility = 'LOG_USER' unsupported_parameters = set() - for (k, v) in list(self.params.items()): + if spec is None: + spec = self.argument_spec + if param is None: + param = self.params + if legal_inputs is None: + legal_inputs = self._legal_inputs + + for (k, v) in list(param.items()): if k == '_ansible_check_mode' and v: self.check_mode = True @@ -1590,7 +1618,7 @@ class AnsibleModule(object): elif k == '_ansible_socket': self._socket_path = v - elif check_invalid_arguments and k not in self._legal_inputs: + elif check_invalid_arguments and k not in legal_inputs: unsupported_parameters.add(k) # clean up internal params: @@ -1598,44 +1626,57 @@ class AnsibleModule(object): del self.params[k] if unsupported_parameters: - self.fail_json(msg="Unsupported parameters for (%s) module: %s. Supported parameters include: %s" % (self._name, - ','.join(sorted(list(unsupported_parameters))), - ','.join(sorted(self.argument_spec.keys())))) + msg = "Unsupported parameters for (%s) module: %s" % (self._name, ','.join(sorted(list(unsupported_parameters)))) + if self._options_context: + msg += " found in %s." % " -> ".join(self._options_context) + msg += " Supported parameters include: %s" % (','.join(sorted(spec.keys()))) + self.fail_json(msg=msg) if self.check_mode and not self.supports_check_mode: self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name) - def _count_terms(self, check): + def _count_terms(self, check, param=None): count = 0 + if param is None: + param = self.params for term in check: - if term in self.params: + if term in param: count += 1 return count - def _check_mutually_exclusive(self, spec): + def _check_mutually_exclusive(self, spec, param=None): if spec is None: return for check in spec: - count = self._count_terms(check) + count = self._count_terms(check, param) if count > 1: - self.fail_json(msg="parameters are mutually exclusive: %s" % (check,)) + msg = "parameters are mutually exclusive: %s" % (check,) + if self._options_context: + msg += " found in %s" % " -> ".join(self._options_context) + self.fail_json(msg=msg) - def _check_required_one_of(self, spec): + def _check_required_one_of(self, spec, param=None): if spec is None: return for check in spec: - count = self._count_terms(check) + count = self._count_terms(check, param) if count == 0: - self.fail_json(msg="one of the following is required: %s" % ','.join(check)) + msg = "one of the following is required: %s" % ','.join(check) + if self._options_context: + msg += " found in %s" % " -> ".join(self._options_context) + self.fail_json(msg=msg) - def _check_required_together(self, spec): + def _check_required_together(self, spec, param=None): if spec is None: return for check in spec: - counts = [self._count_terms([field]) for field in check] + counts = [self._count_terms([field], param) for field in check] non_zero = [c for c in counts if c > 0] if len(non_zero) > 0: if 0 in counts: - self.fail_json(msg="parameters are required together: %s" % (check,)) + msg = "parameters are required together: %s" % (check,) + if self._options_context: + msg += " found in %s" % " -> ".join(self._options_context) + self.fail_json(msg=msg) def _check_required_arguments(self, spec=None, param=None): ''' ensure all required arguments are present ''' @@ -1649,12 +1690,17 @@ class AnsibleModule(object): if required and k not in param: missing.append(k) if len(missing) > 0: - self.fail_json(msg="missing required arguments: %s" % ",".join(missing)) + msg = "missing required arguments: %s" % ",".join(missing) + if self._options_context: + msg += " found in %s" % " -> ".join(self._options_context) + self.fail_json(msg=msg) - def _check_required_if(self, spec): + def _check_required_if(self, spec, param=None): ''' ensure that parameters which conditionally required are present ''' if spec is None: return + if param is None: + param = self.params for sp in spec: missing = [] max_missing_count = 0 @@ -1669,13 +1715,16 @@ class AnsibleModule(object): if is_one_of: max_missing_count = len(requirements) - if key in self.params and self.params[key] == val: + if key in param and param[key] == val: for check in requirements: - count = self._count_terms((check,)) + count = self._count_terms((check,), param) if count == 0: missing.append(check) if len(missing) and len(missing) >= max_missing_count: - self.fail_json(msg="%s is %s but the following are missing: %s" % (key, val, ','.join(missing))) + msg = "%s is %s but the following are missing: %s" % (key, val, ','.join(missing)) + if self._options_context: + msg += " found in %s" % " -> ".join(self._options_context) + self.fail_json(msg=msg) def _check_argument_values(self, spec=None, param=None): ''' ensure all arguments have the requested values, and there are no stray arguments ''' @@ -1710,9 +1759,14 @@ class AnsibleModule(object): if param[k] not in choices: choices_str = ",".join([to_native(c) for c in choices]) msg = "value of %s must be one of: %s, got: %s" % (k, choices_str, param[k]) + if self._options_context: + msg += " found in %s" % " -> ".join(self._options_context) self.fail_json(msg=msg) else: - self.fail_json(msg="internal error: choices for argument %s are not iterable: %s" % (k, choices)) + msg = "internal error: choices for argument %s are not iterable: %s" % (k, choices) + if self._options_context: + msg += " found in %s" % " -> ".join(self._options_context) + self.fail_json(msg=msg) def safe_eval(self, value, locals=None, include_exceptions=False): @@ -1862,6 +1916,60 @@ class AnsibleModule(object): except ValueError: raise TypeError('%s cannot be converted to a Bit value' % type(value)) + def _handle_options(self, argument_spec=None, params=None): + ''' deal with options to create sub spec ''' + if argument_spec is None: + argument_spec = self.argument_spec + if params is None: + params = self.params + + for (k, v) in argument_spec.items(): + wanted = v.get('type', None) + if wanted == 'dict' or (wanted == 'list' and v.get('elements', '') == 'dict'): + spec = v.get('options', None) + if spec is None or not params[k]: + continue + + self._options_context.append(k) + + if isinstance(params[k], dict): + elements = [params[k]] + else: + elements = params[k] + + for param in elements: + if not isinstance(param, dict): + self.fail_json(msg="value of %s must be of type dict or list of dict" % k) + + self._set_fallbacks(spec, param) + options_aliases = self._handle_aliases(spec, param) + + self._handle_no_log_values(spec, param) + options_legal_inputs = list(spec.keys()) + list(options_aliases.keys()) + + self._check_arguments(self.check_invalid_arguments, spec, param, options_legal_inputs) + + # check exclusive early + if not self.bypass_checks: + self._check_mutually_exclusive(self.mutually_exclusive, param) + + self._set_defaults(pre=True, spec=spec, param=param) + + if not self.bypass_checks: + self._check_required_arguments(spec, param) + self._check_argument_types(spec, param) + self._check_argument_values(spec, param) + + self._check_required_together(self.required_together, param) + self._check_required_one_of(self.required_one_of, param) + self._check_required_if(self.required_if, param) + + self._set_defaults(pre=False, spec=spec, param=param) + + # handle multi level options (sub argspec) + self._handle_options(spec, param) + self._options_context.pop() + def _check_argument_types(self, spec=None, param=None): ''' ensure all arguments have the requested type ''' @@ -1902,41 +2010,41 @@ class AnsibleModule(object): e = get_exception() self.fail_json(msg="argument %s is of type %s and we were unable to convert to %s: %s" % (k, type(value), wanted, e)) - # deal with sub options to create sub spec - spec = None - if wanted == 'dict' or (wanted == 'list' and v.get('elements', '') == 'dict'): - spec = v.get('options', None) - if spec: - self._check_required_arguments(spec, param[k]) - self._check_argument_types(spec, param[k]) - self._check_argument_values(spec, param[k]) - - def _set_defaults(self, pre=True): - for (k, v) in self.argument_spec.items(): + def _set_defaults(self, pre=True, spec=None, param=None): + if spec is None: + spec = self.argument_spec + if param is None: + param = self.params + for (k, v) in spec.items(): default = v.get('default', None) if pre is True: # this prevents setting defaults on required items - if default is not None and k not in self.params: - self.params[k] = default + if default is not None and k not in param: + param[k] = default else: # make sure things without a default still get set None - if k not in self.params: - self.params[k] = default + if k not in param: + param[k] = default - def _set_fallbacks(self): - for (k, v) in self.argument_spec.items(): + def _set_fallbacks(self, spec=None, param=None): + if spec is None: + spec = self.argument_spec + if param is None: + param = self.params + + for (k, v) in spec.items(): fallback = v.get('fallback', (None,)) fallback_strategy = fallback[0] fallback_args = [] fallback_kwargs = {} - if k not in self.params and fallback_strategy is not None: + if k not in param and fallback_strategy is not None: for item in fallback[1:]: if isinstance(item, dict): fallback_kwargs = item else: fallback_args = item try: - self.params[k] = fallback_strategy(*fallback_args, **fallback_kwargs) + param = fallback_strategy(*fallback_args, **fallback_kwargs) except AnsibleFallbackNotFound: continue diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index a4a88d9988..3cedd042fa 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -342,6 +342,171 @@ class TestModuleUtilsBasic(ModuleTestCase): supports_check_mode=True, ) + def test_module_utils_basic_ansible_module_with_options_creation(self): + from ansible.module_utils import basic + + options_spec = dict( + foo=dict(required=True, aliases=['dup']), + bar=dict(), + bam=dict(), + baz=dict(), + bam1=dict(default='test') + ) + arg_spec = dict(foobar=dict(type='list', elements='dict', options=options_spec)) + mut_ex = (('bar', 'bam'),) + req_to = (('bam', 'baz'),) + + # should test ok + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"foo": "hello"}, {"foo": "test"}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + am = basic.AnsibleModule( + argument_spec=arg_spec, + mutually_exclusive=mut_ex, + required_together=req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # should test ok, handles aliases + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"dup": "hello"}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + am = basic.AnsibleModule( + argument_spec=arg_spec, + mutually_exclusive=mut_ex, + required_together=req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # fail, because a required param was not specified + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec=arg_spec, + mutually_exclusive=mut_ex, + required_together=req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # fail because of mutually exclusive parameters + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"foo": "hello", "bar": "bad", "bam": "bad"}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec=arg_spec, + mutually_exclusive=mut_ex, + required_together=req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # fail because a param required due to another param was not specified + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"bam": "bad"}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec=arg_spec, + mutually_exclusive=mut_ex, + required_together=req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # fail because one of param is required + req_one_of = (('bar', 'bam'),) + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"foo": "hello"}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec=arg_spec, + required_one_of=req_one_of, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # fail because value of one param mandates presence of other param required + req_if = (('foo', 'hello', ('bam')),) + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"foo": "hello"}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec=arg_spec, + required_if=req_if, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # should test ok, the required param is set by default from spec + req_if = [('foo', 'hello', ('bam1',))] + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"foo": "hello"}]})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + am = basic.AnsibleModule( + argument_spec=arg_spec, + required_if=req_if, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # should test ok, for options in dict format. + arg_spec = dict(foobar=dict(type='dict', options=options_spec)) + + # should test ok + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': {"foo": "hello"}})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + am = basic.AnsibleModule( + argument_spec=arg_spec, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True + ) + + # should fail, check for invalid agrument + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': {"foo1": "hello"}})) + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec=arg_spec, + no_log=True, + check_invalid_arguments=True, + add_file_common_args=True, + supports_check_mode=True + ) + def test_module_utils_basic_ansible_module_type_check(self): from ansible.module_utils import basic @@ -387,6 +552,52 @@ class TestModuleUtilsBasic(ModuleTestCase): supports_check_mode=True, ) + def test_module_utils_basic_ansible_module_options_type_check(self): + from ansible.module_utils import basic + + options_spec = dict( + foo=dict(type='float'), + foo2=dict(type='float'), + foo3=dict(type='float'), + bar=dict(type='int'), + bar2=dict(type='int'), + ) + + arg_spec = dict(foobar=dict(type='list', elements='dict', options=options_spec)) + # should test ok + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{ + "foo": 123.0, # float + "foo2": 123, # int + "foo3": "123", # string + "bar": 123, # int + "bar2": "123", # string + }]})) + + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + am = basic.AnsibleModule( + argument_spec=arg_spec, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True, + ) + + # fail, because bar does not accept floating point numbers + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'foobar': [{"bar": 123.0}]})) + + with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec=arg_spec, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True, + ) + def test_module_utils_basic_ansible_module_load_file_common_arguments(self): from ansible.module_utils import basic basic._ANSIBLE_ARGS = None