1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Cleanup combine_vars

* Dedupe combine_vars() code (removed from VariableManager)
* Fix merge_hash algorithm to preserve the type
* unittest combine_vars and merge_hash
This commit is contained in:
Toshio Kuratomi 2015-09-01 11:20:16 -07:00
parent 7fe495d619
commit aeff960d02
5 changed files with 158 additions and 92 deletions

View file

@ -20,38 +20,63 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import ast import ast
from collections import MutableMapping
from six import string_types from six import iteritems, string_types
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.parsing.splitter import parse_kv from ansible.parsing.splitter import parse_kv
from ansible.utils.unicode import to_unicode from ansible.utils.unicode import to_unicode
def _validate_mutable_mappings(a, b):
"""
Internal convenience function to ensure arguments are MutableMappings
This checks that all arguments are MutableMappings or raises an error
:raises AnsibleError: if one of the arguments is not a MutableMapping
"""
# If this becomes generally needed, change the signature to operate on
# a variable number of arguments instead.
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
raise AnsibleError("failed to combine variables, expected dicts but"
" got a '{0}' and a '{1}'".format(
a.__class__.__name__, b.__class__.__name__))
def combine_vars(a, b): def combine_vars(a, b):
"""
Return a copy of dictionaries of variables based on configured hash behavior
"""
if C.DEFAULT_HASH_BEHAVIOUR == "merge": if C.DEFAULT_HASH_BEHAVIOUR == "merge":
return merge_hash(a, b) return merge_hash(a, b)
else: else:
# HASH_BEHAVIOUR == 'replace'
_validate_mutable_mappings(a, b)
result = a.copy() result = a.copy()
result.update(b) result.update(b)
return result return result
def merge_hash(a, b): def merge_hash(a, b):
''' recursively merges hash b into a """
keys from b take precedence over keys from a ''' Recursively merges hash b into a so that keys from b take precedence over keys from a
"""
result = {} _validate_mutable_mappings(a, b)
result = a.copy()
for dicts in a, b:
# next, iterate over b keys and values # next, iterate over b keys and values
for k, v in dicts.iteritems(): for k, v in iteritems(b):
# if there's already such key in a # if there's already such key in a
# and that key contains dict # and that key contains a MutableMapping
if k in result and isinstance(result[k], dict): if k in result and isinstance(result[k], MutableMapping):
# merge those dicts recursively # merge those dicts recursively
result[k] = merge_hash(a[k], v) result[k] = merge_hash(result[k], v)
else: else:
# otherwise, just copy a value from b to a # otherwise, just copy the value from b to a
result[k] = v result[k] = v
return result return result

View file

@ -33,11 +33,12 @@ except ImportError:
from ansible import constants as C from ansible import constants as C
from ansible.cli import CLI from ansible.cli import CLI
from ansible.errors import * from ansible.errors import AnsibleError
from ansible.parsing import DataLoader from ansible.parsing import DataLoader
from ansible.plugins.cache import FactCache from ansible.plugins.cache import FactCache
from ansible.template import Templar from ansible.template import Templar
from ansible.utils.debug import debug from ansible.utils.debug import debug
from ansible.utils.vars import combine_vars
from ansible.vars.hostvars import HostVars from ansible.vars.hostvars import HostVars
CACHED_VARS = dict() CACHED_VARS = dict()
@ -104,50 +105,6 @@ class VariableManager:
return data return data
def _validate_both_dicts(self, a, b):
'''
Validates that both arguments are dictionaries, or an error is raised.
'''
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
raise AnsibleError("failed to combine variables, expected dicts but got a '%s' and a '%s'" % (type(a).__name__, type(b).__name__))
def _combine_vars(self, a, b):
'''
Combines dictionaries of variables, based on the hash behavior
'''
self._validate_both_dicts(a, b)
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
return self._merge_dicts(a, b)
else:
result = a.copy()
result.update(b)
return result
def _merge_dicts(self, a, b):
'''
Recursively merges dict b into a, so that keys
from b take precedence over keys from a.
'''
result = dict()
self._validate_both_dicts(a, b)
for dicts in a, b:
# next, iterate over b keys and values
for k, v in dicts.iteritems():
# if there's already such key in a
# and that key contains dict
if k in result and isinstance(result[k], dict):
# merge those dicts recursively
result[k] = self._merge_dicts(a[k], v)
else:
# otherwise, just copy a value from b to a
result[k] = v
return result
def get_vars(self, loader, play=None, host=None, task=None, include_hostvars=True, use_cache=True): def get_vars(self, loader, play=None, host=None, task=None, include_hostvars=True, use_cache=True):
''' '''
@ -181,13 +138,13 @@ class VariableManager:
# first we compile any vars specified in defaults/main.yml # first we compile any vars specified in defaults/main.yml
# for all roles within the specified play # for all roles within the specified play
for role in play.get_roles(): for role in play.get_roles():
all_vars = self._combine_vars(all_vars, role.get_default_vars()) all_vars = combine_vars(all_vars, role.get_default_vars())
# if we have a task in this context, and that task has a role, make # if we have a task in this context, and that task has a role, make
# sure it sees its defaults above any other roles, as we previously # sure it sees its defaults above any other roles, as we previously
# (v1) made sure each task had a copy of its roles default vars # (v1) made sure each task had a copy of its roles default vars
if task and task._role is not None: if task and task._role is not None:
all_vars = self._combine_vars(all_vars, task._role.get_default_vars()) all_vars = combine_vars(all_vars, task._role.get_default_vars())
if host: if host:
# next, if a host is specified, we load any vars from group_vars # next, if a host is specified, we load any vars from group_vars
@ -198,38 +155,38 @@ class VariableManager:
if 'all' in self._group_vars_files: if 'all' in self._group_vars_files:
data = self._preprocess_vars(self._group_vars_files['all']) data = self._preprocess_vars(self._group_vars_files['all'])
for item in data: for item in data:
all_vars = self._combine_vars(all_vars, item) all_vars = combine_vars(all_vars, item)
for group in host.get_groups(): for group in host.get_groups():
all_vars = self._combine_vars(all_vars, group.get_vars()) all_vars = combine_vars(all_vars, group.get_vars())
if group.name in self._group_vars_files and group.name != 'all': if group.name in self._group_vars_files and group.name != 'all':
data = self._preprocess_vars(self._group_vars_files[group.name]) data = self._preprocess_vars(self._group_vars_files[group.name])
for item in data: for item in data:
all_vars = self._combine_vars(all_vars, item) all_vars = combine_vars(all_vars, item)
host_name = host.get_name() host_name = host.get_name()
if host_name in self._host_vars_files: if host_name in self._host_vars_files:
data = self._preprocess_vars(self._host_vars_files[host_name]) data = self._preprocess_vars(self._host_vars_files[host_name])
for item in data: for item in data:
all_vars = self._combine_vars(all_vars, self._host_vars_files[host_name]) all_vars = combine_vars(all_vars, self._host_vars_files[host_name])
# then we merge in vars specified for this host # then we merge in vars specified for this host
all_vars = self._combine_vars(all_vars, host.get_vars()) all_vars = combine_vars(all_vars, host.get_vars())
# next comes the facts cache and the vars cache, respectively # next comes the facts cache and the vars cache, respectively
try: try:
all_vars = self._combine_vars(all_vars, self._fact_cache.get(host.name, dict())) all_vars = combine_vars(all_vars, self._fact_cache.get(host.name, dict()))
except KeyError: except KeyError:
pass pass
if play: if play:
all_vars = self._combine_vars(all_vars, play.get_vars()) all_vars = combine_vars(all_vars, play.get_vars())
for vars_file_item in play.get_vars_files(): for vars_file_item in play.get_vars_files():
try: try:
# create a set of temporary vars here, which incorporate the # create a set of temporary vars here, which incorporate the
# extra vars so we can properly template the vars_files entries # extra vars so we can properly template the vars_files entries
temp_vars = self._combine_vars(all_vars, self._extra_vars) temp_vars = combine_vars(all_vars, self._extra_vars)
templar = Templar(loader=loader, variables=temp_vars) templar = Templar(loader=loader, variables=temp_vars)
# we assume each item in the list is itself a list, as we # we assume each item in the list is itself a list, as we
@ -246,26 +203,26 @@ class VariableManager:
data = self._preprocess_vars(loader.load_from_file(vars_file)) data = self._preprocess_vars(loader.load_from_file(vars_file))
if data is not None: if data is not None:
for item in data: for item in data:
all_vars = self._combine_vars(all_vars, item) all_vars = combine_vars(all_vars, item)
break break
else: else:
raise AnsibleError("vars file %s was not found" % vars_file_item) raise AnsibleError("vars file %s was not found" % vars_file_item)
except UndefinedError as e: except UndefinedError:
continue continue
if not C.DEFAULT_PRIVATE_ROLE_VARS: if not C.DEFAULT_PRIVATE_ROLE_VARS:
for role in play.get_roles(): for role in play.get_roles():
all_vars = self._combine_vars(all_vars, role.get_vars()) all_vars = combine_vars(all_vars, role.get_vars())
if task: if task:
if task._role: if task._role:
all_vars = self._combine_vars(all_vars, task._role.get_vars()) all_vars = combine_vars(all_vars, task._role.get_vars())
all_vars = self._combine_vars(all_vars, task.get_vars()) all_vars = combine_vars(all_vars, task.get_vars())
if host: if host:
all_vars = self._combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict())) all_vars = combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
all_vars = self._combine_vars(all_vars, self._extra_vars) all_vars = combine_vars(all_vars, self._extra_vars)
# FIXME: make sure all special vars are here # FIXME: make sure all special vars are here
# Finally, we create special vars # Finally, we create special vars
@ -345,7 +302,7 @@ class VariableManager:
for p in paths: for p in paths:
_found, results = self._load_inventory_file(path=p, loader=loader) _found, results = self._load_inventory_file(path=p, loader=loader)
if results is not None: if results is not None:
data = self._combine_vars(data, results) data = combine_vars(data, results)
else: else:
file_name, ext = os.path.splitext(path) file_name, ext = os.path.splitext(path)

View file

View file

@ -0,0 +1,98 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015, Toshio Kuraotmi <tkuratomi@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from collections import defaultdict
from ansible.compat.tests import mock, unittest
from ansible.errors import AnsibleError
from ansible.utils.vars import combine_vars, merge_hash
class TestVariableUtils(unittest.TestCase):
test_merge_data = (
dict(
a=dict(a=1),
b=dict(b=2),
result=dict(a=1, b=2)
),
dict(
a=dict(a=1, c=dict(foo='bar')),
b=dict(b=2, c=dict(baz='bam')),
result=dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
),
dict(
a=defaultdict(a=1, c=defaultdict(foo='bar')),
b=dict(b=2, c=dict(baz='bam')),
result=defaultdict(a=1, b=2, c=defaultdict(foo='bar', baz='bam'))
),
)
test_replace_data = (
dict(
a=dict(a=1),
b=dict(b=2),
result=dict(a=1, b=2)
),
dict(
a=dict(a=1, c=dict(foo='bar')),
b=dict(b=2, c=dict(baz='bam')),
result=dict(a=1, b=2, c=dict(baz='bam'))
),
dict(
a=defaultdict(a=1, c=dict(foo='bar')),
b=dict(b=2, c=defaultdict(baz='bam')),
result=defaultdict(a=1, b=2, c=defaultdict(baz='bam'))
),
)
def setUp(self):
pass
def tearDown(self):
pass
def test_merge_hash(self):
for test in self.test_merge_data:
self.assertEqual(merge_hash(test['a'], test['b']), test['result'])
def test_improper_args(self):
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
with self.assertRaises(AnsibleError):
combine_vars([1, 2, 3], dict(a=1))
with self.assertRaises(AnsibleError):
combine_vars(dict(a=1), [1, 2, 3])
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
with self.assertRaises(AnsibleError):
combine_vars([1, 2, 3], dict(a=1))
with self.assertRaises(AnsibleError):
combine_vars(dict(a=1), [1, 2, 3])
def test_combine_vars_replace(self):
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
for test in self.test_replace_data:
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])
def test_combine_vars_merge(self):
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
for test in self.test_merge_data:
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])

View file

@ -48,20 +48,6 @@ class TestVariableManager(unittest.TestCase):
self.assertEqual(vars, dict(playbook_dir='.')) self.assertEqual(vars, dict(playbook_dir='.'))
self.assertEqual(
v._merge_dicts(
dict(a=1),
dict(b=2)
), dict(a=1, b=2)
)
self.assertEqual(
v._merge_dicts(
dict(a=1, c=dict(foo='bar')),
dict(b=2, c=dict(baz='bam'))
), dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
)
def test_variable_manager_extra_vars(self): def test_variable_manager_extra_vars(self):
fake_loader = DictDataLoader({}) fake_loader = DictDataLoader({})