1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Cleanup combine_vars

* Dedupe combine_vars() code (removed from VariableManager)
* Fix merge_hash algorithm to preserve the type
* unittest combine_vars and merge_hash
This commit is contained in:
Toshio Kuratomi 2015-09-01 11:20:16 -07:00
parent 7fe495d619
commit aeff960d02
5 changed files with 158 additions and 92 deletions

View file

@ -20,38 +20,63 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import ast
from collections import MutableMapping
from six import string_types
from six import iteritems, string_types
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.parsing.splitter import parse_kv
from ansible.utils.unicode import to_unicode
def _validate_mutable_mappings(a, b):
"""
Internal convenience function to ensure arguments are MutableMappings
This checks that all arguments are MutableMappings or raises an error
:raises AnsibleError: if one of the arguments is not a MutableMapping
"""
# If this becomes generally needed, change the signature to operate on
# a variable number of arguments instead.
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
raise AnsibleError("failed to combine variables, expected dicts but"
" got a '{0}' and a '{1}'".format(
a.__class__.__name__, b.__class__.__name__))
def combine_vars(a, b):
"""
Return a copy of dictionaries of variables based on configured hash behavior
"""
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
return merge_hash(a, b)
else:
# HASH_BEHAVIOUR == 'replace'
_validate_mutable_mappings(a, b)
result = a.copy()
result.update(b)
return result
def merge_hash(a, b):
''' recursively merges hash b into a
keys from b take precedence over keys from a '''
"""
Recursively merges hash b into a so that keys from b take precedence over keys from a
"""
result = {}
_validate_mutable_mappings(a, b)
result = a.copy()
for dicts in a, b:
# next, iterate over b keys and values
for k, v in dicts.iteritems():
for k, v in iteritems(b):
# if there's already such key in a
# and that key contains dict
if k in result and isinstance(result[k], dict):
# and that key contains a MutableMapping
if k in result and isinstance(result[k], MutableMapping):
# merge those dicts recursively
result[k] = merge_hash(a[k], v)
result[k] = merge_hash(result[k], v)
else:
# otherwise, just copy a value from b to a
# otherwise, just copy the value from b to a
result[k] = v
return result

View file

