diff --git a/v2/ansible/__init__.py b/v2/ansible/__init__.py index ae8ccff595..26869775ea 100644 --- a/v2/ansible/__init__.py +++ b/v2/ansible/__init__.py @@ -18,3 +18,5 @@ # Make coding more python3-ish from __future__ import (absolute_import, division, print_function) __metaclass__ = type + +__version__ = '1.v2' diff --git a/v2/ansible/constants.py b/v2/ansible/constants.py index e74720b8a6..6adcdd0a9f 100644 --- a/v2/ansible/constants.py +++ b/v2/ansible/constants.py @@ -104,6 +104,7 @@ YAML_FILENAME_EXTENSIONS = [ "", ".yml", ".yaml", ".json" ] DEFAULTS='defaults' # configurable things +DEFAULT_DEBUG = get_config(p, DEFAULTS, 'debug', 'ANSIBLE_DEBUG', False, boolean=True) DEFAULT_HOST_LIST = shell_expand_path(get_config(p, DEFAULTS, 'hostfile', 'ANSIBLE_HOSTS', '/etc/ansible/hosts')) DEFAULT_MODULE_PATH = get_config(p, DEFAULTS, 'library', 'ANSIBLE_LIBRARY', None) DEFAULT_ROLES_PATH = shell_expand_path(get_config(p, DEFAULTS, 'roles_path', 'ANSIBLE_ROLES_PATH', '/etc/ansible/roles')) diff --git a/v2/ansible/errors/__init__.py b/v2/ansible/errors/__init__.py index 2813507df2..7effe41df7 100644 --- a/v2/ansible/errors/__init__.py +++ b/v2/ansible/errors/__init__.py @@ -21,7 +21,7 @@ __metaclass__ = type import os -from ansible.parsing.yaml.strings import * +from ansible.errors.yaml_strings import * class AnsibleError(Exception): ''' @@ -45,12 +45,12 @@ class AnsibleError(Exception): self._obj = obj self._show_content = show_content - if isinstance(self._obj, AnsibleBaseYAMLObject): + if obj and isinstance(obj, AnsibleBaseYAMLObject): extended_error = self._get_extended_error() if extended_error: - self.message = '%s\n\n%s' % (message, extended_error) + self.message = 'ERROR! %s\n\n%s' % (message, extended_error) else: - self.message = message + self.message = 'ERROR! %s' % message def __str__(self): return self.message @@ -98,8 +98,9 @@ class AnsibleError(Exception): (target_line, prev_line) = self._get_error_lines_from_file(src_file, line_number - 1) if target_line: stripped_line = target_line.replace(" ","") - arrow_line = (" " * (col_number-1)) + "^" - error_message += "%s\n%s\n%s\n" % (prev_line.rstrip(), target_line.rstrip(), arrow_line) + arrow_line = (" " * (col_number-1)) + "^ here" + #header_line = ("=" * 73) + error_message += "\nThe offending line appears to be:\n\n%s\n%s\n%s\n" % (prev_line.rstrip(), target_line.rstrip(), arrow_line) # common error/remediation checking here: # check for unquoted vars starting lines @@ -158,3 +159,11 @@ class AnsibleModuleError(AnsibleRuntimeError): class AnsibleConnectionFailure(AnsibleRuntimeError): ''' the transport / connection_plugin had a fatal error ''' pass + +class AnsibleFilterError(AnsibleRuntimeError): + ''' a templating failure ''' + pass + +class AnsibleUndefinedVariable(AnsibleRuntimeError): + ''' a templating failure ''' + pass diff --git a/v2/ansible/parsing/yaml/strings.py b/v2/ansible/errors/yaml_strings.py similarity index 100% rename from v2/ansible/parsing/yaml/strings.py rename to v2/ansible/errors/yaml_strings.py diff --git a/v2/ansible/executor/connection_info.py b/v2/ansible/executor/connection_info.py new file mode 100644 index 0000000000..8f53c3fe34 --- /dev/null +++ b/v2/ansible/executor/connection_info.py @@ -0,0 +1,171 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pipes +import random + +from ansible import constants as C + + +__all__ = ['ConnectionInformation'] + + +class ConnectionInformation: + + ''' + This class is used to consolidate the connection information for + hosts in a play and child tasks, where the task may override some + connection/authentication information. + ''' + + def __init__(self, play=None, options=None): + # FIXME: implement the new methodology here for supporting + # various different auth escalation methods (becomes, etc.) + + self.connection = C.DEFAULT_TRANSPORT + self.remote_user = 'root' + self.password = '' + self.port = 22 + self.su = False + self.su_user = '' + self.su_pass = '' + self.sudo = False + self.sudo_user = '' + self.sudo_pass = '' + self.verbosity = 0 + self.only_tags = set(['all']) + self.skip_tags = set() + + if play: + self.set_play(play) + + if options: + self.set_options(options) + + def set_play(self, play): + ''' + Configures this connection information instance with data from + the play class. + ''' + + if play.connection: + self.connection = play.connection + + self.remote_user = play.remote_user + self.password = '' + self.port = int(play.port) if play.port else 22 + self.su = play.su + self.su_user = play.su_user + self.su_pass = play.su_pass + self.sudo = play.sudo + self.sudo_user = play.sudo_user + self.sudo_pass = play.sudo_pass + + def set_options(self, options): + ''' + Configures this connection information instance with data from + options specified by the user on the command line. These have a + higher precedence than those set on the play or host. + ''' + + # FIXME: set other values from options here? + + self.verbosity = options.verbosity + if options.connection: + self.connection = options.connection + + # get the tag info from options, converting a comma-separated list + # of values into a proper list if need be. We check to see if the + # options have the attribute, as it is not always added via the CLI + if hasattr(options, 'tags'): + if isinstance(options.tags, list): + self.only_tags.update(options.tags) + elif isinstance(options.tags, basestring): + self.only_tags.update(options.tags.split(',')) + + if hasattr(options, 'skip_tags'): + if isinstance(options.skip_tags, list): + self.skip_tags.update(options.skip_tags) + elif isinstance(options.skip_tags, basestring): + self.skip_tags.update(options.skip_tags.split(',')) + + def copy(self, ci): + ''' + Copies the connection info from another connection info object, used + when merging in data from task overrides. + ''' + + self.connection = ci.connection + self.remote_user = ci.remote_user + self.password = ci.password + self.port = ci.port + self.su = ci.su + self.su_user = ci.su_user + self.su_pass = ci.su_pass + self.sudo = ci.sudo + self.sudo_user = ci.sudo_user + self.sudo_pass = ci.sudo_pass + self.verbosity = ci.verbosity + self.only_tags = ci.only_tags.copy() + self.skip_tags = ci.skip_tags.copy() + + def set_task_override(self, task): + ''' + Sets attributes from the task if they are set, which will override + those from the play. + ''' + + new_info = ConnectionInformation() + new_info.copy(self) + + for attr in ('connection', 'remote_user', 'su', 'su_user', 'su_pass', 'sudo', 'sudo_user', 'sudo_pass'): + if hasattr(task, attr): + attr_val = getattr(task, attr) + if attr_val: + setattr(new_info, attr, attr_val) + + return new_info + + def make_sudo_cmd(self, sudo_exe, executable, cmd): + """ + Helper function for wrapping commands with sudo. + + Rather than detect if sudo wants a password this time, -k makes + sudo always ask for a password if one is required. Passing a quoted + compound command to sudo (or sudo -s) directly doesn't work, so we + shellquote it with pipes.quote() and pass the quoted string to the + user's shell. We loop reading output until we see the randomly- + generated sudo prompt set with the -p option. + """ + + randbits = ''.join(chr(random.randint(ord('a'), ord('z'))) for x in xrange(32)) + prompt = '[sudo via ansible, key=%s] password: ' % randbits + success_key = 'SUDO-SUCCESS-%s' % randbits + + sudocmd = '%s -k && %s %s -S -p "%s" -u %s %s -c %s' % ( + sudo_exe, sudo_exe, C.DEFAULT_SUDO_FLAGS, prompt, + self.sudo_user, executable or '$SHELL', + pipes.quote('echo %s; %s' % (success_key, cmd)) + ) + + #return ('/bin/sh -c ' + pipes.quote(sudocmd), prompt, success_key) + return (sudocmd, prompt, success_key) + diff --git a/v2/ansible/executor/manager.py b/v2/ansible/executor/manager.py new file mode 100644 index 0000000000..33a76e143b --- /dev/null +++ b/v2/ansible/executor/manager.py @@ -0,0 +1,66 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from multiprocessing.managers import SyncManager, BaseProxy +from ansible.playbook.handler import Handler +from ansible.playbook.task import Task +from ansible.playbook.play import Play +from ansible.errors import AnsibleError + +__all__ = ['AnsibleManager'] + + +class VariableManagerWrapper: + ''' + This class simply acts as a wrapper around the VariableManager class, + since manager proxies expect a new object to be returned rather than + any existing one. Using this wrapper, a shared proxy can be created + and an existing VariableManager class assigned to it, which can then + be accessed through the exposed proxy methods. + ''' + + def __init__(self): + self._vm = None + + def get_vars(self, loader, play=None, host=None, task=None): + return self._vm.get_vars(loader=loader, play=play, host=host, task=task) + + def set_variable_manager(self, vm): + self._vm = vm + + def set_host_variable(self, host, varname, value): + self._vm.set_host_variable(host, varname, value) + + def set_host_facts(self, host, facts): + self._vm.set_host_facts(host, facts) + +class AnsibleManager(SyncManager): + ''' + This is our custom manager class, which exists only so we may register + the new proxy below + ''' + pass + +AnsibleManager.register( + typeid='VariableManagerWrapper', + callable=VariableManagerWrapper, +) + diff --git a/v2/ansible/executor/module_common.py b/v2/ansible/executor/module_common.py new file mode 100644 index 0000000000..e438099295 --- /dev/null +++ b/v2/ansible/executor/module_common.py @@ -0,0 +1,185 @@ +# (c) 2013-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# from python and deps +from cStringIO import StringIO +import inspect +import json +import os +import shlex + +# from Ansible +from ansible import __version__ +from ansible import constants as C +from ansible.errors import AnsibleError +from ansible.parsing.utils.jsonify import jsonify + +REPLACER = "#<>" +REPLACER_ARGS = "\"<>\"" +REPLACER_COMPLEX = "\"<>\"" +REPLACER_WINDOWS = "# POWERSHELL_COMMON" +REPLACER_VERSION = "\"<>\"" + +class ModuleReplacer(object): + + """ + The Replacer is used to insert chunks of code into modules before + transfer. Rather than doing classical python imports, this allows for more + efficient transfer in a no-bootstrapping scenario by not moving extra files + over the wire, and also takes care of embedding arguments in the transferred + modules. + + This version is done in such a way that local imports can still be + used in the module code, so IDEs don't have to be aware of what is going on. + + Example: + + from ansible.module_utils.basic import * + + ... will result in the insertion basic.py into the module + + from the module_utils/ directory in the source tree. + + All modules are required to import at least basic, though there will also + be other snippets. + + # POWERSHELL_COMMON + + Also results in the inclusion of the common code in powershell.ps1 + + """ + + # ****************************************************************************** + + def __init__(self, strip_comments=False): + # FIXME: these members need to be prefixed with '_' and the rest of the file fixed + this_file = inspect.getfile(inspect.currentframe()) + # we've moved the module_common relative to the snippets, so fix the path + self.snippet_path = os.path.join(os.path.dirname(this_file), '..', 'module_utils') + self.strip_comments = strip_comments + + # ****************************************************************************** + + + def slurp(self, path): + if not os.path.exists(path): + raise AnsibleError("imported module support code does not exist at %s" % path) + fd = open(path) + data = fd.read() + fd.close() + return data + + def _find_snippet_imports(self, module_data, module_path): + """ + Given the source of the module, convert it to a Jinja2 template to insert + module code and return whether it's a new or old style module. + """ + + module_style = 'old' + if REPLACER in module_data: + module_style = 'new' + elif 'from ansible.module_utils.' in module_data: + module_style = 'new' + elif 'WANT_JSON' in module_data: + module_style = 'non_native_want_json' + + output = StringIO() + lines = module_data.split('\n') + snippet_names = [] + + for line in lines: + + if REPLACER in line: + output.write(self.slurp(os.path.join(self.snippet_path, "basic.py"))) + snippet_names.append('basic') + if REPLACER_WINDOWS in line: + ps_data = self.slurp(os.path.join(self.snippet_path, "powershell.ps1")) + output.write(ps_data) + snippet_names.append('powershell') + elif line.startswith('from ansible.module_utils.'): + tokens=line.split(".") + import_error = False + if len(tokens) != 3: + import_error = True + if " import *" not in line: + import_error = True + if import_error: + raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.basic import *'" % module_path) + snippet_name = tokens[2].split()[0] + snippet_names.append(snippet_name) + output.write(self.slurp(os.path.join(self.snippet_path, snippet_name + ".py"))) + else: + if self.strip_comments and line.startswith("#") or line == '': + pass + output.write(line) + output.write("\n") + + if not module_path.endswith(".ps1"): + # Unixy modules + if len(snippet_names) > 0 and not 'basic' in snippet_names: + raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path) + else: + # Windows modules + if len(snippet_names) > 0 and not 'powershell' in snippet_names: + raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path) + + return (output.getvalue(), module_style) + + # ****************************************************************************** + + def modify_module(self, module_path, module_args): + + with open(module_path) as f: + + # read in the module source + module_data = f.read() + + (module_data, module_style) = self._find_snippet_imports(module_data, module_path) + + #module_args_json = jsonify(module_args) + module_args_json = json.dumps(module_args) + encoded_args = repr(module_args_json.encode('utf-8')) + + # these strings should be part of the 'basic' snippet which is required to be included + module_data = module_data.replace(REPLACER_VERSION, repr(__version__)) + module_data = module_data.replace(REPLACER_ARGS, "''") + module_data = module_data.replace(REPLACER_COMPLEX, encoded_args) + + # FIXME: we're not passing around an inject dictionary anymore, so + # this needs to be fixed with whatever method we use for vars + # like this moving forward + #if module_style == 'new': + # facility = C.DEFAULT_SYSLOG_FACILITY + # if 'ansible_syslog_facility' in inject: + # facility = inject['ansible_syslog_facility'] + # module_data = module_data.replace('syslog.LOG_USER', "syslog.%s" % facility) + + lines = module_data.split("\n") + shebang = None + if lines[0].startswith("#!"): + shebang = lines[0].strip() + args = shlex.split(str(shebang[2:])) + interpreter = args[0] + interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter) + + # FIXME: more inject stuff here... + #if interpreter_config in inject: + # lines[0] = shebang = "#!%s %s" % (inject[interpreter_config], " ".join(args[1:])) + # module_data = "\n".join(lines) + + return (module_data, module_style, shebang) + diff --git a/v2/ansible/executor/play_iterator.py b/v2/ansible/executor/play_iterator.py new file mode 100644 index 0000000000..076d676255 --- /dev/null +++ b/v2/ansible/executor/play_iterator.py @@ -0,0 +1,266 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.errors import * +from ansible.playbook.task import Task + +from ansible.utils.boolean import boolean + +__all__ = ['PlayIterator'] + + +# the primary running states for the play iteration +ITERATING_SETUP = 0 +ITERATING_TASKS = 1 +ITERATING_RESCUE = 2 +ITERATING_ALWAYS = 3 +ITERATING_COMPLETE = 4 + +# the failure states for the play iteration +FAILED_NONE = 0 +FAILED_SETUP = 1 +FAILED_TASKS = 2 +FAILED_RESCUE = 3 +FAILED_ALWAYS = 4 + +class PlayState: + + ''' + A helper class, which keeps track of the task iteration + state for a given playbook. This is used in the PlaybookIterator + class on a per-host basis. + ''' + + # FIXME: this class is the representation of a finite state machine, + # so we really should have a well defined state representation + # documented somewhere... + + def __init__(self, parent_iterator, host): + ''' + Create the initial state, which tracks the running state as well + as the failure state, which are used when executing block branches + (rescue/always) + ''' + + self._run_state = ITERATING_SETUP + self._failed_state = FAILED_NONE + self._task_list = parent_iterator._play.compile() + self._gather_facts = parent_iterator._play.gather_facts + self._host = host + + self._cur_block = None + self._cur_role = None + self._cur_task_pos = 0 + self._cur_rescue_pos = 0 + self._cur_always_pos = 0 + self._cur_handler_pos = 0 + + def next(self, peek=False): + ''' + Determines and returns the next available task from the playbook, + advancing through the list of plays as it goes. If peek is set to True, + the internal state is not stored. + ''' + + task = None + + # save this locally so that we can peek at the next task + # without updating the internal state of the iterator + run_state = self._run_state + failed_state = self._failed_state + cur_block = self._cur_block + cur_role = self._cur_role + cur_task_pos = self._cur_task_pos + cur_rescue_pos = self._cur_rescue_pos + cur_always_pos = self._cur_always_pos + cur_handler_pos = self._cur_handler_pos + + + while True: + if run_state == ITERATING_SETUP: + if failed_state == FAILED_SETUP: + run_state = ITERATING_COMPLETE + else: + run_state = ITERATING_TASKS + + if self._gather_facts == 'smart' and not self._host.gathered_facts or boolean(self._gather_facts): + self._host.set_gathered_facts(True) + task = Task() + task.action = 'setup' + break + elif run_state == ITERATING_TASKS: + # if there is any failure state besides FAILED_NONE, we should + # change to some other running state + if failed_state != FAILED_NONE or cur_task_pos > len(self._task_list) - 1: + # if there is a block (and there always should be), start running + # the rescue portion if it exists (and if we haven't failed that + # already), or the always portion (if it exists and we didn't fail + # there too). Otherwise, we're done iterating. + if cur_block: + if failed_state != FAILED_RESCUE and cur_block.rescue: + run_state = ITERATING_RESCUE + cur_rescue_pos = 0 + elif failed_state != FAILED_ALWAYS and cur_block.always: + run_state = ITERATING_ALWAYS + cur_always_pos = 0 + else: + run_state = ITERATING_COMPLETE + else: + run_state = ITERATING_COMPLETE + else: + task = self._task_list[cur_task_pos] + if cur_block is not None and cur_block != task._block: + run_state = ITERATING_ALWAYS + continue + else: + cur_block = task._block + cur_task_pos += 1 + + # Break out of the while loop now that we have our task + break + + elif run_state == ITERATING_RESCUE: + # If we're iterating through the rescue tasks, make sure we haven't + # failed yet. If so, move on to the always block or if not get the + # next rescue task (if one exists) + if failed_state == FAILED_RESCUE or cur_block.rescue is None or cur_rescue_pos > len(cur_block.rescue) - 1: + run_state = ITERATING_ALWAYS + else: + task = cur_block.rescue[cur_rescue_pos] + cur_rescue_pos += 1 + break + + elif run_state == ITERATING_ALWAYS: + # If we're iterating through the always tasks, make sure we haven't + # failed yet. If so, we're done iterating otherwise get the next always + # task (if one exists) + if failed_state == FAILED_ALWAYS or cur_block.always is None or cur_always_pos > len(cur_block.always) - 1: + cur_block = None + if failed_state == FAILED_ALWAYS or cur_task_pos > len(self._task_list) - 1: + run_state = ITERATING_COMPLETE + else: + run_state = ITERATING_TASKS + else: + task = cur_block.always[cur_always_pos] + cur_always_pos += 1 + break + + elif run_state == ITERATING_COMPLETE: + # done iterating, return None to signify that + return None + + if task._role: + # if we had a current role, mark that role as completed + if cur_role and task._role != cur_role and not peek: + cur_role._completed = True + + cur_role = task._role + + # if the current role has not had its task run flag set, mark + # clear the completed flag so we can correctly determine if the + # role was run + if not cur_role._had_task_run and not peek: + cur_role._completed = False + + # If we're not just peeking at the next task, save the internal state + if not peek: + self._run_state = run_state + self._failed_state = failed_state + self._cur_block = cur_block + self._cur_role = cur_role + self._cur_task_pos = cur_task_pos + self._cur_rescue_pos = cur_rescue_pos + self._cur_always_pos = cur_always_pos + self._cur_handler_pos = cur_handler_pos + + return task + + def mark_failed(self): + ''' + Escalates the failed state relative to the running state. + ''' + if self._run_state == ITERATING_SETUP: + self._failed_state = FAILED_SETUP + elif self._run_state == ITERATING_TASKS: + self._failed_state = FAILED_TASKS + elif self._run_state == ITERATING_RESCUE: + self._failed_state = FAILED_RESCUE + elif self._run_state == ITERATING_ALWAYS: + self._failed_state = FAILED_ALWAYS + + +class PlayIterator: + + ''' + The main iterator class, which keeps the state of the playbook + on a per-host basis using the above PlaybookState class. + ''' + + def __init__(self, inventory, play): + self._play = play + self._inventory = inventory + self._host_entries = dict() + self._first_host = None + + # Build the per-host dictionary of playbook states, using a copy + # of the play object so we can post_validate it to ensure any templated + # fields are filled in without modifying the original object, since + # post_validate() saves the templated values. + + # FIXME: this is a hacky way of doing this, the iterator should + # instead get the loader and variable manager directly + # as args to __init__ + all_vars = inventory._variable_manager.get_vars(loader=inventory._loader, play=play) + new_play = play.copy() + new_play.post_validate(all_vars, fail_on_undefined=False) + + for host in inventory.get_hosts(new_play.hosts): + if self._first_host is None: + self._first_host = host + self._host_entries[host.get_name()] = PlayState(parent_iterator=self, host=host) + + # FIXME: remove, probably not required anymore + #def get_next_task(self, peek=False): + # ''' returns the next task for host[0] ''' + # + # first_entry = self._host_entries[self._first_host.get_name()] + # if not peek: + # for entry in self._host_entries: + # if entry != self._first_host.get_name(): + # target_entry = self._host_entries[entry] + # if target_entry._cur_task_pos == first_entry._cur_task_pos: + # target_entry.next() + # return first_entry.next(peek=peek) + + def get_next_task_for_host(self, host, peek=False): + ''' fetch the next task for the given host ''' + if host.get_name() not in self._host_entries: + raise AnsibleError("invalid host (%s) specified for playbook iteration" % host) + + return self._host_entries[host.get_name()].next(peek=peek) + + def mark_host_failed(self, host): + ''' mark the given host as failed ''' + if host.get_name() not in self._host_entries: + raise AnsibleError("invalid host (%s) specified for playbook iteration" % host) + + self._host_entries[host.get_name()].mark_failed() + diff --git a/v2/ansible/executor/playbook_executor.py b/v2/ansible/executor/playbook_executor.py index 7031e51142..9c5a0b714a 100644 --- a/v2/ansible/executor/playbook_executor.py +++ b/v2/ansible/executor/playbook_executor.py @@ -19,17 +19,110 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import signal + +from ansible import constants as C +from ansible.errors import * +from ansible.executor.task_queue_manager import TaskQueueManager +from ansible.playbook import Playbook + +from ansible.utils.debug import debug + class PlaybookExecutor: - def __init__(self, list_of_plays=[]): - # self.tqm = TaskQueueManager(forks) - assert False + ''' + This is the primary class for executing playbooks, and thus the + basis for bin/ansible-playbook operation. + ''' - def run(self): - # for play in list_of_plays: - # for block in play.blocks: - # # block must know it’s playbook class and context - # tqm.enqueue(block) - # tqm.go()... - assert False + def __init__(self, playbooks, inventory, variable_manager, loader, options): + self._playbooks = playbooks + self._inventory = inventory + self._variable_manager = variable_manager + self._loader = loader + self._options = options + self._tqm = TaskQueueManager(inventory=inventory, callback='default', variable_manager=variable_manager, loader=loader, options=options) + + def run(self): + + ''' + Run the given playbook, based on the settings in the play which + may limit the runs to serialized groups, etc. + ''' + + signal.signal(signal.SIGINT, self._cleanup) + + try: + for playbook_path in self._playbooks: + pb = Playbook.load(playbook_path, variable_manager=self._variable_manager, loader=self._loader) + + # FIXME: playbook entries are just plays, so we should rename them + for play in pb.get_entries(): + self._inventory.remove_restriction() + + # Create a temporary copy of the play here, so we can run post_validate + # on it without the templating changes affecting the original object. + all_vars = self._variable_manager.get_vars(loader=self._loader, play=play) + new_play = play.copy() + new_play.post_validate(all_vars, fail_on_undefined=False) + + result = True + for batch in self._get_serialized_batches(new_play): + if len(batch) == 0: + raise AnsibleError("No hosts matched the list specified in the play", obj=play._ds) + # restrict the inventory to the hosts in the serialized batch + self._inventory.restrict_to_hosts(batch) + # and run it... + result = self._tqm.run(play=play) + if not result: + break + + if not result: + # FIXME: do something here, to signify the playbook execution failed + self._cleanup() + return 1 + except: + self._cleanup() + raise + + self._cleanup() + return 0 + + def _cleanup(self, signum=None, framenum=None): + self._tqm.cleanup() + + def _get_serialized_batches(self, play): + ''' + Returns a list of hosts, subdivided into batches based on + the serial size specified in the play. + ''' + + # make sure we have a unique list of hosts + all_hosts = self._inventory.get_hosts(play.hosts) + + # check to see if the serial number was specified as a percentage, + # and convert it to an integer value based on the number of hosts + if isinstance(play.serial, basestring) and play.serial.endswith('%'): + serial_pct = int(play.serial.replace("%","")) + serial = int((serial_pct/100.0) * len(all_hosts)) + else: + serial = int(play.serial) + + # if the serial count was not specified or is invalid, default to + # a list of all hosts, otherwise split the list of hosts into chunks + # which are based on the serial size + if serial <= 0: + return [all_hosts] + else: + serialized_batches = [] + + while len(all_hosts) > 0: + play_hosts = [] + for x in range(serial): + if len(all_hosts) > 0: + play_hosts.append(all_hosts.pop(0)) + + serialized_batches.append(play_hosts) + + return serialized_batches diff --git a/v2/ansible/executor/playbook_iterator.py b/v2/ansible/executor/playbook_iterator.py deleted file mode 100644 index 88bec5a331..0000000000 --- a/v2/ansible/executor/playbook_iterator.py +++ /dev/null @@ -1,125 +0,0 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -# Make coding more python3-ish -from __future__ import (absolute_import, division, print_function) -__metaclass__ = type - -class PlaybookState: - - ''' - A helper class, which keeps track of the task iteration - state for a given playbook. This is used in the PlaybookIterator - class on a per-host basis. - ''' - def __init__(self, parent_iterator): - self._parent_iterator = parent_iterator - self._cur_play = 0 - self._task_list = None - self._cur_task_pos = 0 - self._done = False - - def next(self, peek=False): - ''' - Determines and returns the next available task from the playbook, - advancing through the list of plays as it goes. - ''' - - task = None - - # we save these locally so that we can peek at the next task - # without updating the internal state of the iterator - cur_play = self._cur_play - task_list = self._task_list - cur_task_pos = self._cur_task_pos - - while True: - # when we hit the end of the playbook entries list, we set a flag - # and return None to indicate we're there - # FIXME: accessing the entries and parent iterator playbook members - # should be done through accessor functions - if self._done or cur_play > len(self._parent_iterator._playbook._entries) - 1: - self._done = True - return None - - # initialize the task list by calling the .compile() method - # on the play, which will call compile() for all child objects - if task_list is None: - task_list = self._parent_iterator._playbook._entries[cur_play].compile() - - # if we've hit the end of this plays task list, move on to the next - # and reset the position values for the next iteration - if cur_task_pos > len(task_list) - 1: - cur_play += 1 - task_list = None - cur_task_pos = 0 - continue - else: - # FIXME: do tag/conditional evaluation here and advance - # the task position if it should be skipped without - # returning a task - task = task_list[cur_task_pos] - cur_task_pos += 1 - - # Skip the task if it is the member of a role which has already - # been run, unless the role allows multiple executions - if task._role: - # FIXME: this should all be done via member functions - # instead of direct access to internal variables - if task._role.has_run() and not task._role._metadata._allow_duplicates: - continue - - # Break out of the while loop now that we have our task - break - - # If we're not just peeking at the next task, save the internal state - if not peek: - self._cur_play = cur_play - self._task_list = task_list - self._cur_task_pos = cur_task_pos - - return task - -class PlaybookIterator: - - ''' - The main iterator class, which keeps the state of the playbook - on a per-host basis using the above PlaybookState class. - ''' - - def __init__(self, inventory, log_manager, playbook): - self._playbook = playbook - self._log_manager = log_manager - self._host_entries = dict() - self._first_host = None - - # build the per-host dictionary of playbook states - for host in inventory.get_hosts(): - if self._first_host is None: - self._first_host = host - self._host_entries[host.get_name()] = PlaybookState(parent_iterator=self) - - def get_next_task(self, peek=False): - ''' returns the next task for host[0] ''' - return self._host_entries[self._first_host.get_name()].next(peek=peek) - - def get_next_task_for_host(self, host, peek=False): - ''' fetch the next task for the given host ''' - if host.get_name() not in self._host_entries: - raise AnsibleError("invalid host specified for playbook iteration") - - return self._host_entries[host.get_name()].next(peek=peek) diff --git a/v2/test/parsing/yaml/__init__.py b/v2/ansible/executor/process/__init__.py similarity index 100% rename from v2/test/parsing/yaml/__init__.py rename to v2/ansible/executor/process/__init__.py diff --git a/v2/ansible/executor/process/result.py b/v2/ansible/executor/process/result.py new file mode 100644 index 0000000000..cb858017f9 --- /dev/null +++ b/v2/ansible/executor/process/result.py @@ -0,0 +1,157 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import Queue +import multiprocessing +import os +import signal +import sys +import time +import traceback + +HAS_ATFORK=True +try: + from Crypto.Random import atfork +except ImportError: + HAS_ATFORK=False + +from ansible.executor.task_result import TaskResult +from ansible.playbook.handler import Handler +from ansible.playbook.task import Task + +from ansible.utils.debug import debug + +__all__ = ['ResultProcess'] + + +class ResultProcess(multiprocessing.Process): + ''' + The result worker thread, which reads results from the results + queue and fires off callbacks/etc. as necessary. + ''' + + def __init__(self, final_q, workers): + + # takes a task queue manager as the sole param: + self._final_q = final_q + self._workers = workers + self._cur_worker = 0 + self._terminated = False + + super(ResultProcess, self).__init__() + + def _send_result(self, result): + debug("sending result: %s" % (result,)) + self._final_q.put(result, block=False) + debug("done sending result") + + def _read_worker_result(self): + result = None + starting_point = self._cur_worker + while True: + (worker_prc, main_q, rslt_q) = self._workers[self._cur_worker] + self._cur_worker += 1 + if self._cur_worker >= len(self._workers): + self._cur_worker = 0 + + try: + if not rslt_q.empty(): + debug("worker %d has data to read" % self._cur_worker) + result = rslt_q.get(block=False) + debug("got a result from worker %d: %s" % (self._cur_worker, result)) + break + except Queue.Empty: + pass + + if self._cur_worker == starting_point: + break + + return result + + def terminate(self): + self._terminated = True + super(ResultProcess, self).terminate() + + def run(self): + ''' + The main thread execution, which reads from the results queue + indefinitely and sends callbacks/etc. when results are received. + ''' + + if HAS_ATFORK: + atfork() + + while True: + try: + result = self._read_worker_result() + if result is None: + time.sleep(0.1) + continue + + host_name = result._host.get_name() + + # send callbacks, execute other options based on the result status + if result.is_failed(): + self._send_result(('host_task_failed', result)) + elif result.is_unreachable(): + self._send_result(('host_unreachable', result)) + elif result.is_skipped(): + self._send_result(('host_task_skipped', result)) + else: + self._send_result(('host_task_ok', result)) + + # if this task is notifying a handler, do it now + if result._task.notify: + # The shared dictionary for notified handlers is a proxy, which + # does not detect when sub-objects within the proxy are modified. + # So, per the docs, we reassign the list so the proxy picks up and + # notifies all other threads + for notify in result._task.notify: + self._send_result(('notify_handler', notify, result._host)) + + if 'add_host' in result._result: + # this task added a new host (add_host module) + self._send_result(('add_host', result)) + elif 'add_group' in result._result: + # this task added a new group (group_by module) + self._send_result(('add_group', result)) + elif 'ansible_facts' in result._result: + # if this task is registering facts, do that now + if result._task.action in ('set_fact', 'include_vars'): + for (key, value) in result._result['ansible_facts'].iteritems(): + self._send_result(('set_host_var', result._host, key, value)) + else: + self._send_result(('set_host_facts', result._host, result._result['ansible_facts'])) + + # if this task is registering a result, do it now + if result._task.register: + self._send_result(('set_host_var', result._host, result._task.register, result._result)) + + except Queue.Empty: + pass + except (KeyboardInterrupt, IOError, EOFError): + break + except: + # FIXME: we should probably send a proper callback here instead of + # simply dumping a stack trace on the screen + traceback.print_exc() + break + diff --git a/v2/ansible/executor/process/worker.py b/v2/ansible/executor/process/worker.py new file mode 100644 index 0000000000..3419d4ec0a --- /dev/null +++ b/v2/ansible/executor/process/worker.py @@ -0,0 +1,144 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import Queue +import multiprocessing +import os +import signal +import sys +import time +import traceback + +HAS_ATFORK=True +try: + from Crypto.Random import atfork +except ImportError: + HAS_ATFORK=False + +from ansible.errors import AnsibleError, AnsibleConnectionFailure +from ansible.executor.task_executor import TaskExecutor +from ansible.executor.task_result import TaskResult +from ansible.playbook.handler import Handler +from ansible.playbook.task import Task + +from ansible.utils.debug import debug + +__all__ = ['ExecutorProcess'] + + +class WorkerProcess(multiprocessing.Process): + ''' + The worker thread class, which uses TaskExecutor to run tasks + read from a job queue and pushes results into a results queue + for reading later. + ''' + + def __init__(self, tqm, main_q, rslt_q, loader, new_stdin): + + # takes a task queue manager as the sole param: + self._main_q = main_q + self._rslt_q = rslt_q + self._loader = loader + + # dupe stdin, if we have one + try: + fileno = sys.stdin.fileno() + except ValueError: + fileno = None + + self._new_stdin = new_stdin + if not new_stdin and fileno is not None: + try: + self._new_stdin = os.fdopen(os.dup(fileno)) + except OSError, e: + # couldn't dupe stdin, most likely because it's + # not a valid file descriptor, so we just rely on + # using the one that was passed in + pass + + if self._new_stdin: + sys.stdin = self._new_stdin + + super(WorkerProcess, self).__init__() + + def run(self): + ''' + Called when the process is started, and loops indefinitely + until an error is encountered (typically an IOerror from the + queue pipe being disconnected). During the loop, we attempt + to pull tasks off the job queue and run them, pushing the result + onto the results queue. We also remove the host from the blocked + hosts list, to signify that they are ready for their next task. + ''' + + if HAS_ATFORK: + atfork() + + while True: + task = None + try: + if not self._main_q.empty(): + debug("there's work to be done!") + (host, task, job_vars, connection_info) = self._main_q.get(block=False) + debug("got a task/handler to work on: %s" % task) + + new_connection_info = connection_info.set_task_override(task) + + # execute the task and build a TaskResult from the result + debug("running TaskExecutor() for %s/%s" % (host, task)) + executor_result = TaskExecutor(host, task, job_vars, new_connection_info, self._loader).run() + debug("done running TaskExecutor() for %s/%s" % (host, task)) + task_result = TaskResult(host, task, executor_result) + + # put the result on the result queue + debug("sending task result") + self._rslt_q.put(task_result, block=False) + debug("done sending task result") + + else: + time.sleep(0.1) + + except Queue.Empty: + pass + except (IOError, EOFError, KeyboardInterrupt): + break + except AnsibleConnectionFailure: + try: + if task: + task_result = TaskResult(host, task, dict(unreachable=True)) + self._rslt_q.put(task_result, block=False) + except: + # FIXME: most likely an abort, catch those kinds of errors specifically + break + except Exception, e: + debug("WORKER EXCEPTION: %s" % e) + debug("WORKER EXCEPTION: %s" % traceback.format_exc()) + try: + if task: + task_result = TaskResult(host, task, dict(failed=True, exception=traceback.format_exc(), stdout='')) + self._rslt_q.put(task_result, block=False) + except: + # FIXME: most likely an abort, catch those kinds of errors specifically + break + + debug("WORKER PROCESS EXITING") + + diff --git a/v2/ansible/executor/task_executor.py b/v2/ansible/executor/task_executor.py index 878c15c489..324a4f37ea 100644 --- a/v2/ansible/executor/task_executor.py +++ b/v2/ansible/executor/task_executor.py @@ -19,14 +19,282 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from ansible import constants as C +from ansible.errors import AnsibleError, AnsibleParserError +from ansible.executor.connection_info import ConnectionInformation +from ansible.playbook.task import Task +from ansible.plugins import lookup_loader, connection_loader, action_loader + +from ansible.utils.debug import debug + +__all__ = ['TaskExecutor'] + +import json +import time + class TaskExecutor: - def __init__(self, task, host): - pass + ''' + This is the main worker class for the executor pipeline, which + handles loading an action plugin to actually dispatch the task to + a given host. This class roughly corresponds to the old Runner() + class. + ''' - def run(self): - # returns TaskResult - pass + def __init__(self, host, task, job_vars, connection_info, loader): + self._host = host + self._task = task + self._job_vars = job_vars + self._connection_info = connection_info + self._loader = loader - + def run(self): + ''' + The main executor entrypoint, where we determine if the specified + task requires looping and either runs the task with + ''' + debug("in run()") + + try: + items = self._get_loop_items() + if items: + if len(items) > 0: + item_results = self._run_loop(items) + res = dict(results=item_results) + else: + res = dict(changed=False, skipped=True, skipped_reason='No items in the list', results=[]) + else: + debug("calling self._execute()") + res = self._execute() + debug("_execute() done") + + debug("dumping result to json") + result = json.dumps(res) + debug("done dumping result, returning") + return result + except AnsibleError, e: + return dict(failed=True, msg=str(e)) + + def _get_loop_items(self): + ''' + Loads a lookup plugin to handle the with_* portion of a task (if specified), + and returns the items result. + ''' + + items = None + if self._task.loop and self._task.loop in lookup_loader: + items = lookup_loader.get(self._task.loop).run(terms=self._task.loop_args, variables=self._job_vars) + + return items + + def _run_loop(self, items): + ''' + Runs the task with the loop items specified and collates the result + into an array named 'results' which is inserted into the final result + along with the item for which the loop ran. + ''' + + results = [] + + # FIXME: squash items into a flat list here for those modules + # which support it (yum, apt, etc.) but make it smarter + # than it is today? + + for item in items: + # make copies of the job vars and task so we can add the item to + # the variables and re-validate the task with the item variable + task_vars = self._job_vars.copy() + task_vars['item'] = item + + try: + tmp_task = self._task.copy() + tmp_task.post_validate(task_vars) + except AnsibleParserError, e: + results.append(dict(failed=True, msg=str(e))) + continue + + # now we swap the internal task with the re-validate copy, execute, + # and swap them back so we can do the next iteration cleanly + (self._task, tmp_task) = (tmp_task, self._task) + res = self._execute() + (self._task, tmp_task) = (tmp_task, self._task) + + # FIXME: we should be sending back a callback result for each item in the loop here + + # now update the result with the item info, and append the result + # to the list of results + res['item'] = item + results.append(res) + + return results + + def _execute(self): + ''' + The primary workhorse of the executor system, this runs the task + on the specified host (which may be the delegated_to host) and handles + the retry/until and block rescue/always execution + ''' + + self._connection = self._get_connection() + self._handler = self._get_action_handler(connection=self._connection) + + # check to see if this task should be skipped, due to it being a member of a + # role which has already run (and whether that role allows duplicate execution) + if self._task._role and self._task._role.has_run(): + # If there is no metadata, the default behavior is to not allow duplicates, + # if there is metadata, check to see if the allow_duplicates flag was set to true + if self._task._role._metadata is None or self._task._role._metadata and not self._task._role._metadata.allow_duplicates: + debug("task belongs to a role which has already run, but does not allow duplicate execution") + return dict(skipped=True, skip_reason='This role has already been run, but does not allow duplicates') + + if not self._task.evaluate_conditional(self._job_vars): + debug("when evaulation failed, skipping this task") + return dict(skipped=True, skip_reason='Conditional check failed') + + if not self._task.evaluate_tags(self._connection_info.only_tags, self._connection_info.skip_tags): + debug("Tags don't match, skipping this task") + return dict(skipped=True, skip_reason='Skipped due to specified tags') + + retries = self._task.retries + if retries <= 0: + retries = 1 + + delay = self._task.delay + if delay < 0: + delay = 1 + + debug("starting attempt loop") + result = None + for attempt in range(retries): + if attempt > 0: + # FIXME: this should use the callback mechanism + print("FAILED - RETRYING: %s (%d retries left)" % (self._task, retries-attempt)) + result['attempts'] = attempt + 1 + + debug("running the handler") + result = self._handler.run(task_vars=self._job_vars) + debug("handler run complete") + + if self._task.async > 0: + # the async_wrapper module returns dumped JSON via its stdout + # response, so we parse it here and replace the result + try: + result = json.loads(result.get('stdout')) + except ValueError, e: + return dict(failed=True, msg="The async task did not return valid JSON: %s" % str(e)) + + if self._task.poll > 0: + result = self._poll_async_result(result=result) + + if self._task.until: + # make a copy of the job vars here, in case we need to update them + vars_copy = self._job_vars.copy() + # now update them with the registered value, if it is set + if self._task.register: + vars_copy[self._task.register] = result + # now create a pseudo task, and assign the value of the until parameter + # to the when param, so we can use evaluate_conditional() + pseudo_task = Task() + pseudo_task.when = self._task.until + if pseudo_task.evaluate_conditional(vars_copy): + break + elif 'failed' not in result and result.get('rc', 0) == 0: + # if the result is not failed, stop trying + break + + if attempt < retries - 1: + time.sleep(delay) + + debug("attempt loop complete, returning result") + return result + + def _poll_async_result(self, result): + ''' + Polls for the specified JID to be complete + ''' + + async_jid = result.get('ansible_job_id') + if async_jid is None: + return dict(failed=True, msg="No job id was returned by the async task") + + # Create a new psuedo-task to run the async_status module, and run + # that (with a sleep for "poll" seconds between each retry) until the + # async time limit is exceeded. + + async_task = Task().load(dict(action='async_status jid=%s' % async_jid)) + + # Because this is an async task, the action handler is async. However, + # we need the 'normal' action handler for the status check, so get it + # now via the action_loader + normal_handler = action_loader.get( + 'normal', + task=async_task, + connection=self._connection, + connection_info=self._connection_info, + loader=self._loader + ) + + time_left = self._task.async + while time_left > 0: + time.sleep(self._task.poll) + + async_result = normal_handler.run() + if int(async_result.get('finished', 0)) == 1 or 'failed' in async_result or 'skipped' in async_result: + break + + time_left -= self._task.poll + + if int(async_result.get('finished', 0)) != 1: + return dict(failed=True, msg="async task did not complete within the requested time") + else: + return async_result + + def _get_connection(self): + ''' + Reads the connection property for the host, and returns the + correct connection object from the list of connection plugins + ''' + + # FIXME: delegate_to calculation should be done here + # FIXME: calculation of connection params/auth stuff should be done here + + # FIXME: add all port/connection type munging here (accelerated mode, + # fixing up options for ssh, etc.)? and 'smart' conversion + conn_type = self._connection_info.connection + if conn_type == 'smart': + conn_type = 'ssh' + + connection = connection_loader.get(conn_type, self._host, self._connection_info) + if not connection: + raise AnsibleError("the connection plugin '%s' was not found" % conn_type) + + connection.connect() + + return connection + + def _get_action_handler(self, connection): + ''' + Returns the correct action plugin to handle the requestion task action + ''' + + if self._task.action in action_loader: + if self._task.async != 0: + raise AnsibleError("async mode is not supported with the %s module" % module_name) + handler_name = self._task.action + elif self._task.async == 0: + handler_name = 'normal' + else: + handler_name = 'async' + + handler = action_loader.get( + handler_name, + task=self._task, + connection=connection, + connection_info=self._connection_info, + loader=self._loader + ) + if not handler: + raise AnsibleError("the handler '%s' was not found" % handler_name) + + return handler diff --git a/v2/ansible/executor/task_queue_manager.py b/v2/ansible/executor/task_queue_manager.py index a79235bfd0..72ff04d53d 100644 --- a/v2/ansible/executor/task_queue_manager.py +++ b/v2/ansible/executor/task_queue_manager.py @@ -19,18 +19,191 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -class TaskQueueManagerHostPlaybookIterator: +import multiprocessing +import os +import socket +import sys - def __init__(self, host, playbook): - pass +from ansible.errors import AnsibleError +from ansible.executor.connection_info import ConnectionInformation +#from ansible.executor.manager import AnsibleManager +from ansible.executor.play_iterator import PlayIterator +from ansible.executor.process.worker import WorkerProcess +from ansible.executor.process.result import ResultProcess +from ansible.plugins import callback_loader, strategy_loader - def get_next_task(self): - assert False +from ansible.utils.debug import debug - def is_blocked(self): - # depending on strategy, either - # ‘linear’ -- all prev tasks must be completed for all hosts - # ‘free’ -- this host doesn’t have any more work to do - assert False +__all__ = ['TaskQueueManager'] +class TaskQueueManager: + + ''' + This class handles the multiprocessing requirements of Ansible by + creating a pool of worker forks, a result handler fork, and a + manager object with shared datastructures/queues for coordinating + work between all processes. + + The queue manager is responsible for loading the play strategy plugin, + which dispatches the Play's tasks to hosts. + ''' + + def __init__(self, inventory, callback, variable_manager, loader, options): + + self._inventory = inventory + self._variable_manager = variable_manager + self._loader = loader + self._options = options + + # a special flag to help us exit cleanly + self._terminated = False + + # create and start the multiprocessing manager + #self._manager = AnsibleManager() + #self._manager.start() + + # this dictionary is used to keep track of notified handlers + self._notified_handlers = dict() + + # dictionaries to keep track of failed/unreachable hosts + self._failed_hosts = dict() + self._unreachable_hosts = dict() + + self._final_q = multiprocessing.Queue() + + # FIXME: hard-coded the default callback plugin here, which + # should be configurable. + self._callback = callback_loader.get(callback) + + # create the pool of worker threads, based on the number of forks specified + try: + fileno = sys.stdin.fileno() + except ValueError: + fileno = None + + self._workers = [] + for i in range(self._options.forks): + # duplicate stdin, if possible + new_stdin = None + if fileno is not None: + try: + new_stdin = os.fdopen(os.dup(fileno)) + except OSError, e: + # couldn't dupe stdin, most likely because it's + # not a valid file descriptor, so we just rely on + # using the one that was passed in + pass + + main_q = multiprocessing.Queue() + rslt_q = multiprocessing.Queue() + + prc = WorkerProcess(self, main_q, rslt_q, loader, new_stdin) + prc.start() + + self._workers.append((prc, main_q, rslt_q)) + + self._result_prc = ResultProcess(self._final_q, self._workers) + self._result_prc.start() + + def _initialize_notified_handlers(self, handlers): + ''' + Clears and initializes the shared notified handlers dict with entries + for each handler in the play, which is an empty array that will contain + inventory hostnames for those hosts triggering the handler. + ''' + + # Zero the dictionary first by removing any entries there. + # Proxied dicts don't support iteritems, so we have to use keys() + for key in self._notified_handlers.keys(): + del self._notified_handlers[key] + + # FIXME: there is a block compile helper for this... + handler_list = [] + for handler_block in handlers: + handler_list.extend(handler_block.compile()) + + # then initalize it with the handler names from the handler list + for handler in handler_list: + self._notified_handlers[handler.get_name()] = [] + + def run(self, play): + ''' + Iterates over the roles/tasks in a play, using the given (or default) + strategy for queueing tasks. The default is the linear strategy, which + operates like classic Ansible by keeping all hosts in lock-step with + a given task (meaning no hosts move on to the next task until all hosts + are done with the current task). + ''' + + connection_info = ConnectionInformation(play, self._options) + self._callback.set_connection_info(connection_info) + + # run final validation on the play now, to make sure fields are templated + # FIXME: is this even required? Everything is validated and merged at the + # task level, so else in the play needs to be templated + #all_vars = self._vmw.get_vars(loader=self._dlw, play=play) + #all_vars = self._vmw.get_vars(loader=self._loader, play=play) + #play.post_validate(all_vars=all_vars) + + self._callback.playbook_on_play_start(play.name) + + # initialize the shared dictionary containing the notified handlers + self._initialize_notified_handlers(play.handlers) + + # load the specified strategy (or the default linear one) + strategy = strategy_loader.get(play.strategy, self) + if strategy is None: + raise AnsibleError("Invalid play strategy specified: %s" % play.strategy, obj=play._ds) + + # build the iterator + iterator = PlayIterator(inventory=self._inventory, play=play) + + # and run the play using the strategy + return strategy.run(iterator, connection_info) + + def cleanup(self): + debug("RUNNING CLEANUP") + + self.terminate() + + self._final_q.close() + self._result_prc.terminate() + + for (worker_prc, main_q, rslt_q) in self._workers: + rslt_q.close() + main_q.close() + worker_prc.terminate() + + def get_inventory(self): + return self._inventory + + def get_callback(self): + return self._callback + + def get_variable_manager(self): + return self._variable_manager + + def get_loader(self): + return self._loader + + def get_server_pipe(self): + return self._server_pipe + + def get_client_pipe(self): + return self._client_pipe + + def get_pending_results(self): + return self._pending_results + + def get_allow_processing(self): + return self._allow_processing + + def get_notified_handlers(self): + return self._notified_handlers + + def get_workers(self): + return self._workers[:] + + def terminate(self): + self._terminated = True diff --git a/v2/ansible/executor/task_result.py b/v2/ansible/executor/task_result.py index 785fc45992..d911713651 100644 --- a/v2/ansible/executor/task_result.py +++ b/v2/ansible/executor/task_result.py @@ -19,3 +19,39 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from ansible.parsing import DataLoader + +class TaskResult: + ''' + This class is responsible for interpretting the resulting data + from an executed task, and provides helper methods for determining + the result of a given task. + ''' + + def __init__(self, host, task, return_data): + self._host = host + self._task = task + if isinstance(return_data, dict): + self._result = return_data.copy() + else: + self._result = DataLoader().load(return_data) + + def is_changed(self): + return self._check_key('changed') + + def is_skipped(self): + return self._check_key('skipped') + + def is_failed(self): + return self._check_key('failed') or self._result.get('rc', 0) != 0 + + def is_unreachable(self): + return self._check_key('unreachable') + + def _check_key(self, key): + if 'results' in self._result: + flag = False + for res in self._result.get('results', []): + flag |= res.get(key, False) + else: + return self._result.get(key, False) diff --git a/v2/ansible/inventory/__init__.py b/v2/ansible/inventory/__init__.py index 5ad688eaf0..0c43133b92 100644 --- a/v2/ansible/inventory/__init__.py +++ b/v2/ansible/inventory/__init__.py @@ -16,73 +16,661 @@ # along with Ansible. If not, see . ############################################# +import fnmatch +import os +import sys +import re +import stat +import subprocess -# Make coding more python3-ish -from __future__ import (absolute_import, division, print_function) -__metaclass__ = type +from ansible import constants as C +from ansible.errors import * + +from ansible.inventory.ini import InventoryParser +from ansible.inventory.script import InventoryScript +from ansible.inventory.dir import InventoryDirectory +from ansible.inventory.group import Group +from ansible.inventory.host import Host +from ansible.plugins import vars_loader +from ansible.utils.vars import combine_vars + +# FIXME: these defs need to be somewhere else +def is_executable(path): + '''is the given path executable?''' + return (stat.S_IXUSR & os.stat(path)[stat.ST_MODE] + or stat.S_IXGRP & os.stat(path)[stat.ST_MODE] + or stat.S_IXOTH & os.stat(path)[stat.ST_MODE]) + +class Inventory(object): + """ + Host inventory for ansible. + """ + + #__slots__ = [ 'host_list', 'groups', '_restriction', '_also_restriction', '_subset', + # 'parser', '_vars_per_host', '_vars_per_group', '_hosts_cache', '_groups_list', + # '_pattern_cache', '_vault_password', '_vars_plugins', '_playbook_basedir'] + + def __init__(self, loader, variable_manager, host_list=C.DEFAULT_HOST_LIST): + + # the host file file, or script path, or list of hosts + # if a list, inventory data will NOT be loaded + self.host_list = host_list + self._loader = loader + self._variable_manager = variable_manager + + # caching to avoid repeated calculations, particularly with + # external inventory scripts. + + self._vars_per_host = {} + self._vars_per_group = {} + self._hosts_cache = {} + self._groups_list = {} + self._pattern_cache = {} + + # to be set by calling set_playbook_basedir by playbook code + self._playbook_basedir = None + + # the inventory object holds a list of groups + self.groups = [] + + # a list of host(names) to contain current inquiries to + self._restriction = None + self._also_restriction = None + self._subset = None + + if isinstance(host_list, basestring): + if "," in host_list: + host_list = host_list.split(",") + host_list = [ h for h in host_list if h and h.strip() ] + + if host_list is None: + self.parser = None + elif isinstance(host_list, list): + self.parser = None + all = Group('all') + self.groups = [ all ] + ipv6_re = re.compile('\[([a-f:A-F0-9]*[%[0-z]+]?)\](?::(\d+))?') + for x in host_list: + m = ipv6_re.match(x) + if m: + all.add_host(Host(m.groups()[0], m.groups()[1])) + else: + if ":" in x: + tokens = x.rsplit(":", 1) + # if there is ':' in the address, then this is an ipv6 + if ':' in tokens[0]: + all.add_host(Host(x)) + else: + all.add_host(Host(tokens[0], tokens[1])) + else: + all.add_host(Host(x)) + elif os.path.exists(host_list): + if os.path.isdir(host_list): + # Ensure basedir is inside the directory + self.host_list = os.path.join(self.host_list, "") + self.parser = InventoryDirectory(filename=host_list) + self.groups = self.parser.groups.values() + else: + # check to see if the specified file starts with a + # shebang (#!/), so if an error is raised by the parser + # class we can show a more apropos error + shebang_present = False + try: + inv_file = open(host_list) + first_line = inv_file.readlines()[0] + inv_file.close() + if first_line.startswith('#!'): + shebang_present = True + except: + pass + + # FIXME: utils is_executable + if is_executable(host_list): + try: + self.parser = InventoryScript(filename=host_list) + self.groups = self.parser.groups.values() + except: + if not shebang_present: + raise errors.AnsibleError("The file %s is marked as executable, but failed to execute correctly. " % host_list + \ + "If this is not supposed to be an executable script, correct this with `chmod -x %s`." % host_list) + else: + raise + else: + try: + self.parser = InventoryParser(filename=host_list) + self.groups = self.parser.groups.values() + except: + if shebang_present: + raise errors.AnsibleError("The file %s looks like it should be an executable inventory script, but is not marked executable. " % host_list + \ + "Perhaps you want to correct this with `chmod +x %s`?" % host_list) + else: + raise + + vars_loader.add_directory(self.basedir(), with_subdir=True) + else: + raise errors.AnsibleError("Unable to find an inventory file, specify one with -i ?") + + self._vars_plugins = [ x for x in vars_loader.all(self) ] + + # FIXME: shouldn't be required, since the group/host vars file + # management will be done in VariableManager + # get group vars from group_vars/ files and vars plugins + for group in self.groups: + # FIXME: combine_vars + group.vars = combine_vars(group.vars, self.get_group_variables(group.name)) + + # get host vars from host_vars/ files and vars plugins + for host in self.get_hosts(): + # FIXME: combine_vars + host.vars = combine_vars(host.vars, self.get_host_variables(host.name)) + + + def _match(self, str, pattern_str): + try: + if pattern_str.startswith('~'): + return re.search(pattern_str[1:], str) + else: + return fnmatch.fnmatch(str, pattern_str) + except Exception, e: + raise errors.AnsibleError('invalid host pattern: %s' % pattern_str) + + def _match_list(self, items, item_attr, pattern_str): + results = [] + try: + if not pattern_str.startswith('~'): + pattern = re.compile(fnmatch.translate(pattern_str)) + else: + pattern = re.compile(pattern_str[1:]) + except Exception, e: + raise errors.AnsibleError('invalid host pattern: %s' % pattern_str) + + for item in items: + if pattern.match(getattr(item, item_attr)): + results.append(item) + return results -class Inventory: - def __init__(self, host_list=C.DEFAULT_HOST_LIST, vault_password=None): - pass def get_hosts(self, pattern="all"): - pass - def clear_pattern_cache(self): - # Possibly not needed? - pass - def groups_for_host(self, host): - pass - def groups_list(self): - pass - def get_groups(self): - pass - def get_host(self, hostname): - pass - def get_group(self, groupname): - pass - def get_group_variables(self, groupname, update_cached=False, vault_password=None): - pass - def get_variables(self, hostname, update_cached=False, vault_password=None): - pass - def get_host_variables(self, hostname, update_cached=False, vault_password=None): - pass - def add_group(self, group): - pass - def list_hosts(self, pattern="all"): - pass - def list_groups(self): - pass - def get_restriction(self): - pass - def restrict_to(self, restriction): - pass - def also_restrict_to(self, restriction): - pass - def subset(self, subset_pattern): + """ + find all host names matching a pattern string, taking into account any inventory restrictions or + applied subsets. """ + + # process patterns + if isinstance(pattern, list): + pattern = ';'.join(pattern) + patterns = pattern.replace(";",":").split(":") + hosts = self._get_hosts(patterns) + + # exclude hosts not in a subset, if defined + if self._subset: + subset = self._get_hosts(self._subset) + hosts = [ h for h in hosts if h in subset ] + + # exclude hosts mentioned in any restriction (ex: failed hosts) + if self._restriction is not None: + hosts = [ h for h in hosts if h in self._restriction ] + if self._also_restriction is not None: + hosts = [ h for h in hosts if h in self._also_restriction ] + + return hosts + + def _get_hosts(self, patterns): + """ + finds hosts that match a list of patterns. Handles negative + matches as well as intersection matches. + """ + + # Host specifiers should be sorted to ensure consistent behavior + pattern_regular = [] + pattern_intersection = [] + pattern_exclude = [] + for p in patterns: + if p.startswith("!"): + pattern_exclude.append(p) + elif p.startswith("&"): + pattern_intersection.append(p) + elif p: + pattern_regular.append(p) + + # if no regular pattern was given, hence only exclude and/or intersection + # make that magically work + if pattern_regular == []: + pattern_regular = ['all'] + + # when applying the host selectors, run those without the "&" or "!" + # first, then the &s, then the !s. + patterns = pattern_regular + pattern_intersection + pattern_exclude + + hosts = [] + + for p in patterns: + # avoid resolving a pattern that is a plain host + if p in self._hosts_cache: + hosts.append(self.get_host(p)) + else: + that = self.__get_hosts(p) + if p.startswith("!"): + hosts = [ h for h in hosts if h not in that ] + elif p.startswith("&"): + hosts = [ h for h in hosts if h in that ] + else: + to_append = [ h for h in that if h.name not in [ y.name for y in hosts ] ] + hosts.extend(to_append) + return hosts + + def __get_hosts(self, pattern): + """ + finds hosts that positively match a particular pattern. Does not + take into account negative matches. + """ + + if pattern in self._pattern_cache: + return self._pattern_cache[pattern] + + (name, enumeration_details) = self._enumeration_info(pattern) + hpat = self._hosts_in_unenumerated_pattern(name) + result = self._apply_ranges(pattern, hpat) + self._pattern_cache[pattern] = result + return result + + def _enumeration_info(self, pattern): + """ + returns (pattern, limits) taking a regular pattern and finding out + which parts of it correspond to start/stop offsets. limits is + a tuple of (start, stop) or None + """ + + # Do not parse regexes for enumeration info + if pattern.startswith('~'): + return (pattern, None) + + # The regex used to match on the range, which can be [x] or [x-y]. + pattern_re = re.compile("^(.*)\[([-]?[0-9]+)(?:(?:-)([0-9]+))?\](.*)$") + m = pattern_re.match(pattern) + if m: + (target, first, last, rest) = m.groups() + first = int(first) + if last: + if first < 0: + raise errors.AnsibleError("invalid range: negative indices cannot be used as the first item in a range") + last = int(last) + else: + last = first + return (target, (first, last)) + else: + return (pattern, None) + + def _apply_ranges(self, pat, hosts): + """ + given a pattern like foo, that matches hosts, return all of hosts + given a pattern like foo[0:5], where foo matches hosts, return the first 6 hosts + """ + + # If there are no hosts to select from, just return the + # empty set. This prevents trying to do selections on an empty set. + # issue#6258 + if not hosts: + return hosts + + (loose_pattern, limits) = self._enumeration_info(pat) + if not limits: + return hosts + + (left, right) = limits + + if left == '': + left = 0 + if right == '': + right = 0 + left=int(left) + right=int(right) + try: + if left != right: + return hosts[left:right] + else: + return [ hosts[left] ] + except IndexError: + raise errors.AnsibleError("no hosts matching the pattern '%s' were found" % pat) + + def _create_implicit_localhost(self, pattern): + new_host = Host(pattern) + new_host.set_variable("ansible_python_interpreter", sys.executable) + new_host.set_variable("ansible_connection", "local") + new_host.ipv4_address = '127.0.0.1' + + ungrouped = self.get_group("ungrouped") + if ungrouped is None: + self.add_group(Group('ungrouped')) + ungrouped = self.get_group('ungrouped') + self.get_group('all').add_child_group(ungrouped) + ungrouped.add_host(new_host) + return new_host + + def _hosts_in_unenumerated_pattern(self, pattern): + """ Get all host names matching the pattern """ + + results = [] + hosts = [] + hostnames = set() + + # ignore any negative checks here, this is handled elsewhere + pattern = pattern.replace("!","").replace("&", "") + + def __append_host_to_results(host): + if host not in results and host.name not in hostnames: + hostnames.add(host.name) + results.append(host) + + groups = self.get_groups() + for group in groups: + if pattern == 'all': + for host in group.get_hosts(): + __append_host_to_results(host) + else: + if self._match(group.name, pattern): + for host in group.get_hosts(): + __append_host_to_results(host) + else: + matching_hosts = self._match_list(group.get_hosts(), 'name', pattern) + for host in matching_hosts: + __append_host_to_results(host) + + if pattern in ["localhost", "127.0.0.1"] and len(results) == 0: + new_host = self._create_implicit_localhost(pattern) + results.append(new_host) + return results + + def clear_pattern_cache(self): + ''' called exclusively by the add_host plugin to allow patterns to be recalculated ''' + self._pattern_cache = {} + + def groups_for_host(self, host): + if host in self._hosts_cache: + return self._hosts_cache[host].get_groups() + else: + return [] + + def groups_list(self): + if not self._groups_list: + groups = {} + for g in self.groups: + groups[g.name] = [h.name for h in g.get_hosts()] + ancestors = g.get_ancestors() + for a in ancestors: + if a.name not in groups: + groups[a.name] = [h.name for h in a.get_hosts()] + self._groups_list = groups + return self._groups_list + + def get_groups(self): + return self.groups + + def get_host(self, hostname): + if hostname not in self._hosts_cache: + self._hosts_cache[hostname] = self._get_host(hostname) + return self._hosts_cache[hostname] + + def _get_host(self, hostname): + if hostname in ['localhost','127.0.0.1']: + for host in self.get_group('all').get_hosts(): + if host.name in ['localhost', '127.0.0.1']: + return host + return self._create_implicit_localhost(hostname) + else: + for group in self.groups: + for host in group.get_hosts(): + if hostname == host.name: + return host + return None + + def get_group(self, groupname): + for group in self.groups: + if group.name == groupname: + return group + return None + + def get_group_variables(self, groupname, update_cached=False, vault_password=None): + if groupname not in self._vars_per_group or update_cached: + self._vars_per_group[groupname] = self._get_group_variables(groupname, vault_password=vault_password) + return self._vars_per_group[groupname] + + def _get_group_variables(self, groupname, vault_password=None): + + group = self.get_group(groupname) + if group is None: + raise Exception("group not found: %s" % groupname) + + vars = {} + + # plugin.get_group_vars retrieves just vars for specific group + vars_results = [ plugin.get_group_vars(group, vault_password=vault_password) for plugin in self._vars_plugins if hasattr(plugin, 'get_group_vars')] + for updated in vars_results: + if updated is not None: + # FIXME: combine_vars + vars = combine_vars(vars, updated) + + # Read group_vars/ files + # FIXME: combine_vars + vars = combine_vars(vars, self.get_group_vars(group)) + + return vars + + def get_variables(self, hostname, update_cached=False, vault_password=None): + + host = self.get_host(hostname) + if not host: + raise Exception("host not found: %s" % hostname) + return host.get_variables() + + def get_host_variables(self, hostname, update_cached=False, vault_password=None): + + if hostname not in self._vars_per_host or update_cached: + self._vars_per_host[hostname] = self._get_host_variables(hostname, vault_password=vault_password) + return self._vars_per_host[hostname] + + def _get_host_variables(self, hostname, vault_password=None): + + host = self.get_host(hostname) + if host is None: + raise errors.AnsibleError("host not found: %s" % hostname) + + vars = {} + + # plugin.run retrieves all vars (also from groups) for host + vars_results = [ plugin.run(host, vault_password=vault_password) for plugin in self._vars_plugins if hasattr(plugin, 'run')] + for updated in vars_results: + if updated is not None: + # FIXME: combine_vars + vars = combine_vars(vars, updated) + + # plugin.get_host_vars retrieves just vars for specific host + vars_results = [ plugin.get_host_vars(host, vault_password=vault_password) for plugin in self._vars_plugins if hasattr(plugin, 'get_host_vars')] + for updated in vars_results: + if updated is not None: + # FIXME: combine_vars + vars = combine_vars(vars, updated) + + # still need to check InventoryParser per host vars + # which actually means InventoryScript per host, + # which is not performant + if self.parser is not None: + # FIXME: combine_vars + vars = combine_vars(vars, self.parser.get_host_variables(host)) + + # Read host_vars/ files + # FIXME: combine_vars + vars = combine_vars(vars, self.get_host_vars(host)) + + return vars + + def add_group(self, group): + if group.name not in self.groups_list(): + self.groups.append(group) + self._groups_list = None # invalidate internal cache + else: + raise errors.AnsibleError("group already in inventory: %s" % group.name) + + def list_hosts(self, pattern="all"): + + """ return a list of hostnames for a pattern """ + + result = [ h for h in self.get_hosts(pattern) ] + if len(result) == 0 and pattern in ["localhost", "127.0.0.1"]: + result = [pattern] + return result + + def list_groups(self): + return sorted([ g.name for g in self.groups ], key=lambda x: x) + + def restrict_to_hosts(self, restriction): + """ + Restrict list operations to the hosts given in restriction. This is used + to exclude failed hosts in main playbook code, don't use this for other + reasons. + """ + if not isinstance(restriction, list): + restriction = [ restriction ] + self._restriction = restriction + + def also_restrict_to(self, restriction): + """ + Works like restict_to but offers an additional restriction. Playbooks use this + to implement serial behavior. + """ + if not isinstance(restriction, list): + restriction = [ restriction ] + self._also_restriction = restriction + + def subset(self, subset_pattern): + """ Limits inventory results to a subset of inventory that matches a given pattern, such as to select a given geographic of numeric slice amongst - a previous 'hosts' selection that only select roles, or vice versa... + a previous 'hosts' selection that only select roles, or vice versa. Corresponds to --limit parameter to ansible-playbook - """ - pass - def lift_restriction(self): - # HACK -- - pass - def lift_also_restriction(self): - # HACK -- dead host skipping - pass - def is_file(self): - pass - def basedir(self): - pass - def src(self): - pass - def playbook_basedir(self): - pass - def set_playbook_basedir(self, dir): - pass - def get_host_vars(self, host, new_pb_basedir=False): - pass - def get_group_vars(self, group, new_pb_basedir=False): - pass + """ + if subset_pattern is None: + self._subset = None + else: + subset_pattern = subset_pattern.replace(',',':') + subset_pattern = subset_pattern.replace(";",":").split(":") + results = [] + # allow Unix style @filename data + for x in subset_pattern: + if x.startswith("@"): + fd = open(x[1:]) + results.extend(fd.read().split("\n")) + fd.close() + else: + results.append(x) + self._subset = results + + def remove_restriction(self): + """ Do not restrict list operations """ + self._restriction = None + + def lift_also_restriction(self): + """ Clears the also restriction """ + self._also_restriction = None + + def is_file(self): + """ did inventory come from a file? """ + if not isinstance(self.host_list, basestring): + return False + return os.path.exists(self.host_list) + + def basedir(self): + """ if inventory came from a file, what's the directory? """ + if not self.is_file(): + return None + dname = os.path.dirname(self.host_list) + if dname is None or dname == '' or dname == '.': + cwd = os.getcwd() + return os.path.abspath(cwd) + return os.path.abspath(dname) + + def src(self): + """ if inventory came from a file, what's the directory and file name? """ + if not self.is_file(): + return None + return self.host_list + + def playbook_basedir(self): + """ returns the directory of the current playbook """ + return self._playbook_basedir + + def set_playbook_basedir(self, dir): + """ + sets the base directory of the playbook so inventory can use it as a + basedir for host_ and group_vars, and other things. + """ + # Only update things if dir is a different playbook basedir + if dir != self._playbook_basedir: + self._playbook_basedir = dir + # get group vars from group_vars/ files + for group in self.groups: + # FIXME: combine_vars + group.vars = combine_vars(group.vars, self.get_group_vars(group, new_pb_basedir=True)) + # get host vars from host_vars/ files + for host in self.get_hosts(): + # FIXME: combine_vars + host.vars = combine_vars(host.vars, self.get_host_vars(host, new_pb_basedir=True)) + # invalidate cache + self._vars_per_host = {} + self._vars_per_group = {} + + def get_host_vars(self, host, new_pb_basedir=False): + """ Read host_vars/ files """ + return self._get_hostgroup_vars(host=host, group=None, new_pb_basedir=new_pb_basedir) + + def get_group_vars(self, group, new_pb_basedir=False): + """ Read group_vars/ files """ + return self._get_hostgroup_vars(host=None, group=group, new_pb_basedir=new_pb_basedir) + + def _get_hostgroup_vars(self, host=None, group=None, new_pb_basedir=False): + """ + Loads variables from group_vars/ and host_vars/ in directories parallel + to the inventory base directory or in the same directory as the playbook. Variables in the playbook + dir will win over the inventory dir if files are in both. + """ + + results = {} + scan_pass = 0 + _basedir = self.basedir() + + # look in both the inventory base directory and the playbook base directory + # unless we do an update for a new playbook base dir + if not new_pb_basedir: + basedirs = [_basedir, self._playbook_basedir] + else: + basedirs = [self._playbook_basedir] + + for basedir in basedirs: + + # this can happen from particular API usages, particularly if not run + # from /usr/bin/ansible-playbook + if basedir is None: + continue + + scan_pass = scan_pass + 1 + + # it's not an eror if the directory does not exist, keep moving + if not os.path.exists(basedir): + continue + + # save work of second scan if the directories are the same + if _basedir == self._playbook_basedir and scan_pass != 1: + continue + + # FIXME: these should go to VariableManager + if group and host is None: + # load vars in dir/group_vars/name_of_group + base_path = os.path.join(basedir, "group_vars/%s" % group.name) + self._variable_manager.add_group_vars_file(base_path, self._loader) + elif host and group is None: + # same for hostvars in dir/host_vars/name_of_host + base_path = os.path.join(basedir, "host_vars/%s" % host.name) + self._variable_manager.add_host_vars_file(base_path, self._loader) + + # all done, results is a dictionary of variables for this particular host. + return results diff --git a/v2/ansible/inventory/dir.py b/v2/ansible/inventory/dir.py new file mode 100644 index 0000000000..9ac23fff89 --- /dev/null +++ b/v2/ansible/inventory/dir.py @@ -0,0 +1,229 @@ +# (c) 2013, Daniel Hokka Zakrisson +# (c) 2014, Serge van Ginderachter +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +############################################# + +import os +import ansible.constants as C +from ansible.inventory.host import Host +from ansible.inventory.group import Group +from ansible.inventory.ini import InventoryParser +from ansible.inventory.script import InventoryScript +from ansible import utils +from ansible import errors + +class InventoryDirectory(object): + ''' Host inventory parser for ansible using a directory of inventories. ''' + + def __init__(self, filename=C.DEFAULT_HOST_LIST): + self.names = os.listdir(filename) + self.names.sort() + self.directory = filename + self.parsers = [] + self.hosts = {} + self.groups = {} + + for i in self.names: + + # Skip files that end with certain extensions or characters + if any(i.endswith(ext) for ext in ("~", ".orig", ".bak", ".ini", ".retry", ".pyc", ".pyo")): + continue + # Skip hidden files + if i.startswith('.') and not i.startswith('./'): + continue + # These are things inside of an inventory basedir + if i in ("host_vars", "group_vars", "vars_plugins"): + continue + fullpath = os.path.join(self.directory, i) + if os.path.isdir(fullpath): + parser = InventoryDirectory(filename=fullpath) + elif utils.is_executable(fullpath): + parser = InventoryScript(filename=fullpath) + else: + parser = InventoryParser(filename=fullpath) + self.parsers.append(parser) + + # retrieve all groups and hosts form the parser and add them to + # self, don't look at group lists yet, to avoid + # recursion trouble, but just make sure all objects exist in self + newgroups = parser.groups.values() + for group in newgroups: + for host in group.hosts: + self._add_host(host) + for group in newgroups: + self._add_group(group) + + # now check the objects lists so they contain only objects from + # self; membership data in groups is already fine (except all & + # ungrouped, see later), but might still reference objects not in self + for group in self.groups.values(): + # iterate on a copy of the lists, as those lists get changed in + # the loop + # list with group's child group objects: + for child in group.child_groups[:]: + if child != self.groups[child.name]: + group.child_groups.remove(child) + group.child_groups.append(self.groups[child.name]) + # list with group's parent group objects: + for parent in group.parent_groups[:]: + if parent != self.groups[parent.name]: + group.parent_groups.remove(parent) + group.parent_groups.append(self.groups[parent.name]) + # list with group's host objects: + for host in group.hosts[:]: + if host != self.hosts[host.name]: + group.hosts.remove(host) + group.hosts.append(self.hosts[host.name]) + # also check here that the group that contains host, is + # also contained in the host's group list + if group not in self.hosts[host.name].groups: + self.hosts[host.name].groups.append(group) + + # extra checks on special groups all and ungrouped + # remove hosts from 'ungrouped' if they became member of other groups + if 'ungrouped' in self.groups: + ungrouped = self.groups['ungrouped'] + # loop on a copy of ungrouped hosts, as we want to change that list + for host in ungrouped.hosts[:]: + if len(host.groups) > 1: + host.groups.remove(ungrouped) + ungrouped.hosts.remove(host) + + # remove hosts from 'all' if they became member of other groups + # all should only contain direct children, not grandchildren + # direct children should have dept == 1 + if 'all' in self.groups: + allgroup = self.groups['all' ] + # loop on a copy of all's child groups, as we want to change that list + for group in allgroup.child_groups[:]: + # groups might once have beeen added to all, and later be added + # to another group: we need to remove the link wit all then + if len(group.parent_groups) > 1 and allgroup in group.parent_groups: + # real children of all have just 1 parent, all + # this one has more, so not a direct child of all anymore + group.parent_groups.remove(allgroup) + allgroup.child_groups.remove(group) + elif allgroup not in group.parent_groups: + # this group was once added to all, but doesn't list it as + # a parent any more; the info in the group is the correct + # info + allgroup.child_groups.remove(group) + + + def _add_group(self, group): + """ Merge an existing group or add a new one; + Track parent and child groups, and hosts of the new one """ + + if group.name not in self.groups: + # it's brand new, add him! + self.groups[group.name] = group + if self.groups[group.name] != group: + # different object, merge + self._merge_groups(self.groups[group.name], group) + + def _add_host(self, host): + if host.name not in self.hosts: + # Papa's got a brand new host + self.hosts[host.name] = host + if self.hosts[host.name] != host: + # different object, merge + self._merge_hosts(self.hosts[host.name], host) + + def _merge_groups(self, group, newgroup): + """ Merge all of instance newgroup into group, + update parent/child relationships + group lists may still contain group objects that exist in self with + same name, but was instanciated as a different object in some other + inventory parser; these are handled later """ + + # name + if group.name != newgroup.name: + raise errors.AnsibleError("Cannot merge group %s with %s" % (group.name, newgroup.name)) + + # depth + group.depth = max([group.depth, newgroup.depth]) + + # hosts list (host objects are by now already added to self.hosts) + for host in newgroup.hosts: + grouphosts = dict([(h.name, h) for h in group.hosts]) + if host.name in grouphosts: + # same host name but different object, merge + self._merge_hosts(grouphosts[host.name], host) + else: + # new membership, add host to group from self + # group from self will also be added again to host.groups, but + # as different object + group.add_host(self.hosts[host.name]) + # now remove this the old object for group in host.groups + for hostgroup in [g for g in host.groups]: + if hostgroup.name == group.name and hostgroup != self.groups[group.name]: + self.hosts[host.name].groups.remove(hostgroup) + + + # group child membership relation + for newchild in newgroup.child_groups: + # dict with existing child groups: + childgroups = dict([(g.name, g) for g in group.child_groups]) + # check if child of new group is already known as a child + if newchild.name not in childgroups: + self.groups[group.name].add_child_group(newchild) + + # group parent membership relation + for newparent in newgroup.parent_groups: + # dict with existing parent groups: + parentgroups = dict([(g.name, g) for g in group.parent_groups]) + # check if parent of new group is already known as a parent + if newparent.name not in parentgroups: + if newparent.name not in self.groups: + # group does not exist yet in self, import him + self.groups[newparent.name] = newparent + # group now exists but not yet as a parent here + self.groups[newparent.name].add_child_group(group) + + # variables + group.vars = utils.combine_vars(group.vars, newgroup.vars) + + def _merge_hosts(self,host, newhost): + """ Merge all of instance newhost into host """ + + # name + if host.name != newhost.name: + raise errors.AnsibleError("Cannot merge host %s with %s" % (host.name, newhost.name)) + + # group membership relation + for newgroup in newhost.groups: + # dict with existing groups: + hostgroups = dict([(g.name, g) for g in host.groups]) + # check if new group is already known as a group + if newgroup.name not in hostgroups: + if newgroup.name not in self.groups: + # group does not exist yet in self, import him + self.groups[newgroup.name] = newgroup + # group now exists but doesn't have host yet + self.groups[newgroup.name].add_host(host) + + # variables + host.vars = utils.combine_vars(host.vars, newhost.vars) + + def get_host_variables(self, host): + """ Gets additional host variables from all inventories """ + vars = {} + for i in self.parsers: + vars.update(i.get_host_variables(host)) + return vars + diff --git a/v2/ansible/inventory/expand_hosts.py b/v2/ansible/inventory/expand_hosts.py new file mode 100644 index 0000000000..f129740935 --- /dev/null +++ b/v2/ansible/inventory/expand_hosts.py @@ -0,0 +1,116 @@ +# (c) 2012, Zettar Inc. +# Written by Chin Fang +# +# This file is part of Ansible +# +# This module is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this software. If not, see . +# + +''' +This module is for enhancing ansible's inventory parsing capability such +that it can deal with hostnames specified using a simple pattern in the +form of [beg:end], example: [1:5], [a:c], [D:G]. If beg is not specified, +it defaults to 0. + +If beg is given and is left-zero-padded, e.g. '001', it is taken as a +formatting hint when the range is expanded. e.g. [001:010] is to be +expanded into 001, 002 ...009, 010. + +Note that when beg is specified with left zero padding, then the length of +end must be the same as that of beg, else an exception is raised. +''' +import string + +from ansible import errors + +def detect_range(line = None): + ''' + A helper function that checks a given host line to see if it contains + a range pattern described in the docstring above. + + Returnes True if the given line contains a pattern, else False. + ''' + if 0 <= line.find("[") < line.find(":") < line.find("]"): + return True + else: + return False + +def expand_hostname_range(line = None): + ''' + A helper function that expands a given line that contains a pattern + specified in top docstring, and returns a list that consists of the + expanded version. + + The '[' and ']' characters are used to maintain the pseudo-code + appearance. They are replaced in this function with '|' to ease + string splitting. + + References: http://ansible.github.com/patterns.html#hosts-and-groups + ''' + all_hosts = [] + if line: + # A hostname such as db[1:6]-node is considered to consists + # three parts: + # head: 'db' + # nrange: [1:6]; range() is a built-in. Can't use the name + # tail: '-node' + + # Add support for multiple ranges in a host so: + # db[01:10:3]node-[01:10] + # - to do this we split off at the first [...] set, getting the list + # of hosts and then repeat until none left. + # - also add an optional third parameter which contains the step. (Default: 1) + # so range can be [01:10:2] -> 01 03 05 07 09 + # FIXME: make this work for alphabetic sequences too. + + (head, nrange, tail) = line.replace('[','|',1).replace(']','|',1).split('|') + bounds = nrange.split(":") + if len(bounds) != 2 and len(bounds) != 3: + raise errors.AnsibleError("host range incorrectly specified") + beg = bounds[0] + end = bounds[1] + if len(bounds) == 2: + step = 1 + else: + step = bounds[2] + if not beg: + beg = "0" + if not end: + raise errors.AnsibleError("host range end value missing") + if beg[0] == '0' and len(beg) > 1: + rlen = len(beg) # range length formatting hint + if rlen != len(end): + raise errors.AnsibleError("host range format incorrectly specified!") + fill = lambda _: str(_).zfill(rlen) # range sequence + else: + fill = str + + try: + i_beg = string.ascii_letters.index(beg) + i_end = string.ascii_letters.index(end) + if i_beg > i_end: + raise errors.AnsibleError("host range format incorrectly specified!") + seq = string.ascii_letters[i_beg:i_end+1] + except ValueError: # not an alpha range + seq = range(int(beg), int(end)+1, int(step)) + + for rseq in seq: + hname = ''.join((head, fill(rseq), tail)) + + if detect_range(hname): + all_hosts.extend( expand_hostname_range( hname ) ) + else: + all_hosts.append(hname) + + return all_hosts diff --git a/v2/ansible/inventory/group.py b/v2/ansible/inventory/group.py new file mode 100644 index 0000000000..87d6f64dfc --- /dev/null +++ b/v2/ansible/inventory/group.py @@ -0,0 +1,159 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from ansible.utils.debug import debug + +class Group: + ''' a group of ansible hosts ''' + + #__slots__ = [ 'name', 'hosts', 'vars', 'child_groups', 'parent_groups', 'depth', '_hosts_cache' ] + + def __init__(self, name=None): + + self.depth = 0 + self.name = name + self.hosts = [] + self.vars = {} + self.child_groups = [] + self.parent_groups = [] + self._hosts_cache = None + + #self.clear_hosts_cache() + #if self.name is None: + # raise Exception("group name is required") + + def __repr__(self): + return self.get_name() + + def __getstate__(self): + return self.serialize() + + def __setstate__(self, data): + return self.deserialize(data) + + def serialize(self): + parent_groups = [] + for parent in self.parent_groups: + parent_groups.append(parent.serialize()) + + result = dict( + name=self.name, + vars=self.vars.copy(), + parent_groups=parent_groups, + depth=self.depth, + ) + + debug("serializing group, result is: %s" % result) + return result + + def deserialize(self, data): + debug("deserializing group, data is: %s" % data) + self.__init__() + self.name = data.get('name') + self.vars = data.get('vars', dict()) + + parent_groups = data.get('parent_groups', []) + for parent_data in parent_groups: + g = Group() + g.deserialize(parent_data) + self.parent_groups.append(g) + + def get_name(self): + return self.name + + def add_child_group(self, group): + + if self == group: + raise Exception("can't add group to itself") + + # don't add if it's already there + if not group in self.child_groups: + self.child_groups.append(group) + + # update the depth of the child + group.depth = max([self.depth+1, group.depth]) + + # update the depth of the grandchildren + group._check_children_depth() + + # now add self to child's parent_groups list, but only if there + # isn't already a group with the same name + if not self.name in [g.name for g in group.parent_groups]: + group.parent_groups.append(self) + + self.clear_hosts_cache() + + def _check_children_depth(self): + + for group in self.child_groups: + group.depth = max([self.depth+1, group.depth]) + group._check_children_depth() + + def add_host(self, host): + + self.hosts.append(host) + host.add_group(self) + self.clear_hosts_cache() + + def set_variable(self, key, value): + + self.vars[key] = value + + def clear_hosts_cache(self): + + self._hosts_cache = None + for g in self.parent_groups: + g.clear_hosts_cache() + + def get_hosts(self): + + if self._hosts_cache is None: + self._hosts_cache = self._get_hosts() + + return self._hosts_cache + + def _get_hosts(self): + + hosts = [] + seen = {} + for kid in self.child_groups: + kid_hosts = kid.get_hosts() + for kk in kid_hosts: + if kk not in seen: + seen[kk] = 1 + hosts.append(kk) + for mine in self.hosts: + if mine not in seen: + seen[mine] = 1 + hosts.append(mine) + return hosts + + def get_vars(self): + return self.vars.copy() + + def _get_ancestors(self): + + results = {} + for g in self.parent_groups: + results[g.name] = g + results.update(g._get_ancestors()) + return results + + def get_ancestors(self): + + return self._get_ancestors().values() + diff --git a/v2/ansible/inventory/host.py b/v2/ansible/inventory/host.py new file mode 100644 index 0000000000..414ec34b96 --- /dev/null +++ b/v2/ansible/inventory/host.py @@ -0,0 +1,127 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible import constants as C +from ansible.inventory.group import Group +from ansible.utils.vars import combine_vars + +__all__ = ['Host'] + +class Host: + ''' a single ansible host ''' + + #__slots__ = [ 'name', 'vars', 'groups' ] + + def __getstate__(self): + return self.serialize() + + def __setstate__(self, data): + return self.deserialize(data) + + def serialize(self): + groups = [] + for group in self.groups: + groups.append(group.serialize()) + + return dict( + name=self.name, + vars=self.vars.copy(), + ipv4_address=self.ipv4_address, + ipv6_address=self.ipv6_address, + port=self.port, + gathered_facts=self._gathered_facts, + groups=groups, + ) + + def deserialize(self, data): + self.__init__() + + self.name = data.get('name') + self.vars = data.get('vars', dict()) + self.ipv4_address = data.get('ipv4_address', '') + self.ipv6_address = data.get('ipv6_address', '') + self.port = data.get('port') + + groups = data.get('groups', []) + for group_data in groups: + g = Group() + g.deserialize(group_data) + self.groups.append(g) + + def __init__(self, name=None, port=None): + + self.name = name + self.vars = {} + self.groups = [] + + self.ipv4_address = name + self.ipv6_address = name + + if port and port != C.DEFAULT_REMOTE_PORT: + self.port = int(port) + else: + self.port = C.DEFAULT_REMOTE_PORT + + self._gathered_facts = False + + def __repr__(self): + return self.get_name() + + def get_name(self): + return self.name + + @property + def gathered_facts(self): + return self._gathered_facts + + def set_gathered_facts(self, gathered): + self._gathered_facts = gathered + + def add_group(self, group): + + self.groups.append(group) + + def set_variable(self, key, value): + + self.vars[key]=value + + def get_groups(self): + + groups = {} + for g in self.groups: + groups[g.name] = g + ancestors = g.get_ancestors() + for a in ancestors: + groups[a.name] = a + return groups.values() + + def get_vars(self): + + results = {} + groups = self.get_groups() + for group in sorted(groups, key=lambda g: g.depth): + results = combine_vars(results, group.get_vars()) + results = combine_vars(results, self.vars) + results['inventory_hostname'] = self.name + results['inventory_hostname_short'] = self.name.split('.')[0] + results['group_names'] = sorted([ g.name for g in groups if g.name != 'all']) + return results + diff --git a/v2/ansible/inventory/ini.py b/v2/ansible/inventory/ini.py new file mode 100644 index 0000000000..ef3f162aa3 --- /dev/null +++ b/v2/ansible/inventory/ini.py @@ -0,0 +1,215 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +############################################# + +import ast +import shlex +import re + +from ansible import constants as C +from ansible.errors import * +from ansible.inventory.host import Host +from ansible.inventory.group import Group +from ansible.inventory.expand_hosts import detect_range +from ansible.inventory.expand_hosts import expand_hostname_range + +class InventoryParser(object): + """ + Host inventory for ansible. + """ + + def __init__(self, filename=C.DEFAULT_HOST_LIST): + + with open(filename) as fh: + self.lines = fh.readlines() + self.groups = {} + self.hosts = {} + self._parse() + + def _parse(self): + + self._parse_base_groups() + self._parse_group_children() + self._add_allgroup_children() + self._parse_group_variables() + return self.groups + + @staticmethod + def _parse_value(v): + if "#" not in v: + try: + return ast.literal_eval(v) + # Using explicit exceptions. + # Likely a string that literal_eval does not like. We wil then just set it. + except ValueError: + # For some reason this was thought to be malformed. + pass + except SyntaxError: + # Is this a hash with an equals at the end? + pass + return v + + # [webservers] + # alpha + # beta:2345 + # gamma sudo=True user=root + # delta asdf=jkl favcolor=red + + def _add_allgroup_children(self): + + for group in self.groups.values(): + if group.depth == 0 and group.name != 'all': + self.groups['all'].add_child_group(group) + + + def _parse_base_groups(self): + # FIXME: refactor + + ungrouped = Group(name='ungrouped') + all = Group(name='all') + all.add_child_group(ungrouped) + + self.groups = dict(all=all, ungrouped=ungrouped) + active_group_name = 'ungrouped' + + for line in self.lines: + line = self._before_comment(line).strip() + if line.startswith("[") and line.endswith("]"): + active_group_name = line.replace("[","").replace("]","") + if ":vars" in line or ":children" in line: + active_group_name = active_group_name.rsplit(":", 1)[0] + if active_group_name not in self.groups: + new_group = self.groups[active_group_name] = Group(name=active_group_name) + active_group_name = None + elif active_group_name not in self.groups: + new_group = self.groups[active_group_name] = Group(name=active_group_name) + elif line.startswith(";") or line == '': + pass + elif active_group_name: + tokens = shlex.split(line) + if len(tokens) == 0: + continue + hostname = tokens[0] + port = C.DEFAULT_REMOTE_PORT + # Three cases to check: + # 0. A hostname that contains a range pesudo-code and a port + # 1. A hostname that contains just a port + if hostname.count(":") > 1: + # Possible an IPv6 address, or maybe a host line with multiple ranges + # IPv6 with Port XXX:XXX::XXX.port + # FQDN foo.example.com + if hostname.count(".") == 1: + (hostname, port) = hostname.rsplit(".", 1) + elif ("[" in hostname and + "]" in hostname and + ":" in hostname and + (hostname.rindex("]") < hostname.rindex(":")) or + ("]" not in hostname and ":" in hostname)): + (hostname, port) = hostname.rsplit(":", 1) + + hostnames = [] + if detect_range(hostname): + hostnames = expand_hostname_range(hostname) + else: + hostnames = [hostname] + + for hn in hostnames: + host = None + if hn in self.hosts: + host = self.hosts[hn] + else: + host = Host(name=hn, port=port) + self.hosts[hn] = host + if len(tokens) > 1: + for t in tokens[1:]: + if t.startswith('#'): + break + try: + (k,v) = t.split("=", 1) + except ValueError, e: + raise AnsibleError("Invalid ini entry: %s - %s" % (t, str(e))) + if k == 'ansible_ssh_host': + host.ipv4_address = self._parse_value(v) + else: + host.set_variable(k, self._parse_value(v)) + self.groups[active_group_name].add_host(host) + + # [southeast:children] + # atlanta + # raleigh + + def _parse_group_children(self): + group = None + + for line in self.lines: + line = line.strip() + if line is None or line == '': + continue + if line.startswith("[") and ":children]" in line: + line = line.replace("[","").replace(":children]","") + group = self.groups.get(line, None) + if group is None: + group = self.groups[line] = Group(name=line) + elif line.startswith("#") or line.startswith(";"): + pass + elif line.startswith("["): + group = None + elif group: + kid_group = self.groups.get(line, None) + if kid_group is None: + raise AnsibleError("child group is not defined: (%s)" % line) + else: + group.add_child_group(kid_group) + + + # [webservers:vars] + # http_port=1234 + # maxRequestsPerChild=200 + + def _parse_group_variables(self): + group = None + for line in self.lines: + line = line.strip() + if line.startswith("[") and ":vars]" in line: + line = line.replace("[","").replace(":vars]","") + group = self.groups.get(line, None) + if group is None: + raise AnsibleError("can't add vars to undefined group: %s" % line) + elif line.startswith("#") or line.startswith(";"): + pass + elif line.startswith("["): + group = None + elif line == '': + pass + elif group: + if "=" not in line: + raise AnsibleError("variables assigned to group must be in key=value form") + else: + (k, v) = [e.strip() for e in line.split("=", 1)] + group.set_variable(k, self._parse_value(v)) + + def get_host_variables(self, host): + return {} + + def _before_comment(self, msg): + ''' what's the part of a string before a comment? ''' + msg = msg.replace("\#","**NOT_A_COMMENT**") + msg = msg.split("#")[0] + msg = msg.replace("**NOT_A_COMMENT**","#") + return msg + diff --git a/v2/ansible/inventory/script.py b/v2/ansible/inventory/script.py new file mode 100644 index 0000000000..6239be0140 --- /dev/null +++ b/v2/ansible/inventory/script.py @@ -0,0 +1,150 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +############################################# + +import os +import subprocess +import ansible.constants as C +from ansible.inventory.host import Host +from ansible.inventory.group import Group +from ansible.module_utils.basic import json_dict_unicode_to_bytes +from ansible import utils +from ansible import errors +import sys + + +class InventoryScript(object): + ''' Host inventory parser for ansible using external inventory scripts. ''' + + def __init__(self, filename=C.DEFAULT_HOST_LIST): + + # Support inventory scripts that are not prefixed with some + # path information but happen to be in the current working + # directory when '.' is not in PATH. + self.filename = os.path.abspath(filename) + cmd = [ self.filename, "--list" ] + try: + sp = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + except OSError, e: + raise errors.AnsibleError("problem running %s (%s)" % (' '.join(cmd), e)) + (stdout, stderr) = sp.communicate() + self.data = stdout + # see comment about _meta below + self.host_vars_from_top = None + self.groups = self._parse(stderr) + + + def _parse(self, err): + + all_hosts = {} + + # not passing from_remote because data from CMDB is trusted + self.raw = utils.parse_json(self.data) + self.raw = json_dict_unicode_to_bytes(self.raw) + + all = Group('all') + groups = dict(all=all) + group = None + + + if 'failed' in self.raw: + sys.stderr.write(err + "\n") + raise errors.AnsibleError("failed to parse executable inventory script results: %s" % self.raw) + + for (group_name, data) in self.raw.items(): + + # in Ansible 1.3 and later, a "_meta" subelement may contain + # a variable "hostvars" which contains a hash for each host + # if this "hostvars" exists at all then do not call --host for each + # host. This is for efficiency and scripts should still return data + # if called with --host for backwards compat with 1.2 and earlier. + + if group_name == '_meta': + if 'hostvars' in data: + self.host_vars_from_top = data['hostvars'] + continue + + if group_name != all.name: + group = groups[group_name] = Group(group_name) + else: + group = all + host = None + + if not isinstance(data, dict): + data = {'hosts': data} + # is not those subkeys, then simplified syntax, host with vars + elif not any(k in data for k in ('hosts','vars')): + data = {'hosts': [group_name], 'vars': data} + + if 'hosts' in data: + if not isinstance(data['hosts'], list): + raise errors.AnsibleError("You defined a group \"%s\" with bad " + "data for the host list:\n %s" % (group_name, data)) + + for hostname in data['hosts']: + if not hostname in all_hosts: + all_hosts[hostname] = Host(hostname) + host = all_hosts[hostname] + group.add_host(host) + + if 'vars' in data: + if not isinstance(data['vars'], dict): + raise errors.AnsibleError("You defined a group \"%s\" with bad " + "data for variables:\n %s" % (group_name, data)) + + for k, v in data['vars'].iteritems(): + if group.name == all.name: + all.set_variable(k, v) + else: + group.set_variable(k, v) + + # Separate loop to ensure all groups are defined + for (group_name, data) in self.raw.items(): + if group_name == '_meta': + continue + if isinstance(data, dict) and 'children' in data: + for child_name in data['children']: + if child_name in groups: + groups[group_name].add_child_group(groups[child_name]) + + for group in groups.values(): + if group.depth == 0 and group.name != 'all': + all.add_child_group(group) + + return groups + + def get_host_variables(self, host): + """ Runs