mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Cleanup combine_vars
* Dedupe combine_vars() code (removed from VariableManager) * Fix merge_hash algorithm to preserve the type * unittest combine_vars and merge_hash
This commit is contained in:
parent
7fe495d619
commit
aeff960d02
5 changed files with 158 additions and 92 deletions
|
@ -20,38 +20,63 @@ from __future__ import (absolute_import, division, print_function)
|
|||
__metaclass__ = type
|
||||
|
||||
import ast
|
||||
from collections import MutableMapping
|
||||
|
||||
from six import string_types
|
||||
from six import iteritems, string_types
|
||||
|
||||
from ansible import constants as C
|
||||
from ansible.errors import AnsibleError
|
||||
from ansible.parsing.splitter import parse_kv
|
||||
from ansible.utils.unicode import to_unicode
|
||||
|
||||
def _validate_mutable_mappings(a, b):
|
||||
"""
|
||||
Internal convenience function to ensure arguments are MutableMappings
|
||||
|
||||
This checks that all arguments are MutableMappings or raises an error
|
||||
|
||||
:raises AnsibleError: if one of the arguments is not a MutableMapping
|
||||
"""
|
||||
|
||||
# If this becomes generally needed, change the signature to operate on
|
||||
# a variable number of arguments instead.
|
||||
|
||||
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
|
||||
raise AnsibleError("failed to combine variables, expected dicts but"
|
||||
" got a '{0}' and a '{1}'".format(
|
||||
a.__class__.__name__, b.__class__.__name__))
|
||||
|
||||
def combine_vars(a, b):
|
||||
"""
|
||||
Return a copy of dictionaries of variables based on configured hash behavior
|
||||
"""
|
||||
|
||||
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
|
||||
return merge_hash(a, b)
|
||||
else:
|
||||
# HASH_BEHAVIOUR == 'replace'
|
||||
_validate_mutable_mappings(a, b)
|
||||
result = a.copy()
|
||||
result.update(b)
|
||||
return result
|
||||
|
||||
def merge_hash(a, b):
|
||||
''' recursively merges hash b into a
|
||||
keys from b take precedence over keys from a '''
|
||||
"""
|
||||
Recursively merges hash b into a so that keys from b take precedence over keys from a
|
||||
"""
|
||||
|
||||
result = {}
|
||||
_validate_mutable_mappings(a, b)
|
||||
result = a.copy()
|
||||
|
||||
for dicts in a, b:
|
||||
# next, iterate over b keys and values
|
||||
for k, v in dicts.iteritems():
|
||||
for k, v in iteritems(b):
|
||||
# if there's already such key in a
|
||||
# and that key contains dict
|
||||
if k in result and isinstance(result[k], dict):
|
||||
# and that key contains a MutableMapping
|
||||
if k in result and isinstance(result[k], MutableMapping):
|
||||
# merge those dicts recursively
|
||||
result[k] = merge_hash(a[k], v)
|
||||
result[k] = merge_hash(result[k], v)
|
||||
else:
|
||||
# otherwise, just copy a value from b to a
|
||||
# otherwise, just copy the value from b to a
|
||||
result[k] = v
|
||||
|
||||
return result
|
||||
|
|
|
@ -33,11 +33,12 @@ except ImportError:
|
|||
|
||||
from ansible import constants as C
|
||||
from ansible.cli import CLI
|
||||
from ansible.errors import *
|
||||
from ansible.errors import AnsibleError
|
||||
from ansible.parsing import DataLoader
|
||||
from ansible.plugins.cache import FactCache
|
||||
from ansible.template import Templar
|
||||
from ansible.utils.debug import debug
|
||||
from ansible.utils.vars import combine_vars
|
||||
from ansible.vars.hostvars import HostVars
|
||||
|
||||
CACHED_VARS = dict()
|
||||
|
@ -104,50 +105,6 @@ class VariableManager:
|
|||
|
||||
return data
|
||||
|
||||
def _validate_both_dicts(self, a, b):
|
||||
'''
|
||||
Validates that both arguments are dictionaries, or an error is raised.
|
||||
'''
|
||||
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
|
||||
raise AnsibleError("failed to combine variables, expected dicts but got a '%s' and a '%s'" % (type(a).__name__, type(b).__name__))
|
||||
|
||||
def _combine_vars(self, a, b):
|
||||
'''
|
||||
Combines dictionaries of variables, based on the hash behavior
|
||||
'''
|
||||
|
||||
self._validate_both_dicts(a, b)
|
||||
|
||||
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
|
||||
return self._merge_dicts(a, b)
|
||||
else:
|
||||
result = a.copy()
|
||||
result.update(b)
|
||||
return result
|
||||
|
||||
def _merge_dicts(self, a, b):
|
||||
'''
|
||||
Recursively merges dict b into a, so that keys
|
||||
from b take precedence over keys from a.
|
||||
'''
|
||||
|
||||
result = dict()
|
||||
|
||||
self._validate_both_dicts(a, b)
|
||||
|
||||
for dicts in a, b:
|
||||
# next, iterate over b keys and values
|
||||
for k, v in dicts.iteritems():
|
||||
# if there's already such key in a
|
||||
# and that key contains dict
|
||||
if k in result and isinstance(result[k], dict):
|
||||
# merge those dicts recursively
|
||||
result[k] = self._merge_dicts(a[k], v)
|
||||
else:
|
||||
# otherwise, just copy a value from b to a
|
||||
result[k] = v
|
||||
|
||||
return result
|
||||
|
||||
def get_vars(self, loader, play=None, host=None, task=None, include_hostvars=True, use_cache=True):
|
||||
'''
|
||||
|
@ -181,13 +138,13 @@ class VariableManager:
|
|||
# first we compile any vars specified in defaults/main.yml
|
||||
# for all roles within the specified play
|
||||
for role in play.get_roles():
|
||||
all_vars = self._combine_vars(all_vars, role.get_default_vars())
|
||||
all_vars = combine_vars(all_vars, role.get_default_vars())
|
||||
|
||||
# if we have a task in this context, and that task has a role, make
|
||||
# sure it sees its defaults above any other roles, as we previously
|
||||
# (v1) made sure each task had a copy of its roles default vars
|
||||
if task and task._role is not None:
|
||||
all_vars = self._combine_vars(all_vars, task._role.get_default_vars())
|
||||
all_vars = combine_vars(all_vars, task._role.get_default_vars())
|
||||
|
||||
if host:
|
||||
# next, if a host is specified, we load any vars from group_vars
|
||||
|
@ -198,38 +155,38 @@ class VariableManager:
|
|||
if 'all' in self._group_vars_files:
|
||||
data = self._preprocess_vars(self._group_vars_files['all'])
|
||||
for item in data:
|
||||
all_vars = self._combine_vars(all_vars, item)
|
||||
all_vars = combine_vars(all_vars, item)
|
||||
|
||||
for group in host.get_groups():
|
||||
all_vars = self._combine_vars(all_vars, group.get_vars())
|
||||
all_vars = combine_vars(all_vars, group.get_vars())
|
||||
if group.name in self._group_vars_files and group.name != 'all':
|
||||
data = self._preprocess_vars(self._group_vars_files[group.name])
|
||||
for item in data:
|
||||
all_vars = self._combine_vars(all_vars, item)
|
||||
all_vars = combine_vars(all_vars, item)
|
||||
|
||||
host_name = host.get_name()
|
||||
if host_name in self._host_vars_files:
|
||||
data = self._preprocess_vars(self._host_vars_files[host_name])
|
||||
for item in data:
|
||||
all_vars = self._combine_vars(all_vars, self._host_vars_files[host_name])
|
||||
all_vars = combine_vars(all_vars, self._host_vars_files[host_name])
|
||||
|
||||
# then we merge in vars specified for this host
|
||||
all_vars = self._combine_vars(all_vars, host.get_vars())
|
||||
all_vars = combine_vars(all_vars, host.get_vars())
|
||||
|
||||
# next comes the facts cache and the vars cache, respectively
|
||||
try:
|
||||
all_vars = self._combine_vars(all_vars, self._fact_cache.get(host.name, dict()))
|
||||
all_vars = combine_vars(all_vars, self._fact_cache.get(host.name, dict()))
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if play:
|
||||
all_vars = self._combine_vars(all_vars, play.get_vars())
|
||||
all_vars = combine_vars(all_vars, play.get_vars())
|
||||
|
||||
for vars_file_item in play.get_vars_files():
|
||||
try:
|
||||
# create a set of temporary vars here, which incorporate the
|
||||
# extra vars so we can properly template the vars_files entries
|
||||
temp_vars = self._combine_vars(all_vars, self._extra_vars)
|
||||
temp_vars = combine_vars(all_vars, self._extra_vars)
|
||||
templar = Templar(loader=loader, variables=temp_vars)
|
||||
|
||||
# we assume each item in the list is itself a list, as we
|
||||
|
@ -246,26 +203,26 @@ class VariableManager:
|
|||
data = self._preprocess_vars(loader.load_from_file(vars_file))
|
||||
if data is not None:
|
||||
for item in data:
|
||||
all_vars = self._combine_vars(all_vars, item)
|
||||
all_vars = combine_vars(all_vars, item)
|
||||
break
|
||||
else:
|
||||
raise AnsibleError("vars file %s was not found" % vars_file_item)
|
||||
except UndefinedError as e:
|
||||
except UndefinedError:
|
||||
continue
|
||||
|
||||
if not C.DEFAULT_PRIVATE_ROLE_VARS:
|
||||
for role in play.get_roles():
|
||||
all_vars = self._combine_vars(all_vars, role.get_vars())
|
||||
all_vars = combine_vars(all_vars, role.get_vars())
|
||||
|
||||
if task:
|
||||
if task._role:
|
||||
all_vars = self._combine_vars(all_vars, task._role.get_vars())
|
||||
all_vars = self._combine_vars(all_vars, task.get_vars())
|
||||
all_vars = combine_vars(all_vars, task._role.get_vars())
|
||||
all_vars = combine_vars(all_vars, task.get_vars())
|
||||
|
||||
if host:
|
||||
all_vars = self._combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
|
||||
all_vars = combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
|
||||
|
||||
all_vars = self._combine_vars(all_vars, self._extra_vars)
|
||||
all_vars = combine_vars(all_vars, self._extra_vars)
|
||||
|
||||
# FIXME: make sure all special vars are here
|
||||
# Finally, we create special vars
|
||||
|
@ -345,7 +302,7 @@ class VariableManager:
|
|||
for p in paths:
|
||||
_found, results = self._load_inventory_file(path=p, loader=loader)
|
||||
if results is not None:
|
||||
data = self._combine_vars(data, results)
|
||||
data = combine_vars(data, results)
|
||||
|
||||
else:
|
||||
file_name, ext = os.path.splitext(path)
|
||||
|
|
0
test/units/utils/__init__.py
Normal file
0
test/units/utils/__init__.py
Normal file
98
test/units/utils/test_vars.py
Normal file
98
test/units/utils/test_vars.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
|
||||
# (c) 2015, Toshio Kuraotmi <tkuratomi@ansible.com>
|
||||
#
|
||||
# This file is part of Ansible
|
||||
#
|
||||
# Ansible is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Ansible is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# Make coding more python3-ish
|
||||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from ansible.compat.tests import mock, unittest
|
||||
from ansible.errors import AnsibleError
|
||||
|
||||
from ansible.utils.vars import combine_vars, merge_hash
|
||||
|
||||
class TestVariableUtils(unittest.TestCase):
|
||||
|
||||
test_merge_data = (
|
||||
dict(
|
||||
a=dict(a=1),
|
||||
b=dict(b=2),
|
||||
result=dict(a=1, b=2)
|
||||
),
|
||||
dict(
|
||||
a=dict(a=1, c=dict(foo='bar')),
|
||||
b=dict(b=2, c=dict(baz='bam')),
|
||||
result=dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
|
||||
),
|
||||
dict(
|
||||
a=defaultdict(a=1, c=defaultdict(foo='bar')),
|
||||
b=dict(b=2, c=dict(baz='bam')),
|
||||
result=defaultdict(a=1, b=2, c=defaultdict(foo='bar', baz='bam'))
|
||||
),
|
||||
)
|
||||
test_replace_data = (
|
||||
dict(
|
||||
a=dict(a=1),
|
||||
b=dict(b=2),
|
||||
result=dict(a=1, b=2)
|
||||
),
|
||||
dict(
|
||||
a=dict(a=1, c=dict(foo='bar')),
|
||||
b=dict(b=2, c=dict(baz='bam')),
|
||||
result=dict(a=1, b=2, c=dict(baz='bam'))
|
||||
),
|
||||
dict(
|
||||
a=defaultdict(a=1, c=dict(foo='bar')),
|
||||
b=dict(b=2, c=defaultdict(baz='bam')),
|
||||
result=defaultdict(a=1, b=2, c=defaultdict(baz='bam'))
|
||||
),
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_merge_hash(self):
|
||||
for test in self.test_merge_data:
|
||||
self.assertEqual(merge_hash(test['a'], test['b']), test['result'])
|
||||
|
||||
def test_improper_args(self):
|
||||
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
|
||||
with self.assertRaises(AnsibleError):
|
||||
combine_vars([1, 2, 3], dict(a=1))
|
||||
with self.assertRaises(AnsibleError):
|
||||
combine_vars(dict(a=1), [1, 2, 3])
|
||||
|
||||
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
|
||||
with self.assertRaises(AnsibleError):
|
||||
combine_vars([1, 2, 3], dict(a=1))
|
||||
with self.assertRaises(AnsibleError):
|
||||
combine_vars(dict(a=1), [1, 2, 3])
|
||||
|
||||
def test_combine_vars_replace(self):
|
||||
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
|
||||
for test in self.test_replace_data:
|
||||
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])
|
||||
|
||||
def test_combine_vars_merge(self):
|
||||
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
|
||||
for test in self.test_merge_data:
|
||||
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])
|
|
@ -48,20 +48,6 @@ class TestVariableManager(unittest.TestCase):
|
|||
|
||||
self.assertEqual(vars, dict(playbook_dir='.'))
|
||||
|
||||
self.assertEqual(
|
||||
v._merge_dicts(
|
||||
dict(a=1),
|
||||
dict(b=2)
|
||||
), dict(a=1, b=2)
|
||||
)
|
||||
self.assertEqual(
|
||||
v._merge_dicts(
|
||||
dict(a=1, c=dict(foo='bar')),
|
||||
dict(b=2, c=dict(baz='bam'))
|
||||
), dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
|
||||
)
|
||||
|
||||
|
||||
def test_variable_manager_extra_vars(self):
|
||||
fake_loader = DictDataLoader({})
|
||||
|
||||
|
|
Loading…
Reference in a new issue