From ff9f5d7dc8dbd2cedb3e1cac3ce7f9f5332e1624 Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Fri, 4 Sep 2015 16:41:38 -0400 Subject: [PATCH] Starting to add additional unit tests for VariableManager Required some rewiring in inventory code to make sure we're using the DataLoader class for some data file operations, which makes mocking them much easier. Also identified two corner cases not currently handled by the code, related to inventory variable sources and which one "wins". Also noticed we weren't properly merging variables from multiple group/host_var file locations (inventory directory vs. playbook directory locations) so fixed as well. --- lib/ansible/cli/__init__.py | 5 +- lib/ansible/cli/adhoc.py | 7 +- lib/ansible/cli/playbook.py | 8 +- lib/ansible/inventory/__init__.py | 8 +- lib/ansible/inventory/dir.py | 7 +- lib/ansible/inventory/host.py | 11 +- lib/ansible/inventory/ini.py | 13 ++- lib/ansible/parsing/__init__.py | 40 ++++++- lib/ansible/utils/path.py | 6 +- lib/ansible/vars/__init__.py | 41 ++++--- test/units/mock/loader.py | 4 + test/units/parsing/test_data_loader.py | 3 +- test/units/vars/test_variable_manager.py | 134 +++++++++++++++++++++-- 13 files changed, 233 insertions(+), 54 deletions(-) diff --git a/lib/ansible/cli/__init__.py b/lib/ansible/cli/__init__.py index e223180903..8906f8134f 100644 --- a/lib/ansible/cli/__init__.py +++ b/lib/ansible/cli/__init__.py @@ -34,7 +34,6 @@ from ansible import constants as C from ansible.errors import AnsibleError, AnsibleOptionsError from ansible.utils.unicode import to_bytes from ansible.utils.display import Display -from ansible.utils.path import is_executable class SortedOptParser(optparse.OptionParser): '''Optparser which sorts the options by opt before outputting --help''' @@ -479,7 +478,7 @@ class CLI(object): return t @staticmethod - def read_vault_password_file(vault_password_file): + def read_vault_password_file(vault_password_file, loader): """ Read a vault password from a file or if executable, execute the script and retrieve password from STDOUT @@ -489,7 +488,7 @@ class CLI(object): if not os.path.exists(this_path): raise AnsibleError("The vault password file %s was not found" % this_path) - if is_executable(this_path): + if loader.is_executable(this_path): try: # STDERR not captured to make it easier for users to prompt for input in their scripts p = subprocess.Popen(this_path, stdout=subprocess.PIPE) diff --git a/lib/ansible/cli/adhoc.py b/lib/ansible/cli/adhoc.py index 80318eadf2..9b6957cfbb 100644 --- a/lib/ansible/cli/adhoc.py +++ b/lib/ansible/cli/adhoc.py @@ -95,13 +95,16 @@ class AdHocCLI(CLI): (sshpass, becomepass) = self.ask_passwords() passwords = { 'conn_pass': sshpass, 'become_pass': becomepass } + loader = DataLoader() + if self.options.vault_password_file: # read vault_pass from a file - vault_pass = CLI.read_vault_password_file(self.options.vault_password_file) + vault_pass = CLI.read_vault_password_file(self.options.vault_password_file, loader=loader) + loader.set_vault_password(vault_pass) elif self.options.ask_vault_pass: vault_pass = self.ask_vault_passwords(ask_vault_pass=True, ask_new_vault_pass=False, confirm_new=False)[0] + loader.set_vault_password(vault_pass) - loader = DataLoader(vault_password=vault_pass) variable_manager = VariableManager() variable_manager.extra_vars = load_extra_vars(loader=loader, options=self.options) diff --git a/lib/ansible/cli/playbook.py b/lib/ansible/cli/playbook.py index d8e7f6761b..306c134790 100644 --- a/lib/ansible/cli/playbook.py +++ b/lib/ansible/cli/playbook.py @@ -89,13 +89,15 @@ class PlaybookCLI(CLI): (sshpass, becomepass) = self.ask_passwords() passwords = { 'conn_pass': sshpass, 'become_pass': becomepass } + loader = DataLoader() + if self.options.vault_password_file: # read vault_pass from a file - vault_pass = CLI.read_vault_password_file(self.options.vault_password_file) + vault_pass = CLI.read_vault_password_file(self.options.vault_password_file, loader=loader) + loader.set_vault_password(vault_pass) elif self.options.ask_vault_pass: vault_pass = self.ask_vault_passwords(ask_vault_pass=True, ask_new_vault_pass=False, confirm_new=False)[0] - - loader = DataLoader(vault_password=vault_pass) + loader.set_vault_password(vault_pass) # initial error check, to make sure all specified playbooks are accessible # before we start running anything through the playbook executor diff --git a/lib/ansible/inventory/__init__.py b/lib/ansible/inventory/__init__.py index c6614439f6..2eda3e6649 100644 --- a/lib/ansible/inventory/__init__.py +++ b/lib/ansible/inventory/__init__.py @@ -104,9 +104,9 @@ class Inventory(object): all.add_host(Host(tokens[0], tokens[1])) else: all.add_host(Host(x)) - elif os.path.exists(host_list): + elif self._loader.path_exists(host_list): #TODO: switch this to a plugin loader and a 'condition' per plugin on which it should be tried, restoring 'inventory pllugins' - if os.path.isdir(host_list): + if self._loader.is_directory(host_list): # Ensure basedir is inside the directory host_list = os.path.join(self.host_list, "") self.parser = InventoryDirectory(loader=self._loader, filename=host_list) @@ -595,14 +595,14 @@ class Inventory(object): """ did inventory come from a file? """ if not isinstance(self.host_list, basestring): return False - return os.path.exists(self.host_list) + return self._loader.path_exists(self.host_list) def basedir(self): """ if inventory came from a file, what's the directory? """ dname = self.host_list if not self.is_file(): dname = None - elif os.path.isdir(self.host_list): + elif self._loader.is_directory(self.host_list): dname = self.host_list else: dname = os.path.dirname(self.host_list) diff --git a/lib/ansible/inventory/dir.py b/lib/ansible/inventory/dir.py index e456a950d4..261ab02fdb 100644 --- a/lib/ansible/inventory/dir.py +++ b/lib/ansible/inventory/dir.py @@ -29,7 +29,6 @@ from ansible.inventory.host import Host from ansible.inventory.group import Group from ansible.utils.vars import combine_vars -from ansible.utils.path import is_executable from ansible.inventory.ini import InventoryParser as InventoryINIParser from ansible.inventory.script import InventoryScript @@ -54,7 +53,7 @@ def get_file_parser(hostsfile, loader): except: pass - if is_executable(hostsfile): + if loader.is_executable(hostsfile): try: parser = InventoryScript(loader=loader, filename=hostsfile) processed = True @@ -65,10 +64,10 @@ def get_file_parser(hostsfile, loader): if not processed: try: - parser = InventoryINIParser(filename=hostsfile) + parser = InventoryINIParser(loader=loader, filename=hostsfile) processed = True except Exception as e: - if shebang_present and not is_executable(hostsfile): + if shebang_present and not loader.is_executable(hostsfile): myerr.append("The file %s looks like it should be an executable inventory script, but is not marked executable. " % hostsfile + \ "Perhaps you want to correct this with `chmod +x %s`?" % hostsfile) else: diff --git a/lib/ansible/inventory/host.py b/lib/ansible/inventory/host.py index 43a96d54bf..ea14745434 100644 --- a/lib/ansible/inventory/host.py +++ b/lib/ansible/inventory/host.py @@ -114,12 +114,15 @@ class Host: def get_vars(self): results = {} - groups = self.get_groups() - for group in sorted(groups, key=lambda g: g.depth): - results = combine_vars(results, group.get_vars()) results = combine_vars(results, self.vars) results['inventory_hostname'] = self.name results['inventory_hostname_short'] = self.name.split('.')[0] - results['group_names'] = sorted([ g.name for g in groups if g.name != 'all']) + results['group_names'] = sorted([ g.name for g in self.get_groups() if g.name != 'all']) return results + def get_group_vars(self): + results = {} + groups = self.get_groups() + for group in sorted(groups, key=lambda g: g.depth): + results = combine_vars(results, group.get_vars()) + return results diff --git a/lib/ansible/inventory/ini.py b/lib/ansible/inventory/ini.py index 2769632ef2..34b3a3acb4 100644 --- a/lib/ansible/inventory/ini.py +++ b/lib/ansible/inventory/ini.py @@ -37,7 +37,8 @@ class InventoryParser(object): with their associated hosts and variable settings. """ - def __init__(self, filename=C.DEFAULT_HOST_LIST): + def __init__(self, loader, filename=C.DEFAULT_HOST_LIST): + self._loader = loader self.filename = filename # Start with an empty host list and the default 'all' and @@ -53,8 +54,14 @@ class InventoryParser(object): # Read in the hosts, groups, and variables defined in the # inventory file. - with open(filename) as fh: - self._parse(fh.readlines()) + if loader: + (data, private) = loader._get_file_contents(filename) + data = data.split('\n') + else: + with open(filename) as fh: + data = fh.readlines() + + self._parse(data) # Finally, add all top-level groups (including 'ungrouped') as # children of 'all'. diff --git a/lib/ansible/parsing/__init__.py b/lib/ansible/parsing/__init__.py index 7fc999623c..11aad338c1 100644 --- a/lib/ansible/parsing/__init__.py +++ b/lib/ansible/parsing/__init__.py @@ -22,6 +22,7 @@ __metaclass__ = type import copy import json import os +import stat from yaml import load, YAMLError from six import text_type @@ -56,11 +57,15 @@ class DataLoader(): ds = dl.load_from_file('/path/to/file') ''' - def __init__(self, vault_password=None): + def __init__(self): self._basedir = '.' - self._vault_password = vault_password self._FILE_CACHE = dict() + # initialize the vault stuff with an empty password + self.set_vault_password(None) + + def set_vault_password(self, vault_password): + self._vault_password = vault_password self._vault = VaultLib(password=vault_password) def load(self, data, file_name='', show_content=True): @@ -130,6 +135,11 @@ class DataLoader(): path = self.path_dwim(path) return os.listdir(path) + def is_executable(self, path): + '''is the given path executable?''' + path = self.path_dwim(path) + return (stat.S_IXUSR & os.stat(path)[stat.ST_MODE] or stat.S_IXGRP & os.stat(path)[stat.ST_MODE] or stat.S_IXOTH & os.stat(path)[stat.ST_MODE]) + def _safe_load(self, stream, file_name=None): ''' Implements yaml.safe_load(), except using our custom loader class. ''' @@ -249,3 +259,29 @@ class DataLoader(): return candidate + def read_vault_password_file(self, vault_password_file): + """ + Read a vault password from a file or if executable, execute the script and + retrieve password from STDOUT + """ + + this_path = os.path.realpath(os.path.expanduser(vault_password_file)) + if not os.path.exists(this_path): + raise AnsibleError("The vault password file %s was not found" % this_path) + + if self.is_executable(this_path): + try: + # STDERR not captured to make it easier for users to prompt for input in their scripts + p = subprocess.Popen(this_path, stdout=subprocess.PIPE) + except OSError as e: + raise AnsibleError("Problem running vault password script %s (%s). If this is not a script, remove the executable bit from the file." % (' '.join(this_path), e)) + stdout, stderr = p.communicate() + self.set_vault_password(stdout.strip('\r\n')) + else: + try: + f = open(this_path, "rb") + self.set_vault_password(f.read().strip()) + f.close() + except (OSError, IOError) as e: + raise AnsibleError("Could not read vault password file %s: %s" % (this_path, e)) + diff --git a/lib/ansible/utils/path.py b/lib/ansible/utils/path.py index 0f9e641b74..ffac578243 100644 --- a/lib/ansible/utils/path.py +++ b/lib/ansible/utils/path.py @@ -22,11 +22,7 @@ import stat from time import sleep from errno import EEXIST -__all__ = ['is_executable', 'unfrackpath'] - -def is_executable(path): - '''is the given path executable?''' - return (stat.S_IXUSR & os.stat(path)[stat.ST_MODE] or stat.S_IXGRP & os.stat(path)[stat.ST_MODE] or stat.S_IXOTH & os.stat(path)[stat.ST_MODE]) +__all__ = ['unfrackpath'] def unfrackpath(path): ''' diff --git a/lib/ansible/vars/__init__.py b/lib/ansible/vars/__init__.py index ebb045e425..aa5b7ea531 100644 --- a/lib/ansible/vars/__init__.py +++ b/lib/ansible/vars/__init__.py @@ -119,11 +119,11 @@ class VariableManager: - host_vars_files[host] (if there is a host context) - host->get_vars (if there is a host context) - fact_cache[host] (if there is a host context) - - vars_cache[host] (if there is a host context) - play vars (if there is a play context) - play vars_files (if there's no host context, ignore file names that cannot be templated) - task->get_vars (if there is a task context) + - vars_cache[host] (if there is a host context) - extra vars ''' @@ -152,29 +152,34 @@ class VariableManager: # files and then any vars from host_vars files which may apply to # this host or the groups it belongs to - # we merge in the special 'all' group_vars first, if they exist + # we merge in vars from groups specified in the inventory (INI or script) + all_vars = combine_vars(all_vars, host.get_group_vars()) + + # then we merge in the special 'all' group_vars first, if they exist if 'all' in self._group_vars_files: data = self._preprocess_vars(self._group_vars_files['all']) for item in data: all_vars = combine_vars(all_vars, item) for group in host.get_groups(): - all_vars = combine_vars(all_vars, group.get_vars()) if group.name in self._group_vars_files and group.name != 'all': - data = self._preprocess_vars(self._group_vars_files[group.name]) + for data in self._group_vars_files[group.name]: + data = self._preprocess_vars(data) + for item in data: + all_vars = combine_vars(all_vars, item) + + # then we merge in vars from the host specified in the inventory (INI or script) + all_vars = combine_vars(all_vars, host.get_vars()) + + # then we merge in the host_vars/ file, if it exists + host_name = host.get_name() + if host_name in self._host_vars_files: + for data in self._host_vars_files[host_name]: + data = self._preprocess_vars(data) for item in data: all_vars = combine_vars(all_vars, item) - host_name = host.get_name() - if host_name in self._host_vars_files: - data = self._preprocess_vars(self._host_vars_files[host_name]) - for item in data: - all_vars = combine_vars(all_vars, self._host_vars_files[host_name]) - - # then we merge in vars specified for this host - all_vars = combine_vars(all_vars, host.get_vars()) - - # next comes the facts cache and the vars cache, respectively + # finally, the facts cache for this host, if it exists try: host_facts = self._fact_cache.get(host.name, dict()) for k in host_facts.keys(): @@ -333,7 +338,9 @@ class VariableManager: (name, data) = self._load_inventory_file(path, loader) if data: - self._host_vars_files[name] = data + if name not in self._host_vars_files: + self._host_vars_files[name] = [] + self._host_vars_files[name].append(data) return data else: return dict() @@ -347,7 +354,9 @@ class VariableManager: (name, data) = self._load_inventory_file(path, loader) if data: - self._group_vars_files[name] = data + if name not in self._group_vars_files: + self._group_vars_files[name] = [] + self._group_vars_files[name].append(data) return data else: return dict() diff --git a/test/units/mock/loader.py b/test/units/mock/loader.py index 88f3970913..db69b862a7 100644 --- a/test/units/mock/loader.py +++ b/test/units/mock/loader.py @@ -57,6 +57,10 @@ class DictDataLoader(DataLoader): def list_directory(self, path): return [x for x in self._known_directories] + def is_executable(self, path): + # FIXME: figure out a way to make paths return true for this + return False + def _add_known_directory(self, directory): if directory not in self._known_directories: self._known_directories.append(directory) diff --git a/test/units/parsing/test_data_loader.py b/test/units/parsing/test_data_loader.py index b9c37cdd0c..810cea41d2 100644 --- a/test/units/parsing/test_data_loader.py +++ b/test/units/parsing/test_data_loader.py @@ -66,7 +66,8 @@ class TestDataLoader(unittest.TestCase): class TestDataLoaderWithVault(unittest.TestCase): def setUp(self): - self._loader = DataLoader(vault_password='ansible') + self._loader = DataLoader() + self._loader.set_vault_password('ansible') def tearDown(self): pass diff --git a/test/units/vars/test_variable_manager.py b/test/units/vars/test_variable_manager.py index 18454bf444..a29f7c075a 100644 --- a/test/units/vars/test_variable_manager.py +++ b/test/units/vars/test_variable_manager.py @@ -19,11 +19,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from collections import defaultdict from six import iteritems from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, MagicMock - +from ansible.inventory import Inventory +from ansible.playbook.play import Play from ansible.vars import VariableManager from units.mock.loader import DictDataLoader @@ -68,20 +70,27 @@ class TestVariableManager(unittest.TestCase): fake_loader = DictDataLoader({ "host_vars/hostname1.yml": """ foo: bar - """ + """, + "other_path/host_vars/hostname1.yml": """ + foo: bam + baa: bat + """, }) v = VariableManager() v.add_host_vars_file("host_vars/hostname1.yml", loader=fake_loader) + v.add_host_vars_file("other_path/host_vars/hostname1.yml", loader=fake_loader) self.assertIn("hostname1", v._host_vars_files) - self.assertEqual(v._host_vars_files["hostname1"], dict(foo="bar")) + self.assertEqual(v._host_vars_files["hostname1"], [dict(foo="bar"), dict(foo="bam", baa="bat")]) mock_host = MagicMock() mock_host.get_name.return_value = "hostname1" mock_host.get_vars.return_value = dict() mock_host.get_groups.return_value = () + mock_host.get_group_vars.return_value = dict() - self.assertEqual(v.get_vars(loader=fake_loader, host=mock_host, use_cache=False).get("foo"), "bar") + self.assertEqual(v.get_vars(loader=fake_loader, host=mock_host, use_cache=False).get("foo"), "bam") + self.assertEqual(v.get_vars(loader=fake_loader, host=mock_host, use_cache=False).get("baa"), "bat") def test_variable_manager_group_vars_file(self): fake_loader = DictDataLoader({ @@ -90,15 +99,19 @@ class TestVariableManager(unittest.TestCase): """, "group_vars/somegroup.yml": """ bam: baz + """, + "other_path/group_vars/somegroup.yml": """ + baa: bat """ }) v = VariableManager() v.add_group_vars_file("group_vars/all.yml", loader=fake_loader) v.add_group_vars_file("group_vars/somegroup.yml", loader=fake_loader) + v.add_group_vars_file("other_path/group_vars/somegroup.yml", loader=fake_loader) self.assertIn("somegroup", v._group_vars_files) - self.assertEqual(v._group_vars_files["all"], dict(foo="bar")) - self.assertEqual(v._group_vars_files["somegroup"], dict(bam="baz")) + self.assertEqual(v._group_vars_files["all"], [dict(foo="bar")]) + self.assertEqual(v._group_vars_files["somegroup"], [dict(bam="baz"), dict(baa="bat")]) mock_group = MagicMock() mock_group.name = "somegroup" @@ -109,10 +122,11 @@ class TestVariableManager(unittest.TestCase): mock_host.get_name.return_value = "hostname1" mock_host.get_vars.return_value = dict() mock_host.get_groups.return_value = (mock_group,) + mock_host.get_group_vars.return_value = dict() vars = v.get_vars(loader=fake_loader, host=mock_host, use_cache=False) self.assertEqual(vars.get("foo"), "bar") - self.assertEqual(vars.get("bam"), "baz") + self.assertEqual(vars.get("baa"), "bat") def test_variable_manager_play_vars(self): fake_loader = DictDataLoader({}) @@ -150,3 +164,109 @@ class TestVariableManager(unittest.TestCase): v = VariableManager() self.assertEqual(v.get_vars(loader=fake_loader, task=mock_task, use_cache=False).get("foo"), "bar") + def test_variable_manager_precedence(self): + ''' + Tests complex variations and combinations of get_vars() with different + objects to modify the context under which variables are merged. + ''' + + v = VariableManager() + v._fact_cache = defaultdict(dict) + + fake_loader = DictDataLoader({ + # inventory1 + '/etc/ansible/inventory1': """ + [group2:children] + group1 + + [group1] + host1 host_var=host_var_from_inventory_host1 + + [group1:vars] + group_var = group_var_from_inventory_group1 + + [group2:vars] + group_var = group_var_from_inventory_group2 + """, + + # role defaults_only1 + '/etc/ansible/roles/defaults_only1/defaults/main.yml': """ + default_var: "default_var_from_defaults_only1" + host_var: "host_var_from_defaults_only1" + group_var: "group_var_from_defaults_only1" + group_var_all: "group_var_all_from_defaults_only1" + extra_var: "extra_var_from_defaults_only1" + """, + '/etc/ansible/roles/defaults_only1/tasks/main.yml': """ + - debug: msg="here i am" + """, + + # role defaults_only2 + '/etc/ansible/roles/defaults_only2/defaults/main.yml': """ + default_var: "default_var_from_defaults_only2" + host_var: "host_var_from_defaults_only2" + group_var: "group_var_from_defaults_only2" + group_var_all: "group_var_all_from_defaults_only2" + extra_var: "extra_var_from_defaults_only2" + """, + }) + + inv1 = Inventory(loader=fake_loader, variable_manager=v, host_list='/etc/ansible/inventory1') + inv1.set_playbook_basedir('./') + + play1 = Play.load(dict( + hosts=['all'], + roles=['defaults_only1', 'defaults_only2'], + ), loader=fake_loader, variable_manager=v) + + # first we assert that the defaults as viewed as a whole are the merged results + # of the defaults from each role, with the last role defined "winning" when + # there is a variable naming conflict + res = v.get_vars(loader=fake_loader, play=play1) + self.assertEqual(res['default_var'], 'default_var_from_defaults_only2') + + # next, we assert that when vars are viewed from the context of a task within a + # role, that task will see its own role defaults before any other role's + blocks = play1.compile() + task = blocks[1].block[0] + res = v.get_vars(loader=fake_loader, play=play1, task=task) + self.assertEqual(res['default_var'], 'default_var_from_defaults_only1') + + # next we assert the precendence of inventory variables + v.set_inventory(inv1) + h1 = inv1.get_host('host1') + + res = v.get_vars(loader=fake_loader, play=play1, host=h1) + self.assertEqual(res['group_var'], 'group_var_from_inventory_group1') + self.assertEqual(res['host_var'], 'host_var_from_inventory_host1') + + # next we test with group_vars/ files loaded + fake_loader.push("/etc/ansible/group_vars/all", """ + group_var_all: group_var_all_from_group_vars_all + """) + fake_loader.push("/etc/ansible/group_vars/group1", """ + group_var: group_var_from_group_vars_group1 + """) + fake_loader.push("/etc/ansible/group_vars/group3", """ + # this is a dummy, which should not be used anywhere + group_var: group_var_from_group_vars_group3 + """) + fake_loader.push("/etc/ansible/host_vars/host1", """ + host_var: host_var_from_host_vars_host1 + """) + + v.add_group_vars_file("/etc/ansible/group_vars/all", loader=fake_loader) + v.add_group_vars_file("/etc/ansible/group_vars/group1", loader=fake_loader) + v.add_group_vars_file("/etc/ansible/group_vars/group2", loader=fake_loader) + v.add_host_vars_file("/etc/ansible/host_vars/host1", loader=fake_loader) + + res = v.get_vars(loader=fake_loader, play=play1, host=h1) + self.assertEqual(res['group_var'], 'group_var_from_group_vars_group1') + self.assertEqual(res['group_var_all'], 'group_var_all_from_group_vars_all') + self.assertEqual(res['host_var'], 'host_var_from_host_vars_host1') + + # add in the fact cache + v._fact_cache['host1'] = dict(fact_cache_var="fact_cache_var_from_fact_cache") + + res = v.get_vars(loader=fake_loader, play=play1, host=h1) + self.assertEqual(res['fact_cache_var'], 'fact_cache_var_from_fact_cache')