@ -33,11 +33,12 @@ except ImportError:
from ansible import constants as C
from ansible.cli import CLI
from ansible.errors import *
from ansible.errors import AnsibleError
from ansible.parsing import DataLoader
from ansible.plugins.cache import FactCache
from ansible.template import Templar
from ansible.utils.debug import debug
from ansible.utils.vars import combine_vars
from ansible.vars.hostvars import HostVars
CACHED_VARS = dict()
@ -104,50 +105,6 @@ class VariableManager:
return data
def _validate_both_dicts(self, a, b):
'''
Validates that both arguments are dictionaries, or an error is raised.
'''
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
raise AnsibleError("failed to combine variables, expected dicts but got a '%s' and a '%s'" % (type(a).__name__, type(b).__name__))
def _combine_vars(self, a, b):
'''
Combines dictionaries of variables, based on the hash behavior
'''
self._validate_both_dicts(a, b)
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
return self._merge_dicts(a, b)
else:
result = a.copy()
result.update(b)
return result
def _merge_dicts(self, a, b):
'''
Recursively merges dict b into a, so that keys
from b take precedence over keys from a.
'''
result = dict()
self._validate_both_dicts(a, b)
for dicts in a, b:
# next, iterate over b keys and values
for k, v in dicts.iteritems():
# if there's already such key in a
# and that key contains dict
if k in result and isinstance(result[k], dict):
# merge those dicts recursively
result[k] = self._merge_dicts(a[k], v)
else:
# otherwise, just copy a value from b to a
result[k] = v
return result
def get_vars(self, loader, play=None, host=None, task=None, include_hostvars=True, use_cache=True):
'''
@ -181,13 +138,13 @@ class VariableManager:
# first we compile any vars specified in defaults/main.yml
# for all roles within the specified play
for role in play.get_roles():
all_vars = self._combine_vars(all_vars, role.get_default_vars())
all_vars = combine_vars(all_vars, role.get_default_vars())
# if we have a task in this context, and that task has a role, make
# sure it sees its defaults above any other roles, as we previously
# (v1) made sure each task had a copy of its roles default vars
if task and task._role is not None:
all_vars = self._combine_vars(all_vars, task._role.get_default_vars())
all_vars = combine_vars(all_vars, task._role.get_default_vars())
if host:
# next, if a host is specified, we load any vars from group_vars
@ -198,38 +155,38 @@ class VariableManager:
if 'all' in self._group_vars_files:
data = self._preprocess_vars(self._group_vars_files['all'])
for item in data:
all_vars = self._combine_vars(all_vars, item)
all_vars = combine_vars(all_vars, item)
for group in host.get_groups():
all_vars = self._combine_vars(all_vars, group.get_vars())
all_vars = combine_vars(all_vars, group.get_vars())
if group.name in self._group_vars_files and group.name != 'all':
data = self._preprocess_vars(self._group_vars_files[group.name])
for item in data:
all_vars = self._combine_vars(all_vars, item)
all_vars = combine_vars(all_vars, item)
host_name = host.get_name()
if host_name in self._host_vars_files:
data = self._preprocess_vars(self._host_vars_files[host_name])
for item in data:
all_vars = self._combine_vars(all_vars, self._host_vars_files[host_name])
all_vars = combine_vars(all_vars, self._host_vars_files[host_name])
# then we merge in vars specified for this host
all_vars = self._combine_vars(all_vars, host.get_vars())
all_vars = combine_vars(all_vars, host.get_vars())
# next comes the facts cache and the vars cache, respectively
try:
all_vars = self._combine_vars(all_vars, self._fact_cache.get(host.name, dict()))
all_vars = combine_vars(all_vars, self._fact_cache.get(host.name, dict()))
except KeyError:
pass
if play:
all_vars = self._combine_vars(all_vars, play.get_vars())
all_vars = combine_vars(all_vars, play.get_vars())
for vars_file_item in play.get_vars_files():
try:
# create a set of temporary vars here, which incorporate the
# extra vars so we can properly template the vars_files entries
temp_vars = self._combine_vars(all_vars, self._extra_vars)
temp_vars = combine_vars(all_vars, self._extra_vars)
templar = Templar(loader=loader, variables=temp_vars)
# we assume each item in the list is itself a list, as we
@ -246,26 +203,26 @@ class VariableManager:
data = self._preprocess_vars(loader.load_from_file(vars_file))
if data is not None:
for item in data:
all_vars = self._combine_vars(all_vars, item)
all_vars = combine_vars(all_vars, item)
break
else:
raise AnsibleError("vars file %s was not found" % vars_file_item)
except UndefinedError as e:
except UndefinedError:
continue
if not C.DEFAULT_PRIVATE_ROLE_VARS:
for role in play.get_roles():
all_vars = self._combine_vars(all_vars, role.get_vars())
all_vars = combine_vars(all_vars, role.get_vars())
if task:
if task._role:
all_vars = self._combine_vars(all_vars, task._role.get_vars())
all_vars = self._combine_vars(all_vars, task.get_vars())
all_vars = combine_vars(all_vars, task._role.get_vars())
all_vars = combine_vars(all_vars, task.get_vars())
if host:
all_vars = self._combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
all_vars = combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
all_vars = self._combine_vars(all_vars, self._extra_vars)
all_vars = combine_vars(all_vars, self._extra_vars)
# FIXME: make sure all special vars are here
# Finally, we create special vars
@ -345,7 +302,7 @@ class VariableManager:
for p in paths:
_found, results = self._load_inventory_file(path=p, loader=loader)
if results is not None:
data = self._combine_vars(data, results)
data = combine_vars(data, results)
else:
file_name, ext = os.path.splitext(path)

View file

View file

@ -0,0 +1,98 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015, Toshio Kuraotmi <tkuratomi@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from collections import defaultdict
from ansible.compat.tests import mock, unittest
from ansible.errors import AnsibleError
from ansible.utils.vars import combine_vars, merge_hash
class TestVariableUtils(unittest.TestCase):
test_merge_data = (
dict(
a=dict(a=1),
b=dict(b=2),
result=dict(a=1, b=2)
),
dict(
a=dict(a=1, c=dict(foo='bar')),
b=dict(b=2, c=dict(baz='bam')),
result=dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
),
dict(
a=defaultdict(a=1, c=defaultdict(foo='bar')),
b=dict(b=2, c=dict(baz='bam')),
result=defaultdict(a=1, b=2, c=defaultdict(foo='bar', baz='bam'))
),
)
test_replace_data = (
dict(
a=dict(a=1),
b=dict(b=2),
result=dict(a=1, b=2)
),
dict(
a=dict(a=1, c=dict(foo='bar')),
b=dict(b=2, c=dict(baz='bam')),
result=dict(a=1, b=2, c=dict(baz='bam'))
),
dict(
a=defaultdict(a=1, c=dict(foo='bar')),
b=dict(b=2, c=defaultdict(baz='bam')),
result=defaultdict(a=1, b=2, c=defaultdict(baz='bam'))
),
)
def setUp(self):
pass
def tearDown(self):
pass
def test_merge_hash(self):
for test in self.test_merge_data:
self.assertEqual(merge_hash(test['a'], test['b']), test['result'])
def test_improper_args(self):
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
with self.assertRaises(AnsibleError):
combine_vars([1, 2, 3], dict(a=1))
with self.assertRaises(AnsibleError):
combine_vars(dict(a=1), [1, 2, 3])
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
with self.assertRaises(AnsibleError):
combine_vars([1, 2, 3], dict(a=1))
with self.assertRaises(AnsibleError):
combine_vars(dict(a=1), [1, 2, 3])
def test_combine_vars_replace(self):
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
for test in self.test_replace_data:
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])
def test_combine_vars_merge(self):
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
for test in self.test_merge_data:
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])

View file

@ -48,20 +48,6 @@ class TestVariableManager(unittest.TestCase):
self.assertEqual(vars, dict(playbook_dir='.'))
self.assertEqual(
v._merge_dicts(
dict(a=1),
dict(b=2)
), dict(a=1, b=2)
)
self.assertEqual(
v._merge_dicts(
dict(a=1, c=dict(foo='bar')),
dict(b=2, c=dict(baz='bam'))
), dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
)
def test_variable_manager_extra_vars(self):
fake_loader = DictDataLoader({})