diff --git a/lib/ansible/inventory/group.py b/lib/ansible/inventory/group.py index 4847e6fbd8..859781cc4d 100644 --- a/lib/ansible/inventory/group.py +++ b/lib/ansible/inventory/group.py @@ -19,6 +19,8 @@ __metaclass__ = type from ansible.errors import AnsibleError +from itertools import chain + class Group: ''' a group of ansible hosts ''' @@ -80,6 +82,38 @@ class Group: g.deserialize(parent_data) self.parent_groups.append(g) + def _walk_relationship(self, rel): + ''' + Given `rel` that is an iterable property of Group, + consitituting a directed acyclic graph among all groups, + Returns a set of all groups in full tree + A B C + | / | / + | / | / + D -> E + | / vertical connections + | / are directed upward + F + Called on F, returns set of (A, B, C, D, E) + ''' + seen = set([]) + unprocessed = set(getattr(self, rel)) + + while unprocessed: + seen.update(unprocessed) + unprocessed = set(chain.from_iterable( + getattr(g, rel) for g in unprocessed + )) + unprocessed.difference_update(seen) + + return seen + + def get_ancestors(self): + return self._walk_relationship('parent_groups') + + def get_descendants(self): + return self._walk_relationship('child_groups') + @property def host_names(self): if self._hosts is None: @@ -96,6 +130,17 @@ class Group: # don't add if it's already there if group not in self.child_groups: + + # prepare list of group's new ancestors this edge creates + start_ancestors = group.get_ancestors() + new_ancestors = self.get_ancestors() + if group in new_ancestors: + raise AnsibleError( + "Adding group '%s' as child to '%s' creates a recursive " + "dependency loop." % (group.name, self.name)) + new_ancestors.add(self) + new_ancestors.difference_update(start_ancestors) + self.child_groups.append(group) # update the depth of the child @@ -109,18 +154,28 @@ class Group: if self.name not in [g.name for g in group.parent_groups]: group.parent_groups.append(self) for h in group.get_hosts(): - h.populate_ancestors() + h.populate_ancestors(additions=new_ancestors) self.clear_hosts_cache() def _check_children_depth(self): - try: - for group in self.child_groups: - group.depth = max([self.depth + 1, group.depth]) - group._check_children_depth() - except RuntimeError: - raise AnsibleError("The group named '%s' has a recursive dependency loop." % self.name) + depth = self.depth + start_depth = self.depth # self.depth could change over loop + seen = set([]) + unprocessed = set(self.child_groups) + + while unprocessed: + seen.update(unprocessed) + depth += 1 + to_process = unprocessed.copy() + unprocessed = set([]) + for g in to_process: + if g.depth < depth: + g.depth = depth + unprocessed.update(g.child_groups) + if depth - start_depth > len(seen): + raise AnsibleError("The group named '%s' has a recursive dependency loop." % self.name) def add_host(self, host): if host.name not in self.host_names: @@ -147,8 +202,8 @@ class Group: def clear_hosts_cache(self): self._hosts_cache = None - for g in self.parent_groups: - g.clear_hosts_cache() + for g in self.get_ancestors(): + g._hosts_cache = None def get_hosts(self): @@ -160,8 +215,8 @@ class Group: hosts = [] seen = {} - for kid in self.child_groups: - kid_hosts = kid.get_hosts() + for kid in self.get_descendants(): + kid_hosts = kid.hosts for kk in kid_hosts: if kk not in seen: seen[kk] = 1 @@ -179,18 +234,6 @@ class Group: def get_vars(self): return self.vars.copy() - def _get_ancestors(self): - - results = {} - for g in self.parent_groups: - results[g.name] = g - results.update(g._get_ancestors()) - return results - - def get_ancestors(self): - - return self._get_ancestors().values() - def set_priority(self, priority): try: self.priority = int(priority) diff --git a/lib/ansible/inventory/host.py b/lib/ansible/inventory/host.py index 647e00dc1c..4327273945 100644 --- a/lib/ansible/inventory/host.py +++ b/lib/ansible/inventory/host.py @@ -101,17 +101,22 @@ class Host: def get_name(self): return self.name - def populate_ancestors(self): + def populate_ancestors(self, additions=None): # populate ancestors - for group in self.groups: - self.add_group(group) + if additions is None: + for group in self.groups: + self.add_group(group) + else: + for group in additions: + if group not in self.groups: + self.groups.append(group) def add_group(self, group): # populate ancestors first for oldg in group.get_ancestors(): if oldg not in self.groups: - self.add_group(oldg) + self.groups.append(oldg) # actually add group if group not in self.groups: diff --git a/test/units/plugins/inventory/test_group.py b/test/units/plugins/inventory/test_group.py new file mode 100644 index 0000000000..086c7cf798 --- /dev/null +++ b/test/units/plugins/inventory/test_group.py @@ -0,0 +1,125 @@ +# Copyright 2018 Alan Rominger +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from ansible.compat.tests import unittest + +from ansible.inventory.group import Group +from ansible.inventory.host import Host +from ansible.errors import AnsibleError + + +class TestGroup(unittest.TestCase): + + def test_depth_update(self): + A = Group('A') + B = Group('B') + Z = Group('Z') + A.add_child_group(B) + A.add_child_group(Z) + self.assertEqual(A.depth, 0) + self.assertEqual(Z.depth, 1) + self.assertEqual(B.depth, 1) + + def test_depth_update_dual_branches(self): + alpha = Group('alpha') + A = Group('A') + alpha.add_child_group(A) + B = Group('B') + A.add_child_group(B) + Z = Group('Z') + alpha.add_child_group(Z) + beta = Group('beta') + B.add_child_group(beta) + Z.add_child_group(beta) + + self.assertEqual(alpha.depth, 0) # apex + self.assertEqual(beta.depth, 3) # alpha -> A -> B -> beta + + omega = Group('omega') + omega.add_child_group(alpha) + + # verify that both paths are traversed to get the max depth value + self.assertEqual(B.depth, 3) # omega -> alpha -> A -> B + self.assertEqual(beta.depth, 4) # B -> beta + + def test_depth_recursion(self): + A = Group('A') + B = Group('B') + A.add_child_group(B) + # hypothetical of adding B as child group to A + A.parent_groups.append(B) + B.child_groups.append(A) + # can't update depths of groups, because of loop + with self.assertRaises(AnsibleError): + B._check_children_depth() + + def test_loop_detection(self): + A = Group('A') + B = Group('B') + C = Group('C') + A.add_child_group(B) + B.add_child_group(C) + with self.assertRaises(AnsibleError): + C.add_child_group(A) + + def test_populates_descendant_hosts(self): + A = Group('A') + B = Group('B') + C = Group('C') + h = Host('h') + C.add_host(h) + A.add_child_group(B) # B is child of A + B.add_child_group(C) # C is descendant of A + A.add_child_group(B) + self.assertEqual(set(h.groups), set([C, B, A])) + h2 = Host('h2') + C.add_host(h2) + self.assertEqual(set(h2.groups), set([C, B, A])) + + def test_ancestor_example(self): + # see docstring for Group._walk_relationship + groups = {} + for name in ['A', 'B', 'C', 'D', 'E', 'F']: + groups[name] = Group(name) + # first row + groups['A'].add_child_group(groups['D']) + groups['B'].add_child_group(groups['D']) + groups['B'].add_child_group(groups['E']) + groups['C'].add_child_group(groups['D']) + # second row + groups['D'].add_child_group(groups['E']) + groups['D'].add_child_group(groups['F']) + groups['E'].add_child_group(groups['F']) + + self.assertEqual( + set(groups['F'].get_ancestors()), + set([ + groups['A'], groups['B'], groups['C'], groups['D'], groups['E'] + ]) + ) + + def test_ancestors_recursive_loop_safe(self): + ''' + The get_ancestors method may be referenced before circular parenting + checks, so the method is expected to be stable even with loops + ''' + A = Group('A') + B = Group('B') + A.parent_groups.append(B) + B.parent_groups.append(A) + # finishes in finite time + self.assertEqual(A.get_ancestors(), set([A, B]))