From 9006d4557dc96a45bb4c041eba8b5f7ef23a517b Mon Sep 17 00:00:00 2001 From: Michael DeHaan Date: Tue, 17 Jul 2012 22:33:36 -0400 Subject: [PATCH] Added code to allow insertion of boilerplate into modules to make them able to share lots of code, the result should be a huge reduction of total ansible source, at a slight cost of difficulty in original module development. We need to apply this now to all modules, but may need to have some exemptions to things like command, which should subclass this module. --- hacking/test-module | 18 ++++++ lib/ansible/module_common.py | 114 +++++++++++++++++++++++++++++++++ lib/ansible/runner/__init__.py | 15 +++-- library/slurp | 60 ++++------------- 4 files changed, 154 insertions(+), 53 deletions(-) create mode 100644 lib/ansible/module_common.py diff --git a/hacking/test-module b/hacking/test-module index cdf9dfbf33..50241c1368 100755 --- a/hacking/test-module +++ b/hacking/test-module @@ -31,6 +31,7 @@ import os import subprocess import traceback from ansible import utils +from ansible import module_common try: import json @@ -54,6 +55,23 @@ argsfile = open(argspath, 'w') argsfile.write(args) argsfile.close() +module_fh = open(modfile) +module_data = module_fh.read() +included_boilerplate = module_data.find(module_common.REPLACER) != -1 +module_fh.close() + +if included_boilerplate: + module_data = module_data.replace(module_common.REPLACER, module_common.MODULE_COMMON) + modfile2_path = os.path.expanduser("~/.ansible_module_generated") + print "* including generated source, if any, saving to: %s" % modfile2_path + print "* this will offset any line numbers in tracebacks!" + modfile2 = open(modfile2_path, 'w') + modfile2.write(module_data) + modfile2.close() + modfile = modfile2_path +else: + print "* module boilerplate substitution not requested in module, tracebacks will be unaltered" + os.system("chmod +x %s" % modfile) cmd = subprocess.Popen("%s %s" % (modfile, argspath), shell=True, diff --git a/lib/ansible/module_common.py b/lib/ansible/module_common.py new file mode 100644 index 0000000000..69273d907f --- /dev/null +++ b/lib/ansible/module_common.py @@ -0,0 +1,114 @@ +# (c) 2012, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +REPLACER = "#<>" + +MODULE_COMMON = """ + +# == BEGIN DYNAMICALLY INSERTED CODE == + +# ansible modules can be written in any language. To simplify +# development of Python modules, the functions available here +# can be inserted in any module source automatically by including +# #<> on a blank line by itself inside +# of an ansible module. The source of this common code lives +# in lib/ansible/module_common.py + +try: + import json +except ImportError: + import simplejson as json +import os +import shlex +import subprocess +import sys +import syslog + +class AnsibleModule(object): + + def __init__(self, argument_spec, bypass_checks=False, no_log=False): + ''' + @argument_spec: a hash of argument names, where the values are none if + the types are NOT checked, or a list of valid types where the argument + must be one of those values. All possible arguments must be listed. + + @required_arguments: a list of arguments that must be sent to the module + ''' + + self.argument_spec = argument_spec + (self.params, self.args) = self._load_params() + if not bypass_checks: + self._check_required_arguments() + self._check_argument_types() + if not no_log: + self._log_invocation() + + def _check_required_arguments(self): + ''' ensure all required arguments are present ''' + missing = [] + for (k,v) in self.argument_spec.iteritems(): + (type_spec, required) = v + if required and k not in self.params: + missing.append(k) + if len(missing) > 0: + self.fail_json(msg="missing required arguments: %s" % ",".join(missing)) + + def _check_argument_types(self): + ''' ensure all arguments have the requested values, and there are no stray arguments ''' + for (k,v) in self.argument_spec.iteritems(): + (type_spec, required) = v + if type_spec is not None: + if type(spec) == list: + if v not in spec: + self.fail_json(msg="value of %s must be one of: %s, recieved: %s" % (k, ",".join(spec), v)) + else: + self.fail_json(msg="internal error: do not know how to interpret argument_spec") + + def _load_params(self): + ''' read the input and return a dictionary and the arguments string ''' + if len(sys.argv) == 2 and os.path.exists(sys.argv[1]): + argfile = sys.argv[1] + args = open(argfile, 'r').read() + else: + args = ' '.join(sys.argv[1:]) + items = shlex.split(args) + params = {} + for x in items: + (k, v) = x.split("=") + params[k] = v + return (params, args) + + def _log_invocation(self): + ''' log that ansible ran the module ''' + syslog.openlog('ansible-%s' % os.path.basename(__file__)) + syslog.syslog(syslog.LOG_NOTICE, 'Invoked with %s' % self.args) + + def exit_json(self, rc=0, **kwargs): + ''' return from the module, without error ''' + kwargs['rc'] = rc + print json.dumps(kwargs) + sys.exit(rc) + + def fail_json(self, **kwargs): + ''' return from the module, with an error message ''' + assert 'msg' in kwargs, "implementation error -- msg to explain the error is required" + kwargs['failed'] = True + self.exit_json(rc=1, **kwargs) + +# == END DYNAMICALLY INSERTED CODE === + +""" diff --git a/lib/ansible/runner/__init__.py b/lib/ansible/runner/__init__.py index 27889c1ccc..bcb0612fe1 100644 --- a/lib/ansible/runner/__init__.py +++ b/lib/ansible/runner/__init__.py @@ -36,6 +36,7 @@ import ansible.constants as C import ansible.inventory from ansible import utils from ansible import errors +from ansible import module_common import poller import connection from ansible import callbacks as ans_callbacks @@ -721,16 +722,20 @@ class Runner(object): # use the correct python interpreter for the host host_variables = self.inventory.get_variables(conn.host) + + module_data = "" + with open(in_path) as f: + module_data = f.read() + module_data = module_data.replace(module_common.REPLACER, module_common.MODULE_COMMON) + if 'ansible_python_interpreter' in host_variables: interpreter = host_variables['ansible_python_interpreter'] - with open(in_path) as f: - module_lines = f.readlines() + module_lines = module_data.split('\n') if '#!' and 'python' in module_lines[0]: module_lines[0] = "#!%s" % interpreter - self._transfer_str(conn, tmp, module, '\n'.join(module_lines)) - else: - conn.put_file(in_path, out_path) + module_data = "\n".join(module_lines) + self._transfer_str(conn, tmp, module, module_data) return out_path # ***************************************************** diff --git a/library/slurp b/library/slurp index acb6b5f83f..b23f1da6a4 100755 --- a/library/slurp +++ b/library/slurp @@ -17,60 +17,24 @@ # You should have received a copy of the GNU General Public License # along with Ansible. If not, see . -import sys -import os -import shlex import base64 -import syslog -try: - import json -except ImportError: - import simplejson as json +# this is magic, see lib/ansible/module_common.py +#<> -# =========================================== -# convert arguments of form a=b c=d -# to a dictionary +module = AnsibleModule( + argument_spec = dict( + src=(None,True), + ) +) +source = module.params['src'] -if len(sys.argv) == 1: - sys.exit(1) -argfile = sys.argv[1] -if not os.path.exists(argfile): - sys.exit(1) - -args = open(argfile, 'r').read() -items = shlex.split(args) -syslog.openlog('ansible-%s' % os.path.basename(__file__)) -syslog.syslog(syslog.LOG_NOTICE, 'Invoked with %s' % args) - -params = {} -for x in items: - (k, v) = x.split("=") - params[k] = v -source = os.path.expanduser(params['src']) - -# ========================================== - -# raise an error if there is no template metadata if not os.path.exists(source): - print json.dumps(dict( - failed = 1, - msg = "file not found: %s" % source - )) - sys.exit(1) - + module.fail_json(msg="file not found: %s" % source) if not os.access(source, os.R_OK): - print json.dumps(dict( - failed = 1, - msg = "file is not readable: %s" % source - )) - sys.exit(1) + module.fail_json(msg="file is not readable: %s" % source) -# ========================================== +data = base64.b64encode(file(source).read()) -data = file(source).read() -data = base64.b64encode(data) - -print json.dumps(dict(content=data, encoding='base64')) -sys.exit(0) +module.exit_json(content=data, encoding='base64')