mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
adds new common functions for declarative intent modules (#25210)
* adds new common functions for declarative intent modules * adds Entity and EntityCollection * adds dict_diff and dict_combine * update for CI PEP8 compliance * more CI PEP8 fixes * more PEP8 CI clean up * refactors the lambda assignments into top level classes this is to be in compliant the PEP8 CI sanity checks * one last pep8 ci fix
This commit is contained in:
parent
43468b825d
commit
3aa41eda0b
2 changed files with 287 additions and 26 deletions
|
@ -24,7 +24,10 @@
|
|||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
||||
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
from itertools import chain
|
||||
|
||||
from ansible.module_utils.six import iteritems
|
||||
from ansible.module_utils.basic import AnsibleFallbackNotFound
|
||||
from ansible.module_utils.six import iteritems
|
||||
|
||||
|
@ -38,7 +41,13 @@ def to_list(val):
|
|||
return list()
|
||||
|
||||
|
||||
class ComplexDict(object):
|
||||
def sort_list(val):
|
||||
if isinstance(val, list):
|
||||
return sorted(val)
|
||||
return val
|
||||
|
||||
|
||||
class Entity(object):
|
||||
"""Transforms a dict to with an argument spec
|
||||
|
||||
This class will take a dict and apply an Ansible argument spec to the
|
||||
|
@ -52,7 +61,7 @@ class ComplexDict(object):
|
|||
display=dict(default='text', choices=['text', 'json']),
|
||||
validate=dict(type='bool')
|
||||
)
|
||||
transform = ComplexDict(argument_spec, module)
|
||||
transform = Entity(module, argument_spec)
|
||||
value = dict(command='foo')
|
||||
result = transform(value)
|
||||
print result
|
||||
|
@ -66,31 +75,42 @@ class ComplexDict(object):
|
|||
* fallback - implements fallback function
|
||||
* choices - set of valid options
|
||||
* default - default value
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, attrs, module):
|
||||
self._attributes = attrs
|
||||
def __init__(self, module, attrs=None, args=[], keys=None, from_argspec=False):
|
||||
self._attributes = attrs or {}
|
||||
self._module = module
|
||||
|
||||
for arg in args:
|
||||
self._attributes[arg] = dict()
|
||||
if from_argspec:
|
||||
self._attributes[arg]['read_from'] = arg
|
||||
if keys and arg in keys:
|
||||
self._attributes[arg]['key'] = True
|
||||
|
||||
self.attr_names = frozenset(self._attributes.keys())
|
||||
|
||||
self._has_key = False
|
||||
_has_key = False
|
||||
|
||||
for name, attr in iteritems(self._attributes):
|
||||
if attr.get('read_from'):
|
||||
if attr['read_from'] not in self._module.argument_spec:
|
||||
module.fail_json(msg='argument %s does not exist' % attr['read_from'])
|
||||
spec = self._module.argument_spec.get(attr['read_from'])
|
||||
if not spec:
|
||||
raise ValueError('argument_spec %s does not exist' % attr['read_from'])
|
||||
for key, value in iteritems(spec):
|
||||
if key not in attr:
|
||||
attr[key] = value
|
||||
|
||||
if attr.get('key'):
|
||||
if self._has_key:
|
||||
raise ValueError('only one key value can be specified')
|
||||
self._has_key = True
|
||||
if _has_key:
|
||||
module.fail_json(msg='only one key value can be specified')
|
||||
_has_key = True
|
||||
attr['required'] = True
|
||||
|
||||
def _dict(self, value):
|
||||
def serialize(self):
|
||||
return self._attributes
|
||||
|
||||
def to_dict(self, value):
|
||||
obj = {}
|
||||
for name, attr in iteritems(self._attributes):
|
||||
if attr.get('key'):
|
||||
|
@ -99,16 +119,17 @@ class ComplexDict(object):
|
|||
obj[name] = attr.get('default')
|
||||
return obj
|
||||
|
||||
def __call__(self, value):
|
||||
def __call__(self, value, strict=True):
|
||||
if not isinstance(value, dict):
|
||||
value = self._dict(value)
|
||||
value = self.to_dict(value)
|
||||
|
||||
unknown = set(value).difference(self.attr_names)
|
||||
if unknown:
|
||||
raise ValueError('invalid keys: %s' % ','.join(unknown))
|
||||
if strict:
|
||||
unknown = set(value).difference(self.attr_names)
|
||||
if unknown:
|
||||
self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown))
|
||||
|
||||
for name, attr in iteritems(self._attributes):
|
||||
if not value.get(name):
|
||||
if value.get(name) is None:
|
||||
value[name] = attr.get('default')
|
||||
|
||||
if attr.get('fallback') and not value.get(name):
|
||||
|
@ -128,24 +149,135 @@ class ComplexDict(object):
|
|||
continue
|
||||
|
||||
if attr.get('required') and value.get(name) is None:
|
||||
raise ValueError('missing required attribute %s' % name)
|
||||
self._module.fail_json(msg='missing required attribute %s' % name)
|
||||
|
||||
if 'choices' in attr:
|
||||
if value[name] not in attr['choices']:
|
||||
raise ValueError('%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name]))
|
||||
self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name]))
|
||||
|
||||
if value[name] is not None:
|
||||
value_type = attr.get('type', 'str')
|
||||
type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type]
|
||||
type_checker(value[name])
|
||||
elif value.get(name):
|
||||
value[name] = self._module.params[name]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class ComplexList(ComplexDict):
|
||||
"""Extends ```ComplexDict``` to handle a list of dicts """
|
||||
class EntityCollection(Entity):
|
||||
"""Extends ```Entity``` to handle a list of dicts """
|
||||
|
||||
def __call__(self, values):
|
||||
if not isinstance(values, (list, tuple)):
|
||||
raise TypeError('value must be an ordered iterable')
|
||||
return [(super(ComplexList, self).__call__(v)) for v in values]
|
||||
def __call__(self, iterable, strict=True):
|
||||
if iterable is None:
|
||||
iterable = [super(EntityCollection, self).__call__(self._module.params, strict)]
|
||||
|
||||
if not isinstance(iterable, (list, tuple)):
|
||||
module.fail_json(msg='value must be an iterable')
|
||||
|
||||
return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable]
|
||||
|
||||
|
||||
# these two are for backwards compatibility and can be removed once all of the
|
||||
# modules that use them are updated
|
||||
class ComplexDict(Entity):
|
||||
def __init__(self, attrs, module, *args, **kwargs):
|
||||
super(ComplexDict, self).__init__(module, attrs, *args, **kwargs)
|
||||
|
||||
|
||||
class ComplexList(EntityCollection):
|
||||
def __init__(self, attrs, module, *args, **kwargs):
|
||||
super(ComplexList, self).__init__(module, attrs, *args, **kwargs)
|
||||
|
||||
|
||||
def dict_diff(base, comparable):
|
||||
""" Generate a dict object of differences
|
||||
|
||||
This function will compare two dict objects and return the difference
|
||||
between them as a dict object. For scalar values, the key will reflect
|
||||
the updated value. If the key does not exist in `comparable`, then then no
|
||||
key will be returned. For lists, the value in comparable will wholly replace
|
||||
the value in base for the key. For dicts, the returned value will only
|
||||
return keys that are different.
|
||||
|
||||
:param base: dict object to base the diff on
|
||||
:param comparable: dict object to compare against base
|
||||
|
||||
:returns: new dict object with differences
|
||||
"""
|
||||
assert isinstance(base, dict), "`base` must be of type <dict>"
|
||||
assert isinstance(comparable, dict), "`comparable` must be of type <dict>"
|
||||
|
||||
updates = dict()
|
||||
|
||||
for key, value in iteritems(base):
|
||||
if isinstance(value, dict):
|
||||
item = comparable.get(key)
|
||||
if item is not None:
|
||||
updates[key] = dict_diff(value, comparable[key])
|
||||
else:
|
||||
comparable_value = comparable.get(key)
|
||||
if comparable_value is not None:
|
||||
if sort_list(base[key]) != sort_list(comparable_value):
|
||||
updates[key] = comparable_value
|
||||
|
||||
for key in set(comparable.keys()).difference(base.keys()):
|
||||
updates[key] = comparable.get(key)
|
||||
|
||||
return updates
|
||||
|
||||
|
||||
def dict_combine(base, other):
|
||||
""" Return a new dict object that combines base and other
|
||||
|
||||
This will create a new dict object that is a combination of the key/value
|
||||
pairs from base and other. When both keys exist, the value will be
|
||||
selected from other. If the value is a list object, the two lists will
|
||||
be combined and duplicate entries removed.
|
||||
|
||||
:param base: dict object to serve as base
|
||||
:param other: dict object to combine with base
|
||||
|
||||
:returns: new combined dict object
|
||||
"""
|
||||
assert isinstance(base, dict), "`base` must be of type <dict>"
|
||||
assert isinstance(other, dict), "`other` must be of type <dict>"
|
||||
|
||||
combined = dict()
|
||||
|
||||
for key, value in iteritems(base):
|
||||
if isinstance(value, dict):
|
||||
if key in other:
|
||||
item = other.get(key)
|
||||
if item is not None:
|
||||
combined[key] = dict_combine(value, other[key])
|
||||
else:
|
||||
combined[key] = item
|
||||
else:
|
||||
combined[key] = value
|
||||
elif isinstance(value, list):
|
||||
if key in other:
|
||||
item = other.get(key)
|
||||
if item is not None:
|
||||
combined[key] = list(set(chain(value, item)))
|
||||
else:
|
||||
combined[key] = item
|
||||
else:
|
||||
combined[key] = value
|
||||
else:
|
||||
if key in other:
|
||||
other_value = other.get(key)
|
||||
if other_value is not None:
|
||||
if sort_list(base[key]) != sort_list(other_value):
|
||||
combined[key] = other_value
|
||||
else:
|
||||
combined[key] = value
|
||||
else:
|
||||
combined[key] = other_value
|
||||
else:
|
||||
combined[key] = value
|
||||
|
||||
for key in set(other.keys()).difference(base.keys()):
|
||||
combined[key] = other.get(key)
|
||||
|
||||
return combined
|
||||
|
|
129
test/units/module_utils/test_network_common.py
Normal file
129
test/units/module_utils/test_network_common.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# (c) 2017 Red Hat, Inc.
|
||||
#
|
||||
# This file is part of Ansible
|
||||
#
|
||||
# Ansible is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Ansible is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# Make coding more python3-ish
|
||||
from __future__ import (absolute_import, division)
|
||||
__metaclass__ = type
|
||||
|
||||
from ansible.compat.tests import unittest
|
||||
|
||||
from ansible.module_utils.network_common import to_list, sort_list
|
||||
from ansible.module_utils.network_common import dict_diff, dict_combine
|
||||
|
||||
|
||||
class TestModuleUtilsNetworkCommon(unittest.TestCase):
|
||||
|
||||
def test_to_list(self):
|
||||
for scalar in ('string', 1, True, False, None):
|
||||
self.assertTrue(isinstance(to_list(scalar), list))
|
||||
|
||||
for container in ([1, 2, 3], {'one': 1}):
|
||||
self.assertTrue(isinstance(to_list(container), list))
|
||||
|
||||
test_list = [1, 2, 3]
|
||||
self.assertNotEqual(id(test_list), id(to_list(test_list)))
|
||||
|
||||
def test_sort(self):
|
||||
data = [3, 1, 2]
|
||||
self.assertEqual([1, 2, 3], sort_list(data))
|
||||
|
||||
string_data = '123'
|
||||
self.assertEqual(string_data, sort_list(string_data))
|
||||
|
||||
def test_dict_diff(self):
|
||||
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
|
||||
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
|
||||
l1=[1, 3], l2=[1, 2, 3], l4=[4],
|
||||
nested=dict(n1=dict(n2=2)))
|
||||
|
||||
other = dict(b1=True, b2=False, b3=True, b4=True,
|
||||
one=1, three=4, four=4, obj1=dict(key1=2),
|
||||
l1=[2, 1], l2=[3, 2, 1], l3=[1],
|
||||
nested=dict(n1=dict(n2=2, n3=3)))
|
||||
|
||||
result = dict_diff(base, other)
|
||||
|
||||
# string assertions
|
||||
self.assertNotIn('one', result)
|
||||
self.assertNotIn('two', result)
|
||||
self.assertEqual(result['three'], 4)
|
||||
self.assertEqual(result['four'], 4)
|
||||
|
||||
# dict assertions
|
||||
self.assertIn('obj1', result)
|
||||
self.assertIn('key1', result['obj1'])
|
||||
self.assertNotIn('key2', result['obj1'])
|
||||
|
||||
# list assertions
|
||||
self.assertEqual(result['l1'], [2, 1])
|
||||
self.assertNotIn('l2', result)
|
||||
self.assertEqual(result['l3'], [1])
|
||||
self.assertNotIn('l4', result)
|
||||
|
||||
# nested assertions
|
||||
self.assertIn('obj1', result)
|
||||
self.assertEqual(result['obj1']['key1'], 2)
|
||||
self.assertNotIn('key2', result['obj1'])
|
||||
|
||||
# bool assertions
|
||||
self.assertNotIn('b1', result)
|
||||
self.assertNotIn('b2', result)
|
||||
self.assertTrue(result['b3'])
|
||||
self.assertTrue(result['b4'])
|
||||
|
||||
def test_dict_combine(self):
|
||||
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
|
||||
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
|
||||
l1=[1, 3], l2=[1, 2, 3], l4=[4],
|
||||
nested=dict(n1=dict(n2=2)))
|
||||
|
||||
other = dict(b1=True, b2=False, b3=True, b4=True,
|
||||
one=1, three=4, four=4, obj1=dict(key1=2),
|
||||
l1=[2, 1], l2=[3, 2, 1], l3=[1],
|
||||
nested=dict(n1=dict(n2=2, n3=3)))
|
||||
|
||||
result = dict_combine(base, other)
|
||||
|
||||
# string assertions
|
||||
self.assertIn('one', result)
|
||||
self.assertIn('two', result)
|
||||
self.assertEqual(result['three'], 4)
|
||||
self.assertEqual(result['four'], 4)
|
||||
|
||||
# dict assertions
|
||||
self.assertIn('obj1', result)
|
||||
self.assertIn('key1', result['obj1'])
|
||||
self.assertIn('key2', result['obj1'])
|
||||
|
||||
# list assertions
|
||||
self.assertEqual(result['l1'], [1, 2, 3])
|
||||
self.assertIn('l2', result)
|
||||
self.assertEqual(result['l3'], [1])
|
||||
self.assertIn('l4', result)
|
||||
|
||||
# nested assertions
|
||||
self.assertIn('obj1', result)
|
||||
self.assertEqual(result['obj1']['key1'], 2)
|
||||
self.assertIn('key2', result['obj1'])
|
||||
|
||||
# bool assertions
|
||||
self.assertIn('b1', result)
|
||||
self.assertIn('b2', result)
|
||||
self.assertTrue(result['b3'])
|
||||
self.assertTrue(result['b4'])
|
Loading…
Reference in a new issue