From 54b45e9bd44ab258cacb8f7a7b7136e40398b636 Mon Sep 17 00:00:00 2001 From: Daniel Hokka Zakrisson Date: Mon, 10 Dec 2012 15:54:07 +0100 Subject: [PATCH] Allow intersecting host patterns by using & This allows patterns such as webservers:!debian:&datacenter1 to target hosts in the webservers group, that are not in the debian group, but are in the datacenter1 group. It also parses patterns left to right. --- lib/ansible/inventory/__init__.py | 63 ++++++++++++++----------------- test/TestInventory.py | 7 ++++ 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/lib/ansible/inventory/__init__.py b/lib/ansible/inventory/__init__.py index 6e1a356e9f..2316104b63 100644 --- a/lib/ansible/inventory/__init__.py +++ b/lib/ansible/inventory/__init__.py @@ -104,27 +104,12 @@ class Inventory(object): if isinstance(pattern, list): pattern = ';'.join(pattern) patterns = pattern.replace(";",":").split(":") - positive_patterns = [ p for p in patterns if not p.startswith("!") ] - negative_patterns = [ p for p in patterns if p.startswith("!") ] - - # find hosts matching positive patterns - hosts = self._get_hosts(positive_patterns) - - # exclude hosts mentioned in a negative pattern - if len(negative_patterns): - exclude_hosts = [ h.name for h in self._get_hosts(negative_patterns) ] - hosts = [ h for h in hosts if h.name not in exclude_hosts ] + hosts = self._get_hosts(patterns) # exclude hosts not in a subset, if defined if self._subset: - positive_subsetp = [ p for p in self._subset if not p.startswith("!") ] - negative_subsetp = [ p for p in self._subset if p.startswith("!") ] - if len(positive_subsetp): - positive_subset = [ h.name for h in self._get_hosts(positive_subsetp) ] - hosts = [ h for h in hosts if (h.name in positive_subset) ] - if len(negative_subsetp): - negative_subset = [ h.name for h in self._get_hosts(negative_subsetp) ] - hosts = [ h for h in hosts if (h.name not in negative_subset)] + subset = self._get_hosts(self._subset) + hosts.intersection_update(subset) # exclude hosts mentioned in any restriction (ex: failed hosts) if self._restriction is not None: @@ -135,27 +120,35 @@ class Inventory(object): return sorted(hosts, key=lambda x: x.name) def _get_hosts(self, patterns): + """ + finds hosts that match a list of patterns. Handles negative + matches as well as intersection matches. + """ + + hosts = set() + for p in patterns: + if p.startswith("!"): + # Discard excluded hosts + hosts.difference_update(self.__get_hosts(p)) + elif p.startswith("&"): + # Only leave the intersected hosts + hosts.intersection_update(self.__get_hosts(p)) + else: + # Get all hosts from both patterns + hosts.update(self.__get_hosts(p)) + return hosts + + def __get_hosts(self, pattern): """ - finds hosts that postively match a particular list of patterns. Does not + finds hosts that postively match a particular pattern. Does not take into account negative matches. """ - by_pattern = {} - for p in patterns: - (name, enumeration_details) = self._enumeration_info(p) - hpat = self._hosts_in_unenumerated_pattern(name) - hpat = sorted(hpat, key=lambda x: x.name) - by_pattern[p] = hpat + (name, enumeration_details) = self._enumeration_info(pattern) + hpat = self._hosts_in_unenumerated_pattern(name) + hpat = sorted(hpat, key=lambda x: x.name) - ranged = {} - for (pat, hosts) in by_pattern.iteritems(): - ranged[pat] = self._apply_ranges(pat, hosts) - - results = [] - for (pat, hosts) in ranged.iteritems(): - results.extend(hosts) - - return list(set(results)) + return set(self._apply_ranges(pattern, hpat)) def _enumeration_info(self, pattern): """ @@ -200,7 +193,7 @@ class Inventory(object): hosts = {} # ignore any negative checks here, this is handled elsewhere - pattern = pattern.replace("!","") + pattern = pattern.replace("!","").replace("&", "") groups = self.get_groups() for group in groups: diff --git a/test/TestInventory.py b/test/TestInventory.py index 5c444d9e96..1f053afd3b 100644 --- a/test/TestInventory.py +++ b/test/TestInventory.py @@ -177,6 +177,13 @@ class TestInventory(unittest.TestCase): hosts = inventory.list_hosts("nc[2-3]:florida[1-2]") self.compare(hosts, expected4, sort=False) + def test_complex_intersect(self): + inventory = self.complex_inventory() + hosts = inventory.list_hosts("nc:&redundantgroup:!rtp_c") + self.compare(hosts, ['rtp_a']) + hosts = inventory.list_hosts("nc:&triangle:!tri_c") + self.compare(hosts, ['tri_a', 'tri_b']) + ################################################### ### Inventory API tests