From aeff960d028644c19dd845e51ced14a9bd3709c5 Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Tue, 1 Sep 2015 11:20:16 -0700 Subject: [PATCH] Cleanup combine_vars * Dedupe combine_vars() code (removed from VariableManager) * Fix merge_hash algorithm to preserve the type * unittest combine_vars and merge_hash --- lib/ansible/utils/vars.py | 55 +++++++++---- lib/ansible/vars/__init__.py | 83 +++++--------------- test/units/utils/__init__.py | 0 test/units/utils/test_vars.py | 98 ++++++++++++++++++++++++ test/units/vars/test_variable_manager.py | 14 ---- 5 files changed, 158 insertions(+), 92 deletions(-) create mode 100644 test/units/utils/__init__.py create mode 100644 test/units/utils/test_vars.py diff --git a/lib/ansible/utils/vars.py b/lib/ansible/utils/vars.py index ba232c2b68..d202a1b18f 100644 --- a/lib/ansible/utils/vars.py +++ b/lib/ansible/utils/vars.py @@ -20,39 +20,64 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type import ast +from collections import MutableMapping -from six import string_types +from six import iteritems, string_types from ansible import constants as C +from ansible.errors import AnsibleError from ansible.parsing.splitter import parse_kv from ansible.utils.unicode import to_unicode +def _validate_mutable_mappings(a, b): + """ + Internal convenience function to ensure arguments are MutableMappings + + This checks that all arguments are MutableMappings or raises an error + + :raises AnsibleError: if one of the arguments is not a MutableMapping + """ + + # If this becomes generally needed, change the signature to operate on + # a variable number of arguments instead. + + if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)): + raise AnsibleError("failed to combine variables, expected dicts but" + " got a '{0}' and a '{1}'".format( + a.__class__.__name__, b.__class__.__name__)) + def combine_vars(a, b): + """ + Return a copy of dictionaries of variables based on configured hash behavior + """ if C.DEFAULT_HASH_BEHAVIOUR == "merge": return merge_hash(a, b) else: + # HASH_BEHAVIOUR == 'replace' + _validate_mutable_mappings(a, b) result = a.copy() result.update(b) return result def merge_hash(a, b): - ''' recursively merges hash b into a - keys from b take precedence over keys from a ''' + """ + Recursively merges hash b into a so that keys from b take precedence over keys from a + """ - result = {} + _validate_mutable_mappings(a, b) + result = a.copy() - for dicts in a, b: - # next, iterate over b keys and values - for k, v in dicts.iteritems(): - # if there's already such key in a - # and that key contains dict - if k in result and isinstance(result[k], dict): - # merge those dicts recursively - result[k] = merge_hash(a[k], v) - else: - # otherwise, just copy a value from b to a - result[k] = v + # next, iterate over b keys and values + for k, v in iteritems(b): + # if there's already such key in a + # and that key contains a MutableMapping + if k in result and isinstance(result[k], MutableMapping): + # merge those dicts recursively + result[k] = merge_hash(result[k], v) + else: + # otherwise, just copy the value from b to a + result[k] = v return result diff --git a/lib/ansible/vars/__init__.py b/lib/ansible/vars/__init__.py index bcddc2e77f..30b3d5d84b 100644 --- a/lib/ansible/vars/__init__.py +++ b/lib/ansible/vars/__init__.py @@ -33,11 +33,12 @@ except ImportError: from ansible import constants as C from ansible.cli import CLI -from ansible.errors import * +from ansible.errors import AnsibleError from ansible.parsing import DataLoader from ansible.plugins.cache import FactCache from ansible.template import Templar from ansible.utils.debug import debug +from ansible.utils.vars import combine_vars from ansible.vars.hostvars import HostVars CACHED_VARS = dict() @@ -104,50 +105,6 @@ class VariableManager: return data - def _validate_both_dicts(self, a, b): - ''' - Validates that both arguments are dictionaries, or an error is raised. - ''' - if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)): - raise AnsibleError("failed to combine variables, expected dicts but got a '%s' and a '%s'" % (type(a).__name__, type(b).__name__)) - - def _combine_vars(self, a, b): - ''' - Combines dictionaries of variables, based on the hash behavior - ''' - - self._validate_both_dicts(a, b) - - if C.DEFAULT_HASH_BEHAVIOUR == "merge": - return self._merge_dicts(a, b) - else: - result = a.copy() - result.update(b) - return result - - def _merge_dicts(self, a, b): - ''' - Recursively merges dict b into a, so that keys - from b take precedence over keys from a. - ''' - - result = dict() - - self._validate_both_dicts(a, b) - - for dicts in a, b: - # next, iterate over b keys and values - for k, v in dicts.iteritems(): - # if there's already such key in a - # and that key contains dict - if k in result and isinstance(result[k], dict): - # merge those dicts recursively - result[k] = self._merge_dicts(a[k], v) - else: - # otherwise, just copy a value from b to a - result[k] = v - - return result def get_vars(self, loader, play=None, host=None, task=None, include_hostvars=True, use_cache=True): ''' @@ -181,13 +138,13 @@ class VariableManager: # first we compile any vars specified in defaults/main.yml # for all roles within the specified play for role in play.get_roles(): - all_vars = self._combine_vars(all_vars, role.get_default_vars()) + all_vars = combine_vars(all_vars, role.get_default_vars()) # if we have a task in this context, and that task has a role, make # sure it sees its defaults above any other roles, as we previously # (v1) made sure each task had a copy of its roles default vars if task and task._role is not None: - all_vars = self._combine_vars(all_vars, task._role.get_default_vars()) + all_vars = combine_vars(all_vars, task._role.get_default_vars()) if host: # next, if a host is specified, we load any vars from group_vars @@ -198,38 +155,38 @@ class VariableManager: if 'all' in self._group_vars_files: data = self._preprocess_vars(self._group_vars_files['all']) for item in data: - all_vars = self._combine_vars(all_vars, item) + all_vars = combine_vars(all_vars, item) for group in host.get_groups(): - all_vars = self._combine_vars(all_vars, group.get_vars()) + all_vars = combine_vars(all_vars, group.get_vars()) if group.name in self._group_vars_files and group.name != 'all': data = self._preprocess_vars(self._group_vars_files[group.name]) for item in data: - all_vars = self._combine_vars(all_vars, item) + all_vars = combine_vars(all_vars, item) host_name = host.get_name() if host_name in self._host_vars_files: data = self._preprocess_vars(self._host_vars_files[host_name]) for item in data: - all_vars = self._combine_vars(all_vars, self._host_vars_files[host_name]) + all_vars = combine_vars(all_vars, self._host_vars_files[host_name]) # then we merge in vars specified for this host - all_vars = self._combine_vars(all_vars, host.get_vars()) + all_vars = combine_vars(all_vars, host.get_vars()) # next comes the facts cache and the vars cache, respectively try: - all_vars = self._combine_vars(all_vars, self._fact_cache.get(host.name, dict())) + all_vars = combine_vars(all_vars, self._fact_cache.get(host.name, dict())) except KeyError: pass if play: - all_vars = self._combine_vars(all_vars, play.get_vars()) + all_vars = combine_vars(all_vars, play.get_vars()) for vars_file_item in play.get_vars_files(): try: # create a set of temporary vars here, which incorporate the # extra vars so we can properly template the vars_files entries - temp_vars = self._combine_vars(all_vars, self._extra_vars) + temp_vars = combine_vars(all_vars, self._extra_vars) templar = Templar(loader=loader, variables=temp_vars) # we assume each item in the list is itself a list, as we @@ -246,26 +203,26 @@ class VariableManager: data = self._preprocess_vars(loader.load_from_file(vars_file)) if data is not None: for item in data: - all_vars = self._combine_vars(all_vars, item) + all_vars = combine_vars(all_vars, item) break else: raise AnsibleError("vars file %s was not found" % vars_file_item) - except UndefinedError as e: + except UndefinedError: continue if not C.DEFAULT_PRIVATE_ROLE_VARS: for role in play.get_roles(): - all_vars = self._combine_vars(all_vars, role.get_vars()) + all_vars = combine_vars(all_vars, role.get_vars()) if task: if task._role: - all_vars = self._combine_vars(all_vars, task._role.get_vars()) - all_vars = self._combine_vars(all_vars, task.get_vars()) + all_vars = combine_vars(all_vars, task._role.get_vars()) + all_vars = combine_vars(all_vars, task.get_vars()) if host: - all_vars = self._combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict())) + all_vars = combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict())) - all_vars = self._combine_vars(all_vars, self._extra_vars) + all_vars = combine_vars(all_vars, self._extra_vars) # FIXME: make sure all special vars are here # Finally, we create special vars @@ -345,7 +302,7 @@ class VariableManager: for p in paths: _found, results = self._load_inventory_file(path=p, loader=loader) if results is not None: - data = self._combine_vars(data, results) + data = combine_vars(data, results) else: file_name, ext = os.path.splitext(path) diff --git a/test/units/utils/__init__.py b/test/units/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/units/utils/test_vars.py b/test/units/utils/test_vars.py new file mode 100644 index 0000000000..aba05c41d4 --- /dev/null +++ b/test/units/utils/test_vars.py @@ -0,0 +1,98 @@ +# (c) 2012-2014, Michael DeHaan +# (c) 2015, Toshio Kuraotmi +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from collections import defaultdict + +from ansible.compat.tests import mock, unittest +from ansible.errors import AnsibleError + +from ansible.utils.vars import combine_vars, merge_hash + +class TestVariableUtils(unittest.TestCase): + + test_merge_data = ( + dict( + a=dict(a=1), + b=dict(b=2), + result=dict(a=1, b=2) + ), + dict( + a=dict(a=1, c=dict(foo='bar')), + b=dict(b=2, c=dict(baz='bam')), + result=dict(a=1, b=2, c=dict(foo='bar', baz='bam')) + ), + dict( + a=defaultdict(a=1, c=defaultdict(foo='bar')), + b=dict(b=2, c=dict(baz='bam')), + result=defaultdict(a=1, b=2, c=defaultdict(foo='bar', baz='bam')) + ), + ) + test_replace_data = ( + dict( + a=dict(a=1), + b=dict(b=2), + result=dict(a=1, b=2) + ), + dict( + a=dict(a=1, c=dict(foo='bar')), + b=dict(b=2, c=dict(baz='bam')), + result=dict(a=1, b=2, c=dict(baz='bam')) + ), + dict( + a=defaultdict(a=1, c=dict(foo='bar')), + b=dict(b=2, c=defaultdict(baz='bam')), + result=defaultdict(a=1, b=2, c=defaultdict(baz='bam')) + ), + ) + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_merge_hash(self): + for test in self.test_merge_data: + self.assertEqual(merge_hash(test['a'], test['b']), test['result']) + + def test_improper_args(self): + with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'): + with self.assertRaises(AnsibleError): + combine_vars([1, 2, 3], dict(a=1)) + with self.assertRaises(AnsibleError): + combine_vars(dict(a=1), [1, 2, 3]) + + with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'): + with self.assertRaises(AnsibleError): + combine_vars([1, 2, 3], dict(a=1)) + with self.assertRaises(AnsibleError): + combine_vars(dict(a=1), [1, 2, 3]) + + def test_combine_vars_replace(self): + with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'): + for test in self.test_replace_data: + self.assertEqual(combine_vars(test['a'], test['b']), test['result']) + + def test_combine_vars_merge(self): + with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'): + for test in self.test_merge_data: + self.assertEqual(combine_vars(test['a'], test['b']), test['result']) diff --git a/test/units/vars/test_variable_manager.py b/test/units/vars/test_variable_manager.py index 9d500d04d8..0d8e0770e9 100644 --- a/test/units/vars/test_variable_manager.py +++ b/test/units/vars/test_variable_manager.py @@ -48,20 +48,6 @@ class TestVariableManager(unittest.TestCase): self.assertEqual(vars, dict(playbook_dir='.')) - self.assertEqual( - v._merge_dicts( - dict(a=1), - dict(b=2) - ), dict(a=1, b=2) - ) - self.assertEqual( - v._merge_dicts( - dict(a=1, c=dict(foo='bar')), - dict(b=2, c=dict(baz='bam')) - ), dict(a=1, b=2, c=dict(foo='bar', baz='bam')) - ) - - def test_variable_manager_extra_vars(self): fake_loader = DictDataLoader({})