diff --git a/changelogs/fragments/improved-fieldattribute-inheritance.yaml b/changelogs/fragments/improved-fieldattribute-inheritance.yaml new file mode 100644 index 0000000000..b4e8674776 --- /dev/null +++ b/changelogs/fragments/improved-fieldattribute-inheritance.yaml @@ -0,0 +1,3 @@ +minor_changes: +- inheritance - Improve ``FieldAttribute`` inheritance, by using a sentinel + instead of ``None`` to indicate that the option has not been explicitly set diff --git a/lib/ansible/parsing/mod_args.py b/lib/ansible/parsing/mod_args.py index 7ef4e3f801..5de5ed918f 100644 --- a/lib/ansible/parsing/mod_args.py +++ b/lib/ansible/parsing/mod_args.py @@ -25,6 +25,7 @@ from ansible.module_utils._text import to_text from ansible.parsing.splitter import parse_kv, split_args from ansible.plugins.loader import module_loader, action_loader from ansible.template import Templar +from ansible.utils.sentinel import Sentinel # For filtering out modules correctly below @@ -258,7 +259,7 @@ class ModuleArgsParser: thing = None action = None - delegate_to = self._task_ds.get('delegate_to', None) + delegate_to = self._task_ds.get('delegate_to', Sentinel) args = dict() # This is the standard YAML form for command-type modules. We grab diff --git a/lib/ansible/playbook/base.py b/lib/ansible/playbook/base.py index 17e15e4a96..f1b7ad6ac8 100644 --- a/lib/ansible/playbook/base.py +++ b/lib/ansible/playbook/base.py @@ -20,18 +20,24 @@ from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, Ansible from ansible.module_utils._text import to_text, to_native from ansible.playbook.attribute import Attribute, FieldAttribute from ansible.parsing.dataloader import DataLoader -from ansible.utils.vars import combine_vars, isidentifier, get_unique_id from ansible.utils.display import Display +from ansible.utils.sentinel import Sentinel +from ansible.utils.vars import combine_vars, isidentifier, get_unique_id display = Display() def _generic_g(prop_name, self): try: - return self._attributes[prop_name] + value = self._attributes[prop_name] except KeyError: raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, prop_name)) + if value is Sentinel: + value = self._attr_defaults[prop_name] + + return value + def _generic_g_method(prop_name, self): try: @@ -55,6 +61,9 @@ def _generic_g_parent(prop_name, self): except KeyError: raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, prop_name)) + if value is Sentinel: + value = self._attr_defaults[prop_name] + return value @@ -105,7 +114,8 @@ class BaseMeta(type): dst_dict[attr_name] = property(getter, setter, deleter) dst_dict['_valid_attrs'][attr_name] = value - dst_dict['_attributes'][attr_name] = value.default + dst_dict['_attributes'][attr_name] = Sentinel + dst_dict['_attr_defaults'][attr_name] = value.default if value.alias is not None: dst_dict[value.alias] = property(getter, setter, deleter) @@ -125,9 +135,10 @@ class BaseMeta(type): _process_parents(parent.__bases__, new_dst_dict) # create some additional class attributes - dct['_attributes'] = dict() - dct['_valid_attrs'] = dict() - dct['_alias_attrs'] = dict() + dct['_attributes'] = {} + dct['_attr_defaults'] = {} + dct['_valid_attrs'] = {} + dct['_alias_attrs'] = {} # now create the attributes based on the FieldAttributes # available, including from parent (and grandparent) objects @@ -158,10 +169,11 @@ class FieldAttributeBase(with_metaclass(BaseMeta, object)): # it was initialized as a class param in the meta class, so we # need a unique object here (all members contained within are # unique already). - self._attributes = self._attributes.copy() - for key, value in self._attributes.items(): + self._attributes = self.__class__._attributes.copy() + self._attr_defaults = self.__class__._attr_defaults.copy() + for key, value in self._attr_defaults.items(): if callable(value): - self._attributes[key] = value() + self._attr_defaults[key] = value() # and init vars, avoid using defaults in field declaration as it lives across plays self.vars = dict() @@ -312,6 +324,7 @@ class FieldAttributeBase(with_metaclass(BaseMeta, object)): if name in self._alias_attrs: continue new_me._attributes[name] = shallowcopy(self._attributes[name]) + new_me._attr_defaults[name] = shallowcopy(self._attr_defaults[name]) new_me._loader = self._loader new_me._variable_manager = self._variable_manager @@ -482,6 +495,12 @@ class FieldAttributeBase(with_metaclass(BaseMeta, object)): if not isinstance(new_value, list): new_value = [new_value] + # Due to where _extend_value may run for some attributes + # it is possible to end up with Sentinel in the list of values + # ensure we strip them + value[:] = [v for v in value if v is not Sentinel] + new_value[:] = [v for v in new_value if v is not Sentinel] + if prepend: combined = new_value + value else: @@ -583,7 +602,7 @@ class Base(FieldAttributeBase): _ignore_unreachable = FieldAttribute(isa='bool') _check_mode = FieldAttribute(isa='bool') _diff = FieldAttribute(isa='bool') - _any_errors_fatal = FieldAttribute(isa='bool') + _any_errors_fatal = FieldAttribute(isa='bool', default=C.ANY_ERRORS_FATAL) # explicitly invoke a debugger on tasks _debugger = FieldAttribute(isa='string') diff --git a/lib/ansible/playbook/block.py b/lib/ansible/playbook/block.py index 50483a9a48..9f5a8cf133 100644 --- a/lib/ansible/playbook/block.py +++ b/lib/ansible/playbook/block.py @@ -27,6 +27,7 @@ from ansible.playbook.conditional import Conditional from ansible.playbook.helpers import load_list_of_tasks from ansible.playbook.role import Role from ansible.playbook.taggable import Taggable +from ansible.utils.sentinel import Sentinel class Block(Base, Become, Conditional, Taggable): @@ -311,51 +312,45 @@ class Block(Base, Become, Conditional, Taggable): else: _parent = self._parent._parent - if _parent and (value is None or extend): + if _parent and (value is Sentinel or extend): try: if getattr(_parent, 'statically_loaded', True): if hasattr(_parent, '_get_parent_attribute'): parent_value = _parent._get_parent_attribute(attr) else: - parent_value = _parent._attributes.get(attr, None) + parent_value = _parent._attributes.get(attr, Sentinel) if extend: value = self._extend_value(value, parent_value, prepend) else: value = parent_value except AttributeError: pass - if self._role and (value is None or extend): + if self._role and (value is Sentinel or extend): try: - if hasattr(self._role, '_get_parent_attribute'): - parent_value = self._role.get_parent_attribute(attr) - else: - parent_value = self._role._attributes.get(attr, None) + parent_value = self._role._attributes.get(attr, Sentinel) if extend: value = self._extend_value(value, parent_value, prepend) else: value = parent_value dep_chain = self.get_dep_chain() - if dep_chain and (value is None or extend): + if dep_chain and (value is Sentinel or extend): dep_chain.reverse() for dep in dep_chain: - if hasattr(dep, '_get_parent_attribute'): - dep_value = dep._get_parent_attribute(attr) - else: - dep_value = dep._attributes.get(attr, None) + dep_value = dep._attributes.get(attr, Sentinel) if extend: value = self._extend_value(value, dep_value, prepend) else: value = dep_value - if value is not None and not extend: + if value is not Sentinel and not extend: break except AttributeError: pass - if self._play and (value is None or extend): + if self._play and (value is Sentinel or extend): try: - play_value = self._play._attributes.get(attr, None) - if play_value is not None: + play_value = self._play._attributes.get(attr, Sentinel) + if play_value is not Sentinel: if extend: value = self._extend_value(value, play_value, prepend) else: diff --git a/lib/ansible/playbook/play.py b/lib/ansible/playbook/play.py index 6dbd384d38..cc73c5e721 100644 --- a/lib/ansible/playbook/play.py +++ b/lib/ansible/playbook/play.py @@ -94,7 +94,7 @@ class Play(Base, Taggable, Become): def get_name(self): ''' return the name of the Play ''' - return self._attributes.get('name') + return self.name @staticmethod def load(data, variable_manager=None, loader=None, vars=None): diff --git a/lib/ansible/playbook/role/__init__.py b/lib/ansible/playbook/role/__init__.py index e7f87d25f6..3d5008346d 100644 --- a/lib/ansible/playbook/role/__init__.py +++ b/lib/ansible/playbook/role/__init__.py @@ -183,19 +183,16 @@ class Role(Base, Become, Conditional, Taggable): if parent_role: self.add_parent(parent_role) - # copy over all field attributes, except for when and tags, which - # are special cases and need to preserve pre-existing values + # copy over all field attributes from the RoleInclude + # update self._attributes directly, to avoid squashing for (attr_name, _) in iteritems(self._valid_attrs): - if attr_name not in ('when', 'tags'): - setattr(self, attr_name, getattr(role_include, attr_name)) - - current_when = getattr(self, 'when')[:] - current_when.extend(role_include.when) - setattr(self, 'when', current_when) - - current_tags = getattr(self, 'tags')[:] - current_tags.extend(role_include.tags) - setattr(self, 'tags', current_tags) + if attr_name in ('when', 'tags'): + self._attributes[attr_name] = self._extend_value( + self._attributes[attr_name], + role_include._attributes[attr_name], + ) + else: + self._attributes[attr_name] = role_include._attributes[attr_name] # dynamically load any plugins from the role directory for name, obj in get_all_plugin_loaders(): diff --git a/lib/ansible/playbook/task.py b/lib/ansible/playbook/task.py index 3a7e939561..a0b775dd99 100644 --- a/lib/ansible/playbook/task.py +++ b/lib/ansible/playbook/task.py @@ -37,6 +37,7 @@ from ansible.playbook.loop_control import LoopControl from ansible.playbook.role import Role from ansible.playbook.taggable import Taggable from ansible.utils.display import Display +from ansible.utils.sentinel import Sentinel __all__ = ['Task'] @@ -438,13 +439,13 @@ class Task(Base, Conditional, Taggable, Become): else: _parent = self._parent._parent - if _parent and (value is None or extend): + if _parent and (value is Sentinel or extend): if getattr(_parent, 'statically_loaded', True): # vars are always inheritable, other attributes might not be for the parent but still should be for other ancestors if attr != 'vars' and hasattr(_parent, '_get_parent_attribute'): parent_value = _parent._get_parent_attribute(attr) else: - parent_value = _parent._attributes.get(attr, None) + parent_value = _parent._attributes.get(attr, Sentinel) if extend: value = self._extend_value(value, parent_value, prepend) @@ -455,14 +456,6 @@ class Task(Base, Conditional, Taggable, Become): return value - def _get_attr_any_errors_fatal(self): - value = self._attributes['any_errors_fatal'] - if value is None: - value = self._get_parent_attribute('any_errors_fatal') - if value is None: - value = C.ANY_ERRORS_FATAL - return value - def get_dep_chain(self): if self._parent: return self._parent.get_dep_chain() diff --git a/lib/ansible/playbook/task_include.py b/lib/ansible/playbook/task_include.py index 5a11f4c635..9296e656d4 100644 --- a/lib/ansible/playbook/task_include.py +++ b/lib/ansible/playbook/task_include.py @@ -25,6 +25,7 @@ from ansible.playbook.attribute import FieldAttribute from ansible.playbook.block import Block from ansible.playbook.task import Task from ansible.utils.display import Display +from ansible.utils.sentinel import Sentinel __all__ = ['TaskInclude'] @@ -84,7 +85,7 @@ class TaskInclude(Task): diff = set(ds.keys()).difference(TaskInclude.VALID_INCLUDE_KEYWORDS) for k in diff: # This check doesn't handle ``include`` as we have no idea at this point if it is static or not - if ds[k] is not None and ds['action'] in ('include_tasks', 'include_role'): + if ds[k] is not Sentinel and ds['action'] in ('include_tasks', 'include_role'): if C.INVALID_TASK_ATTRIBUTE_FAILED: raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (k, self.__class__.__name__), obj=ds) else: diff --git a/lib/ansible/utils/sentinel.py b/lib/ansible/utils/sentinel.py new file mode 100644 index 0000000000..ca4f82764e --- /dev/null +++ b/lib/ansible/utils/sentinel.py @@ -0,0 +1,68 @@ +# Copyright (c) 2019 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +class Sentinel: + """ + Object which can be used to mark whether an entry as being special + + A sentinel value demarcates a value or marks an entry as having a special meaning. In C, the + Null byte is used as a sentinel for the end of a string. In Python, None is often used as + a Sentinel in optional parameters to mean that the parameter was not set by the user. + + You should use None as a Sentinel value any Python code where None is not a valid entry. If + None is a valid entry, though, then you need to create a different value, which is the purpose + of this class. + + Example of using Sentinel as a default parameter value:: + + def confirm_big_red_button(tristate=Sentinel): + if tristate is Sentinel: + print('You must explicitly press the big red button to blow up the base') + elif tristate is True: + print('Countdown to destruction activated') + elif tristate is False: + print('Countdown stopped') + elif tristate is None: + print('Waiting for more input') + + Example of using Sentinel to tell whether a dict which has a default value has been changed:: + + values = {'one': Sentinel, 'two': Sentinel} + defaults = {'one': 1, 'two': 2} + + # [.. Other code which does things including setting a new value for 'one' ..] + values['one'] = None + # [..] + + print('You made changes to:') + for key, value in values.items(): + if value is Sentinel: + continue + print('%s: %s' % (key, value) + """ + + def __new__(cls): + """ + Return the cls itself. This makes both equality and identity True for comparing the class + to an instance of the class, preventing common usage errors. + + Preferred usage:: + + a = Sentinel + if a is Sentinel: + print('Sentinel value') + + However, these are True as well, eliminating common usage errors:: + + if Sentinel is Sentinel(): + print('Sentinel value') + + if Sentinel == Sentinel(): + print('Sentinel value') + """ + return cls diff --git a/test/integration/targets/check_mode/check_mode-not-on-cli.yml b/test/integration/targets/check_mode/check_mode-not-on-cli.yml new file mode 100644 index 0000000000..0c7af038e1 --- /dev/null +++ b/test/integration/targets/check_mode/check_mode-not-on-cli.yml @@ -0,0 +1,37 @@ +--- +# Run withhout --check +- hosts: localhost + gather_facts: False + tasks: + - command: 'echo ran' + register: command_out + + - debug: var=command_out + - name: check that this did not run in check mode + assert: + that: + - '"ran" in command_out["stdout"]' + +- hosts: localhost + gather_facts: False + check_mode: True + tasks: + - command: 'echo ran' + register: command_out + + - name: check that play level check_mode overrode the cli + assert: + that: + - '"check mode" in command_out["msg"]' + +- hosts: localhost + gather_facts: False + tasks: + - command: 'echo ran' + register: command_out + check_mode: True + + - name: check that task level check_mode overrode the cli + assert: + that: + - '"check mode" in command_out["msg"]' diff --git a/test/integration/targets/check_mode/check_mode-on-cli.yml b/test/integration/targets/check_mode/check_mode-on-cli.yml new file mode 100644 index 0000000000..8091198bb7 --- /dev/null +++ b/test/integration/targets/check_mode/check_mode-on-cli.yml @@ -0,0 +1,36 @@ +--- +# Run with --check +- hosts: localhost + gather_facts: False + tasks: + - command: 'echo ran' + register: command_out + + - name: check that this did not run in check mode + assert: + that: + - '"check mode" in command_out["msg"]' + +- hosts: localhost + gather_facts: False + check_mode: False + tasks: + - command: 'echo ran' + register: command_out + + - name: check that play level check_mode overrode the cli + assert: + that: + - '"ran" in command_out["stdout"]' + +- hosts: localhost + gather_facts: False + tasks: + - command: 'echo ran' + register: command_out + check_mode: False + + - name: check that task level check_mode overrode the cli + assert: + that: + - '"ran" in command_out["stdout"]' diff --git a/test/integration/targets/check_mode/check_mode.yml b/test/integration/targets/check_mode/check_mode.yml index 0f4fde4da0..a5777506ff 100644 --- a/test/integration/targets/check_mode/check_mode.yml +++ b/test/integration/targets/check_mode/check_mode.yml @@ -1,4 +1,5 @@ -- hosts: testhost +- name: Test that check works with check_mode specified in roles + hosts: testhost vars: - output_dir: . roles: diff --git a/test/integration/targets/check_mode/runme.sh b/test/integration/targets/check_mode/runme.sh index 184ecb7c27..954ac6ff58 100755 --- a/test/integration/targets/check_mode/runme.sh +++ b/test/integration/targets/check_mode/runme.sh @@ -3,3 +3,5 @@ set -eux ansible-playbook check_mode.yml -i ../../inventory -v --check "$@" +ansible-playbook check_mode-on-cli.yml -i ../../inventory -v --check "$@" +ansible-playbook check_mode-not-on-cli.yml -i ../../inventory -v "$@" diff --git a/test/units/parsing/test_mod_args.py b/test/units/parsing/test_mod_args.py index 04aa0d02b5..7aa2161e42 100644 --- a/test/units/parsing/test_mod_args.py +++ b/test/units/parsing/test_mod_args.py @@ -9,6 +9,7 @@ import pytest from ansible.errors import AnsibleParserError from ansible.parsing.mod_args import ModuleArgsParser +from ansible.utils.sentinel import Sentinel class TestModArgsDwim: @@ -37,7 +38,7 @@ class TestModArgsDwim: assert args == dict( _raw_params='echo hi', ) - assert to is None + assert to is Sentinel def test_basic_command(self): m = ModuleArgsParser(dict(command='echo hi')) @@ -48,7 +49,7 @@ class TestModArgsDwim: assert args == dict( _raw_params='echo hi', ) - assert to is None + assert to is Sentinel def test_shell_with_modifiers(self): m = ModuleArgsParser(dict(shell='/bin/foo creates=/tmp/baz removes=/tmp/bleep')) @@ -61,7 +62,7 @@ class TestModArgsDwim: removes='/tmp/bleep', _raw_params='/bin/foo', ) - assert to is None + assert to is Sentinel def test_normal_usage(self): m = ModuleArgsParser(dict(copy='src=a dest=b')) @@ -70,7 +71,7 @@ class TestModArgsDwim: assert mod, 'copy' assert args, dict(src='a', dest='b') - assert to is None + assert to is Sentinel def test_complex_args(self): m = ModuleArgsParser(dict(copy=dict(src='a', dest='b'))) @@ -79,7 +80,7 @@ class TestModArgsDwim: assert mod, 'copy' assert args, dict(src='a', dest='b') - assert to is None + assert to is Sentinel def test_action_with_complex(self): m = ModuleArgsParser(dict(action=dict(module='copy', src='a', dest='b'))) @@ -88,7 +89,7 @@ class TestModArgsDwim: assert mod == 'copy' assert args == dict(src='a', dest='b') - assert to is None + assert to is Sentinel def test_action_with_complex_and_complex_args(self): m = ModuleArgsParser(dict(action=dict(module='copy', args=dict(src='a', dest='b')))) @@ -97,7 +98,7 @@ class TestModArgsDwim: assert mod == 'copy' assert args == dict(src='a', dest='b') - assert to is None + assert to is Sentinel def test_local_action_string(self): m = ModuleArgsParser(dict(local_action='copy src=a dest=b'))