From 58a4d2f7b4551b6222f9def382cac87970fc3aea Mon Sep 17 00:00:00 2001 From: Tim Bielawa Date: Sat, 22 Sep 2012 02:07:49 -0400 Subject: [PATCH] Add 'pause' action plugin and support plugins skipping the host loop. --- lib/ansible/runner/__init__.py | 21 +++- lib/ansible/runner/action_plugins/pause.py | 131 +++++++++++++++++++++ lib/ansible/utils.py | 31 +++-- 3 files changed, 172 insertions(+), 11 deletions(-) create mode 100644 lib/ansible/runner/action_plugins/pause.py diff --git a/lib/ansible/runner/__init__.py b/lib/ansible/runner/__init__.py index e8f24f0948..ccc1ee1bf9 100644 --- a/lib/ansible/runner/__init__.py +++ b/lib/ansible/runner/__init__.py @@ -563,7 +563,26 @@ class Runner(object): hosts = [ (self,x) for x in hosts ] results = None - if self.forks > 1: + + # Check if this is an action plugin. Some of them are designed + # to be ran once per group of hosts. Example module: pause, + # run once per hostgroup, rather than pausing once per each + # host. + p = self.action_plugins.get(self.module_name, None) + if p and getattr(p, 'BYPASS_HOST_LOOP', None): + # Expose the current hostgroup to the bypassing plugins + self.host_set = hosts + # We aren't iterating over all the hosts in this + # group. So, just pick the first host in our group to + # construct the conn object with. + result_data = self._executor(hosts[0][1]).result + # Create a ResultData item for each host in this group + # using the returned result. If we didn't do this we would + # get false reports of dark hosts. + results = [ ReturnData(host=h[1], result=result_data, comm_ok=True) \ + for h in hosts ] + del self.host_set + elif self.forks > 1: results = self._parallel_exec(hosts) else: results = [ self._executor(h[1]) for h in hosts ] diff --git a/lib/ansible/runner/action_plugins/pause.py b/lib/ansible/runner/action_plugins/pause.py new file mode 100644 index 0000000000..e3fbdfdbd0 --- /dev/null +++ b/lib/ansible/runner/action_plugins/pause.py @@ -0,0 +1,131 @@ +# Copyright 2012, Tim Bielawa +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from ansible.callbacks import vv +from ansible.errors import AnsibleError as ae +from ansible.runner.return_data import ReturnData +from ansible.utils import getch +from termios import tcflush, TCIFLUSH +import datetime +import sys +import time + + +class ActionModule(object): + ''' pauses execution for a length or time, or until input is received ''' + + PAUSE_TYPES = ['seconds', 'minutes', 'prompt', ''] + BYPASS_HOST_LOOP = True + + def __init__(self, runner): + self.runner = runner + # Set defaults + self.duration_unit = 'minutes' + self.prompt = None + self.seconds = None + self.result = {'changed': False, + 'rc': 0, + 'stderr': '', + 'stdout': '', + 'start': None, + 'stop': None, + 'delta': None, + } + + def run(self, conn, tmp, module_name, module_args, inject): + ''' run the pause actionmodule ''' + args = self.runner.module_args + hosts = ', '.join(map(lambda x: x[1], self.runner.host_set)) + + (self.pause_type, sep, pause_params) = args.partition('=') + + if self.pause_type == '': + self.pause_type = 'prompt' + elif not self.pause_type in self.PAUSE_TYPES: + raise ae("invalid parameter for pause, '%s'. must be one of: %s" % \ + (self.pause_type, ", ".join(self.PAUSE_TYPES))) + + # error checking + if self.pause_type in ['minutes', 'seconds']: + try: + int(pause_params) + except ValueError: + raise ae("value given to %s parameter invalid: '%s', must be an integer" % \ + self.pause_type, pause_params) + + # The time() command operates in seconds so we need to + # recalculate for minutes=X values. + if self.pause_type == 'minutes': + self.seconds = int(pause_params) * 60 + elif self.pause_type == 'seconds': + self.seconds = int(pause_params) + self.duration_unit = 'seconds' + else: + # if no args are given we pause with a prompt + if pause_params == '': + self.prompt = "[%s]\nPress enter to continue: " % hosts + else: + self.prompt = "[%s]\n%s: " % (hosts, pause_params) + + vv("created 'pause' ActionModule: pause_type=%s, duration_unit=%s, calculated_seconds=%s, prompt=%s" % \ + (self.pause_type, self.duration_unit, self.seconds, self.prompt)) + + try: + self._start() + if not self.pause_type == 'prompt': + print "[%s]\nPausing for %s seconds" % (hosts, self.seconds) + time.sleep(self.seconds) + else: + # Clear out any unflushed buffered input which would + # otherwise be consumed by raw_input() prematurely. + tcflush(sys.stdin, TCIFLUSH) + raw_input(self.prompt) + except KeyboardInterrupt: + while True: + print '\nAction? (a)bort/(c)ontinue: ' + c = getch() + if c == 'c': + # continue playbook evaluation + break + elif c == 'a': + # abort further playbook evaluation + raise ae('user requested abort!') + finally: + self._stop() + + return ReturnData(conn=conn, result=self.result) + + def _start(self): + ''' mark the time of execution for duration calculations later ''' + self.start = time.time() + self.result['start'] = str(datetime.datetime.now()) + if not self.pause_type == 'prompt': + print "(^C-c = continue early, ^C-a = abort)" + + def _stop(self): + ''' calculate the duration we actually paused for and then + finish building the task result string ''' + duration = time.time() - self.start + self.result['stop'] = str(datetime.datetime.now()) + self.result['delta'] = int(duration) + + if self.duration_unit == 'minutes': + duration = round(duration / 60.0, 2) + else: + duration = round(duration, 2) + + self.result['stdout'] = "Paused for %s %s" % (duration, self.duration_unit) diff --git a/lib/ansible/utils.py b/lib/ansible/utils.py index 76459a8a3a..d6eee06516 100644 --- a/lib/ansible/utils.py +++ b/lib/ansible/utils.py @@ -33,6 +33,8 @@ import imp import glob import subprocess import stat +import termios +import tty VERBOSITY=0 @@ -375,7 +377,7 @@ def _gitinfo(): ''' returns a string containing git branch, commit id and commit date ''' result = None repo_path = os.path.join(os.path.dirname(__file__), '..', '..', '.git') - + if os.path.exists(repo_path): # Check if the .git is a file. If it is a file, it means that we are in a submodule structure. if os.path.isfile(repo_path): @@ -397,7 +399,7 @@ def _gitinfo(): commit = f.readline()[:10] f.close() date = time.localtime(os.stat(branch_path).st_mtime) - if time.daylight == 0: + if time.daylight == 0: offset = time.timezone else: offset = time.altzone @@ -414,6 +416,17 @@ def version(prog): result = result + " {0}".format(gitinfo) return result +def getch(): + ''' read in a single character ''' + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(sys.stdin.fileno()) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + #################################################################### # option handling code for /usr/bin/ansible and ansible-playbook # below this line @@ -429,7 +442,7 @@ def increment_debug(option, opt, value, parser): global VERBOSITY VERBOSITY += 1 -def base_parser(constants=C, usage="", output_opts=False, runas_opts=False, +def base_parser(constants=C, usage="", output_opts=False, runas_opts=False, async_opts=False, connect_opts=False, subset_opts=False): ''' create an options parser for any ansible script ''' @@ -515,15 +528,15 @@ def last_non_blank_line(buf): if (len(line) > 0): return line # shouldn't occur unless there's no output - return "" + return "" def filter_leading_non_json_lines(buf): - ''' + ''' used to avoid random output from SSH at the top of JSON output, like messages from tcagetattr, or where dropbear spews MOTD on every single command (which is nuts). - + need to filter anything which starts not with '{', '[', ', '=' or is an empty line. - filter only leading lines since multiline JSON is valid. + filter only leading lines since multiline JSON is valid. ''' filtered_lines = StringIO.StringIO() @@ -536,12 +549,10 @@ def filter_leading_non_json_lines(buf): def import_plugins(directory): modules = {} - for path in glob.glob(os.path.join(directory, '*.py')): + for path in glob.glob(os.path.join(directory, '*.py')): if path.startswith("_"): continue name, ext = os.path.splitext(os.path.basename(path)) if not name.startswith("_"): modules[name] = imp.load_source(name, path) return modules - -