diff --git a/lib/ansible/plugins/filter/mathstuff.py b/lib/ansible/plugins/filter/mathstuff.py index 403f752382..58ffa144a8 100644 --- a/lib/ansible/plugins/filter/mathstuff.py +++ b/lib/ansible/plugins/filter/mathstuff.py @@ -27,54 +27,93 @@ import collections import itertools import math -from ansible import errors +from jinja2.filters import environmentfilter + +from ansible.errors import AnsibleFilterError from ansible.module_utils import basic from ansible.module_utils.six import binary_type, text_type from ansible.module_utils.six.moves import zip, zip_longest -from ansible.module_utils._text import to_native +from ansible.module_utils._text import to_native, to_text + +try: + from jinja2.filters import do_unique + HAS_UNIQUE = True +except ImportError: + HAS_UNIQUE = False + +try: + from __main__ import display +except ImportError: + from ansible.utils.display import Display + display = Display() -def unique(a): - if isinstance(a, collections.Hashable): - c = set(a) - else: - c = [] - for x in a: - if x not in c: - c.append(x) +@environmentfilter +def unique(environment, a, case_sensitive=False, attribute=None): + + error = None + try: + if HAS_UNIQUE: + c = set(do_unique(environment, a, case_sensitive=case_sensitive, attribute=attribute)) + except Exception as e: + if case_sensitive or attribute: + raise AnsibleFilterError("Jinja2's unique filter failed and we cannot fall back to Ansible's version " + "as it does not support the parameters supplied", orig_exc=e) + else: + display.warning('Falling back to Ansible unique filter as Jinaj2 one failed: %s' % to_text(e)) + error = e + + if not HAS_UNIQUE or error: + + # handle Jinja2 specific attributes when using Ansible's version + if case_sensitive or attribute: + raise AnsibleFilterError("Ansible's unique filter does not support case_sensitive nor attribute parameters, " + "you need a newer version of Jinja2 that provides their version of the filter.") + + if isinstance(a, collections.Hashable): + c = set(a) + else: + c = [] + for x in a: + if x not in c: + c.append(x) return c -def intersect(a, b): +@environmentfilter +def intersect(environment, a, b): if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable): c = set(a) & set(b) else: - c = unique([x for x in a if x in b]) + c = unique(environment, [x for x in a if x in b]) return c -def difference(a, b): +@environmentfilter +def difference(environment, a, b): if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable): c = set(a) - set(b) else: - c = unique([x for x in a if x not in b]) + c = unique(environment, [x for x in a if x not in b]) return c -def symmetric_difference(a, b): +@environmentfilter +def symmetric_difference(environment, a, b): if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable): c = set(a) ^ set(b) else: - isect = intersect(a, b) - c = [x for x in union(a, b) if x not in isect] + isect = intersect(environment, a, b) + c = [x for x in union(environment, a, b) if x not in isect] return c -def union(a, b): +@environmentfilter +def union(environment, a, b): if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable): c = set(a) | set(b) else: - c = unique(a + b) + c = unique(environment, a + b) return c @@ -95,14 +134,14 @@ def logarithm(x, base=math.e): else: return math.log(x, base) except TypeError as e: - raise errors.AnsibleFilterError('log() can only be used on numbers: %s' % str(e)) + raise AnsibleFilterError('log() can only be used on numbers: %s' % str(e)) def power(x, y): try: return math.pow(x, y) except TypeError as e: - raise errors.AnsibleFilterError('pow() can only be used on numbers: %s' % str(e)) + raise AnsibleFilterError('pow() can only be used on numbers: %s' % str(e)) def inversepower(x, base=2): @@ -112,7 +151,7 @@ def inversepower(x, base=2): else: return math.pow(x, 1.0 / float(base)) except (ValueError, TypeError) as e: - raise errors.AnsibleFilterError('root() can only be used on numbers: %s' % str(e)) + raise AnsibleFilterError('root() can only be used on numbers: %s' % str(e)) def human_readable(size, isbits=False, unit=None): @@ -120,7 +159,7 @@ def human_readable(size, isbits=False, unit=None): try: return basic.bytes_to_human(size, isbits, unit) except Exception: - raise errors.AnsibleFilterError("human_readable() can't interpret following string: %s" % size) + raise AnsibleFilterError("human_readable() can't interpret following string: %s" % size) def human_to_bytes(size, default_unit=None, isbits=False): @@ -128,7 +167,7 @@ def human_to_bytes(size, default_unit=None, isbits=False): try: return basic.human_to_bytes(size, default_unit, isbits) except Exception: - raise errors.AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size) + raise AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size) def rekey_on_member(data, key, duplicates='error'): @@ -141,7 +180,7 @@ def rekey_on_member(data, key, duplicates='error'): value would be duplicated or to overwrite previous entries if that's the case. """ if duplicates not in ('error', 'overwrite'): - raise errors.AnsibleFilterError("duplicates parameter to rekey_on_member has unknown value: {0}".format(duplicates)) + raise AnsibleFilterError("duplicates parameter to rekey_on_member has unknown value: {0}".format(duplicates)) new_obj = {} @@ -150,24 +189,24 @@ def rekey_on_member(data, key, duplicates='error'): elif isinstance(data, collections.Iterable) and not isinstance(data, (text_type, binary_type)): iterate_over = data else: - raise errors.AnsibleFilterError("Type is not a valid list, set, or dict") + raise AnsibleFilterError("Type is not a valid list, set, or dict") for item in iterate_over: if not isinstance(item, collections.Mapping): - raise errors.AnsibleFilterError("List item is not a valid dict") + raise AnsibleFilterError("List item is not a valid dict") try: key_elem = item[key] except KeyError: - raise errors.AnsibleFilterError("Key {0} was not found".format(key)) + raise AnsibleFilterError("Key {0} was not found".format(key)) except Exception as e: - raise errors.AnsibleFilterError(to_native(e)) + raise AnsibleFilterError(to_native(e)) # Note: if new_obj[key_elem] exists it will always be a non-empty dict (it will at # minimun contain {key: key_elem} if new_obj.get(key_elem, None): if duplicates == 'error': - raise errors.AnsibleFilterError("Key {0} is not unique, cannot correctly turn into dict".format(key_elem)) + raise AnsibleFilterError("Key {0} is not unique, cannot correctly turn into dict".format(key_elem)) elif duplicates == 'overwrite': new_obj[key_elem] = item else: diff --git a/test/units/plugins/filter/test_mathstuff.py b/test/units/plugins/filter/test_mathstuff.py index 0107839163..f8bba423a6 100644 --- a/test/units/plugins/filter/test_mathstuff.py +++ b/test/units/plugins/filter/test_mathstuff.py @@ -4,9 +4,10 @@ # Make coding more python3-ish from __future__ import (absolute_import, division, print_function) __metaclass__ = type - import pytest +from jinja2 import Environment + import ansible.plugins.filter.mathstuff as ms from ansible.errors import AnsibleFilterError @@ -22,41 +23,43 @@ TWO_SETS_DATA = (([1, 2], [3, 4], ([], sorted([1, 2]), sorted([1, 2, 3, 4]), sor (['a', 'b', 'c'], ['d', 'c', 'e'], (['c'], sorted(['a', 'b']), sorted(['a', 'b', 'd', 'e']), sorted(['a', 'b', 'c', 'e', 'd']))), ) +env = Environment() + @pytest.mark.parametrize('data, expected', UNIQUE_DATA) class TestUnique: def test_unhashable(self, data, expected): - assert sorted(ms.unique(list(data))) == expected + assert sorted(ms.unique(env, list(data))) == expected def test_hashable(self, data, expected): - assert sorted(ms.unique(tuple(data))) == expected + assert sorted(ms.unique(env, tuple(data))) == expected @pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA) class TestIntersect: def test_unhashable(self, dataset1, dataset2, expected): - assert sorted(ms.intersect(list(dataset1), list(dataset2))) == expected[0] + assert sorted(ms.intersect(env, list(dataset1), list(dataset2))) == expected[0] def test_hashable(self, dataset1, dataset2, expected): - assert sorted(ms.intersect(tuple(dataset1), tuple(dataset2))) == expected[0] + assert sorted(ms.intersect(env, tuple(dataset1), tuple(dataset2))) == expected[0] @pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA) class TestDifference: def test_unhashable(self, dataset1, dataset2, expected): - assert sorted(ms.difference(list(dataset1), list(dataset2))) == expected[1] + assert sorted(ms.difference(env, list(dataset1), list(dataset2))) == expected[1] def test_hashable(self, dataset1, dataset2, expected): - assert sorted(ms.difference(tuple(dataset1), tuple(dataset2))) == expected[1] + assert sorted(ms.difference(env, tuple(dataset1), tuple(dataset2))) == expected[1] @pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA) class TestSymmetricDifference: def test_unhashable(self, dataset1, dataset2, expected): - assert sorted(ms.symmetric_difference(list(dataset1), list(dataset2))) == expected[2] + assert sorted(ms.symmetric_difference(env, list(dataset1), list(dataset2))) == expected[2] def test_hashable(self, dataset1, dataset2, expected): - assert sorted(ms.symmetric_difference(tuple(dataset1), tuple(dataset2))) == expected[2] + assert sorted(ms.symmetric_difference(env, tuple(dataset1), tuple(dataset2))) == expected[2] class TestMin: