mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Cleanup combine_vars
* Dedupe combine_vars() code (removed from VariableManager) * Fix merge_hash algorithm to preserve the type * unittest combine_vars and merge_hash
This commit is contained in:
parent
7fe495d619
commit
aeff960d02
5 changed files with 158 additions and 92 deletions
|
@ -20,38 +20,63 @@ from __future__ import (absolute_import, division, print_function)
|
||||||
__metaclass__ = type
|
__metaclass__ = type
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
|
from collections import MutableMapping
|
||||||
|
|
||||||
from six import string_types
|
from six import iteritems, string_types
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
|
from ansible.errors import AnsibleError
|
||||||
from ansible.parsing.splitter import parse_kv
|
from ansible.parsing.splitter import parse_kv
|
||||||
from ansible.utils.unicode import to_unicode
|
from ansible.utils.unicode import to_unicode
|
||||||
|
|
||||||
|
def _validate_mutable_mappings(a, b):
|
||||||
|
"""
|
||||||
|
Internal convenience function to ensure arguments are MutableMappings
|
||||||
|
|
||||||
|
This checks that all arguments are MutableMappings or raises an error
|
||||||
|
|
||||||
|
:raises AnsibleError: if one of the arguments is not a MutableMapping
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If this becomes generally needed, change the signature to operate on
|
||||||
|
# a variable number of arguments instead.
|
||||||
|
|
||||||
|
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
|
||||||
|
raise AnsibleError("failed to combine variables, expected dicts but"
|
||||||
|
" got a '{0}' and a '{1}'".format(
|
||||||
|
a.__class__.__name__, b.__class__.__name__))
|
||||||
|
|
||||||
def combine_vars(a, b):
|
def combine_vars(a, b):
|
||||||
|
"""
|
||||||
|
Return a copy of dictionaries of variables based on configured hash behavior
|
||||||
|
"""
|
||||||
|
|
||||||
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
|
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
|
||||||
return merge_hash(a, b)
|
return merge_hash(a, b)
|
||||||
else:
|
else:
|
||||||
|
# HASH_BEHAVIOUR == 'replace'
|
||||||
|
_validate_mutable_mappings(a, b)
|
||||||
result = a.copy()
|
result = a.copy()
|
||||||
result.update(b)
|
result.update(b)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def merge_hash(a, b):
|
def merge_hash(a, b):
|
||||||
''' recursively merges hash b into a
|
"""
|
||||||
keys from b take precedence over keys from a '''
|
Recursively merges hash b into a so that keys from b take precedence over keys from a
|
||||||
|
"""
|
||||||
|
|
||||||
result = {}
|
_validate_mutable_mappings(a, b)
|
||||||
|
result = a.copy()
|
||||||
|
|
||||||
for dicts in a, b:
|
|
||||||
# next, iterate over b keys and values
|
# next, iterate over b keys and values
|
||||||
for k, v in dicts.iteritems():
|
for k, v in iteritems(b):
|
||||||
# if there's already such key in a
|
# if there's already such key in a
|
||||||
# and that key contains dict
|
# and that key contains a MutableMapping
|
||||||
if k in result and isinstance(result[k], dict):
|
if k in result and isinstance(result[k], MutableMapping):
|
||||||
# merge those dicts recursively
|
# merge those dicts recursively
|
||||||
result[k] = merge_hash(a[k], v)
|
result[k] = merge_hash(result[k], v)
|
||||||
else:
|
else:
|
||||||
# otherwise, just copy a value from b to a
|
# otherwise, just copy the value from b to a
|
||||||
result[k] = v
|
result[k] = v
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -33,11 +33,12 @@ except ImportError:
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
from ansible.cli import CLI
|
from ansible.cli import CLI
|
||||||
from ansible.errors import *
|
from ansible.errors import AnsibleError
|
||||||
from ansible.parsing import DataLoader
|
from ansible.parsing import DataLoader
|
||||||
from ansible.plugins.cache import FactCache
|
from ansible.plugins.cache import FactCache
|
||||||
from ansible.template import Templar
|
from ansible.template import Templar
|
||||||
from ansible.utils.debug import debug
|
from ansible.utils.debug import debug
|
||||||
|
from ansible.utils.vars import combine_vars
|
||||||
from ansible.vars.hostvars import HostVars
|
from ansible.vars.hostvars import HostVars
|
||||||
|
|
||||||
CACHED_VARS = dict()
|
CACHED_VARS = dict()
|
||||||
|
@ -104,50 +105,6 @@ class VariableManager:
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _validate_both_dicts(self, a, b):
|
|
||||||
'''
|
|
||||||
Validates that both arguments are dictionaries, or an error is raised.
|
|
||||||
'''
|
|
||||||
if not (isinstance(a, MutableMapping) and isinstance(b, MutableMapping)):
|
|
||||||
raise AnsibleError("failed to combine variables, expected dicts but got a '%s' and a '%s'" % (type(a).__name__, type(b).__name__))
|
|
||||||
|
|
||||||
def _combine_vars(self, a, b):
|
|
||||||
'''
|
|
||||||
Combines dictionaries of variables, based on the hash behavior
|
|
||||||
'''
|
|
||||||
|
|
||||||
self._validate_both_dicts(a, b)
|
|
||||||
|
|
||||||
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
|
|
||||||
return self._merge_dicts(a, b)
|
|
||||||
else:
|
|
||||||
result = a.copy()
|
|
||||||
result.update(b)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _merge_dicts(self, a, b):
|
|
||||||
'''
|
|
||||||
Recursively merges dict b into a, so that keys
|
|
||||||
from b take precedence over keys from a.
|
|
||||||
'''
|
|
||||||
|
|
||||||
result = dict()
|
|
||||||
|
|
||||||
self._validate_both_dicts(a, b)
|
|
||||||
|
|
||||||
for dicts in a, b:
|
|
||||||
# next, iterate over b keys and values
|
|
||||||
for k, v in dicts.iteritems():
|
|
||||||
# if there's already such key in a
|
|
||||||
# and that key contains dict
|
|
||||||
if k in result and isinstance(result[k], dict):
|
|
||||||
# merge those dicts recursively
|
|
||||||
result[k] = self._merge_dicts(a[k], v)
|
|
||||||
else:
|
|
||||||
# otherwise, just copy a value from b to a
|
|
||||||
result[k] = v
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_vars(self, loader, play=None, host=None, task=None, include_hostvars=True, use_cache=True):
|
def get_vars(self, loader, play=None, host=None, task=None, include_hostvars=True, use_cache=True):
|
||||||
'''
|
'''
|
||||||
|
@ -181,13 +138,13 @@ class VariableManager:
|
||||||
# first we compile any vars specified in defaults/main.yml
|
# first we compile any vars specified in defaults/main.yml
|
||||||
# for all roles within the specified play
|
# for all roles within the specified play
|
||||||
for role in play.get_roles():
|
for role in play.get_roles():
|
||||||
all_vars = self._combine_vars(all_vars, role.get_default_vars())
|
all_vars = combine_vars(all_vars, role.get_default_vars())
|
||||||
|
|
||||||
# if we have a task in this context, and that task has a role, make
|
# if we have a task in this context, and that task has a role, make
|
||||||
# sure it sees its defaults above any other roles, as we previously
|
# sure it sees its defaults above any other roles, as we previously
|
||||||
# (v1) made sure each task had a copy of its roles default vars
|
# (v1) made sure each task had a copy of its roles default vars
|
||||||
if task and task._role is not None:
|
if task and task._role is not None:
|
||||||
all_vars = self._combine_vars(all_vars, task._role.get_default_vars())
|
all_vars = combine_vars(all_vars, task._role.get_default_vars())
|
||||||
|
|
||||||
if host:
|
if host:
|
||||||
# next, if a host is specified, we load any vars from group_vars
|
# next, if a host is specified, we load any vars from group_vars
|
||||||
|
@ -198,38 +155,38 @@ class VariableManager:
|
||||||
if 'all' in self._group_vars_files:
|
if 'all' in self._group_vars_files:
|
||||||
data = self._preprocess_vars(self._group_vars_files['all'])
|
data = self._preprocess_vars(self._group_vars_files['all'])
|
||||||
for item in data:
|
for item in data:
|
||||||
all_vars = self._combine_vars(all_vars, item)
|
all_vars = combine_vars(all_vars, item)
|
||||||
|
|
||||||
for group in host.get_groups():
|
for group in host.get_groups():
|
||||||
all_vars = self._combine_vars(all_vars, group.get_vars())
|
all_vars = combine_vars(all_vars, group.get_vars())
|
||||||
if group.name in self._group_vars_files and group.name != 'all':
|
if group.name in self._group_vars_files and group.name != 'all':
|
||||||
data = self._preprocess_vars(self._group_vars_files[group.name])
|
data = self._preprocess_vars(self._group_vars_files[group.name])
|
||||||
for item in data:
|
for item in data:
|
||||||
all_vars = self._combine_vars(all_vars, item)
|
all_vars = combine_vars(all_vars, item)
|
||||||
|
|
||||||
host_name = host.get_name()
|
host_name = host.get_name()
|
||||||
if host_name in self._host_vars_files:
|
if host_name in self._host_vars_files:
|
||||||
data = self._preprocess_vars(self._host_vars_files[host_name])
|
data = self._preprocess_vars(self._host_vars_files[host_name])
|
||||||
for item in data:
|
for item in data:
|
||||||
all_vars = self._combine_vars(all_vars, self._host_vars_files[host_name])
|
all_vars = combine_vars(all_vars, self._host_vars_files[host_name])
|
||||||
|
|
||||||
# then we merge in vars specified for this host
|
# then we merge in vars specified for this host
|
||||||
all_vars = self._combine_vars(all_vars, host.get_vars())
|
all_vars = combine_vars(all_vars, host.get_vars())
|
||||||
|
|
||||||
# next comes the facts cache and the vars cache, respectively
|
# next comes the facts cache and the vars cache, respectively
|
||||||
try:
|
try:
|
||||||
all_vars = self._combine_vars(all_vars, self._fact_cache.get(host.name, dict()))
|
all_vars = combine_vars(all_vars, self._fact_cache.get(host.name, dict()))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if play:
|
if play:
|
||||||
all_vars = self._combine_vars(all_vars, play.get_vars())
|
all_vars = combine_vars(all_vars, play.get_vars())
|
||||||
|
|
||||||
for vars_file_item in play.get_vars_files():
|
for vars_file_item in play.get_vars_files():
|
||||||
try:
|
try:
|
||||||
# create a set of temporary vars here, which incorporate the
|
# create a set of temporary vars here, which incorporate the
|
||||||
# extra vars so we can properly template the vars_files entries
|
# extra vars so we can properly template the vars_files entries
|
||||||
temp_vars = self._combine_vars(all_vars, self._extra_vars)
|
temp_vars = combine_vars(all_vars, self._extra_vars)
|
||||||
templar = Templar(loader=loader, variables=temp_vars)
|
templar = Templar(loader=loader, variables=temp_vars)
|
||||||
|
|
||||||
# we assume each item in the list is itself a list, as we
|
# we assume each item in the list is itself a list, as we
|
||||||
|
@ -246,26 +203,26 @@ class VariableManager:
|
||||||
data = self._preprocess_vars(loader.load_from_file(vars_file))
|
data = self._preprocess_vars(loader.load_from_file(vars_file))
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for item in data:
|
for item in data:
|
||||||
all_vars = self._combine_vars(all_vars, item)
|
all_vars = combine_vars(all_vars, item)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise AnsibleError("vars file %s was not found" % vars_file_item)
|
raise AnsibleError("vars file %s was not found" % vars_file_item)
|
||||||
except UndefinedError as e:
|
except UndefinedError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not C.DEFAULT_PRIVATE_ROLE_VARS:
|
if not C.DEFAULT_PRIVATE_ROLE_VARS:
|
||||||
for role in play.get_roles():
|
for role in play.get_roles():
|
||||||
all_vars = self._combine_vars(all_vars, role.get_vars())
|
all_vars = combine_vars(all_vars, role.get_vars())
|
||||||
|
|
||||||
if task:
|
if task:
|
||||||
if task._role:
|
if task._role:
|
||||||
all_vars = self._combine_vars(all_vars, task._role.get_vars())
|
all_vars = combine_vars(all_vars, task._role.get_vars())
|
||||||
all_vars = self._combine_vars(all_vars, task.get_vars())
|
all_vars = combine_vars(all_vars, task.get_vars())
|
||||||
|
|
||||||
if host:
|
if host:
|
||||||
all_vars = self._combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
|
all_vars = combine_vars(all_vars, self._vars_cache.get(host.get_name(), dict()))
|
||||||
|
|
||||||
all_vars = self._combine_vars(all_vars, self._extra_vars)
|
all_vars = combine_vars(all_vars, self._extra_vars)
|
||||||
|
|
||||||
# FIXME: make sure all special vars are here
|
# FIXME: make sure all special vars are here
|
||||||
# Finally, we create special vars
|
# Finally, we create special vars
|
||||||
|
@ -345,7 +302,7 @@ class VariableManager:
|
||||||
for p in paths:
|
for p in paths:
|
||||||
_found, results = self._load_inventory_file(path=p, loader=loader)
|
_found, results = self._load_inventory_file(path=p, loader=loader)
|
||||||
if results is not None:
|
if results is not None:
|
||||||
data = self._combine_vars(data, results)
|
data = combine_vars(data, results)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
file_name, ext = os.path.splitext(path)
|
file_name, ext = os.path.splitext(path)
|
||||||
|
|
0
test/units/utils/__init__.py
Normal file
0
test/units/utils/__init__.py
Normal file
98
test/units/utils/test_vars.py
Normal file
98
test/units/utils/test_vars.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
|
||||||
|
# (c) 2015, Toshio Kuraotmi <tkuratomi@ansible.com>
|
||||||
|
#
|
||||||
|
# This file is part of Ansible
|
||||||
|
#
|
||||||
|
# Ansible is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# Ansible is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License
|
||||||
|
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
# Make coding more python3-ish
|
||||||
|
from __future__ import (absolute_import, division, print_function)
|
||||||
|
__metaclass__ = type
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from ansible.compat.tests import mock, unittest
|
||||||
|
from ansible.errors import AnsibleError
|
||||||
|
|
||||||
|
from ansible.utils.vars import combine_vars, merge_hash
|
||||||
|
|
||||||
|
class TestVariableUtils(unittest.TestCase):
|
||||||
|
|
||||||
|
test_merge_data = (
|
||||||
|
dict(
|
||||||
|
a=dict(a=1),
|
||||||
|
b=dict(b=2),
|
||||||
|
result=dict(a=1, b=2)
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
a=dict(a=1, c=dict(foo='bar')),
|
||||||
|
b=dict(b=2, c=dict(baz='bam')),
|
||||||
|
result=dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
a=defaultdict(a=1, c=defaultdict(foo='bar')),
|
||||||
|
b=dict(b=2, c=dict(baz='bam')),
|
||||||
|
result=defaultdict(a=1, b=2, c=defaultdict(foo='bar', baz='bam'))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
test_replace_data = (
|
||||||
|
dict(
|
||||||
|
a=dict(a=1),
|
||||||
|
b=dict(b=2),
|
||||||
|
result=dict(a=1, b=2)
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
a=dict(a=1, c=dict(foo='bar')),
|
||||||
|
b=dict(b=2, c=dict(baz='bam')),
|
||||||
|
result=dict(a=1, b=2, c=dict(baz='bam'))
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
a=defaultdict(a=1, c=dict(foo='bar')),
|
||||||
|
b=dict(b=2, c=defaultdict(baz='bam')),
|
||||||
|
result=defaultdict(a=1, b=2, c=defaultdict(baz='bam'))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_merge_hash(self):
|
||||||
|
for test in self.test_merge_data:
|
||||||
|
self.assertEqual(merge_hash(test['a'], test['b']), test['result'])
|
||||||
|
|
||||||
|
def test_improper_args(self):
|
||||||
|
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
|
||||||
|
with self.assertRaises(AnsibleError):
|
||||||
|
combine_vars([1, 2, 3], dict(a=1))
|
||||||
|
with self.assertRaises(AnsibleError):
|
||||||
|
combine_vars(dict(a=1), [1, 2, 3])
|
||||||
|
|
||||||
|
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
|
||||||
|
with self.assertRaises(AnsibleError):
|
||||||
|
combine_vars([1, 2, 3], dict(a=1))
|
||||||
|
with self.assertRaises(AnsibleError):
|
||||||
|
combine_vars(dict(a=1), [1, 2, 3])
|
||||||
|
|
||||||
|
def test_combine_vars_replace(self):
|
||||||
|
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'replace'):
|
||||||
|
for test in self.test_replace_data:
|
||||||
|
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])
|
||||||
|
|
||||||
|
def test_combine_vars_merge(self):
|
||||||
|
with mock.patch('ansible.constants.DEFAULT_HASH_BEHAVIOUR', 'merge'):
|
||||||
|
for test in self.test_merge_data:
|
||||||
|
self.assertEqual(combine_vars(test['a'], test['b']), test['result'])
|
|
@ -48,20 +48,6 @@ class TestVariableManager(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(vars, dict(playbook_dir='.'))
|
self.assertEqual(vars, dict(playbook_dir='.'))
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
v._merge_dicts(
|
|
||||||
dict(a=1),
|
|
||||||
dict(b=2)
|
|
||||||
), dict(a=1, b=2)
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
v._merge_dicts(
|
|
||||||
dict(a=1, c=dict(foo='bar')),
|
|
||||||
dict(b=2, c=dict(baz='bam'))
|
|
||||||
), dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_variable_manager_extra_vars(self):
|
def test_variable_manager_extra_vars(self):
|
||||||
fake_loader = DictDataLoader({})
|
fake_loader = DictDataLoader({})
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue