From ccb8bcebd3a86ce6d30621cc85e32762b53dfe9a Mon Sep 17 00:00:00 2001 From: Matt Martz Date: Thu, 4 Jun 2015 11:34:56 -0500 Subject: [PATCH] Resync the v1 directory with v1_last. Fixes #11162 --- v1/ansible/constants.py | 8 +- v1/ansible/inventory/__init__.py | 4 +- v1/ansible/module_utils/basic.py | 147 ++++++++++++-------- v1/ansible/module_utils/cloudstack.py | 2 - v1/ansible/module_utils/facts.py | 48 ++++++- v1/ansible/module_utils/powershell.ps1 | 4 +- v1/ansible/module_utils/urls.py | 49 ++++--- v1/ansible/runner/connection_plugins/ssh.py | 67 ++------- v1/ansible/utils/__init__.py | 8 +- v1/ansible/utils/module_docs.py | 11 +- 10 files changed, 200 insertions(+), 148 deletions(-) diff --git a/v1/ansible/constants.py b/v1/ansible/constants.py index a9b4f40bb8..2cdc08d8ce 100644 --- a/v1/ansible/constants.py +++ b/v1/ansible/constants.py @@ -134,7 +134,10 @@ DEFAULT_SU_FLAGS = get_config(p, DEFAULTS, 'su_flags', 'ANSIBLE_SU_FLAG DEFAULT_SU_USER = get_config(p, DEFAULTS, 'su_user', 'ANSIBLE_SU_USER', 'root') DEFAULT_ASK_SU_PASS = get_config(p, DEFAULTS, 'ask_su_pass', 'ANSIBLE_ASK_SU_PASS', False, boolean=True) DEFAULT_GATHERING = get_config(p, DEFAULTS, 'gathering', 'ANSIBLE_GATHERING', 'implicit').lower() -DEFAULT_LOG_PATH = shell_expand_path(get_config(p, DEFAULTS, 'log_path', 'ANSIBLE_LOG_PATH', '')) +DEFAULT_LOG_PATH = shell_expand_path(get_config(p, DEFAULTS, 'log_path', 'ANSIBLE_LOG_PATH', '')) + +# selinux +DEFAULT_SELINUX_SPECIAL_FS = get_config(p, 'selinux', 'special_context_filesystems', None, 'fuse, nfs, vboxsf', islist=True) #TODO: get rid of ternary chain mess BECOME_METHODS = ['sudo','su','pbrun','pfexec','runas'] @@ -176,6 +179,9 @@ DEFAULT_LOAD_CALLBACK_PLUGINS = get_config(p, DEFAULTS, 'bin_ansible_callbacks' DEFAULT_FORCE_HANDLERS = get_config(p, DEFAULTS, 'force_handlers', 'ANSIBLE_FORCE_HANDLERS', False, boolean=True) +RETRY_FILES_ENABLED = get_config(p, DEFAULTS, 'retry_files_enabled', 'ANSIBLE_RETRY_FILES_ENABLED', True, boolean=True) +RETRY_FILES_SAVE_PATH = get_config(p, DEFAULTS, 'retry_files_save_path', 'ANSIBLE_RETRY_FILES_SAVE_PATH', '~/') + # CONNECTION RELATED ANSIBLE_SSH_ARGS = get_config(p, 'ssh_connection', 'ssh_args', 'ANSIBLE_SSH_ARGS', None) ANSIBLE_SSH_CONTROL_PATH = get_config(p, 'ssh_connection', 'control_path', 'ANSIBLE_SSH_CONTROL_PATH', "%(directory)s/ansible-ssh-%%h-%%p-%%r") diff --git a/v1/ansible/inventory/__init__.py b/v1/ansible/inventory/__init__.py index 2048046d3c..f012246e22 100644 --- a/v1/ansible/inventory/__init__.py +++ b/v1/ansible/inventory/__init__.py @@ -36,7 +36,7 @@ class Inventory(object): Host inventory for ansible. """ - __slots__ = [ 'host_list', 'groups', '_restriction', '_also_restriction', '_subset', + __slots__ = [ 'host_list', 'groups', '_restriction', '_also_restriction', '_subset', 'parser', '_vars_per_host', '_vars_per_group', '_hosts_cache', '_groups_list', '_pattern_cache', '_vault_password', '_vars_plugins', '_playbook_basedir'] @@ -53,7 +53,7 @@ class Inventory(object): self._vars_per_host = {} self._vars_per_group = {} self._hosts_cache = {} - self._groups_list = {} + self._groups_list = {} self._pattern_cache = {} # to be set by calling set_playbook_basedir by playbook code diff --git a/v1/ansible/module_utils/basic.py b/v1/ansible/module_utils/basic.py index 54a1a9cfff..e772a12efc 100644 --- a/v1/ansible/module_utils/basic.py +++ b/v1/ansible/module_utils/basic.py @@ -38,6 +38,8 @@ BOOLEANS_TRUE = ['yes', 'on', '1', 'true', 1] BOOLEANS_FALSE = ['no', 'off', '0', 'false', 0] BOOLEANS = BOOLEANS_TRUE + BOOLEANS_FALSE +SELINUX_SPECIAL_FS="<>" + # ansible modules can be written in any language. To simplify # development of Python modules, the functions available here # can be inserted in any module source automatically by including @@ -181,7 +183,8 @@ def get_distribution(): ''' return the distribution name ''' if platform.system() == 'Linux': try: - distribution = platform.linux_distribution()[0].capitalize() + supported_dists = platform._supported_dists + ('arch',) + distribution = platform.linux_distribution(supported_dists=supported_dists)[0].capitalize() if not distribution and os.path.isfile('/etc/system-release'): distribution = platform.linux_distribution(supported_dists=['system'])[0].capitalize() if 'Amazon' in distribution: @@ -334,7 +337,8 @@ class AnsibleModule(object): def __init__(self, argument_spec, bypass_checks=False, no_log=False, check_invalid_arguments=True, mutually_exclusive=None, required_together=None, - required_one_of=None, add_file_common_args=False, supports_check_mode=False): + required_one_of=None, add_file_common_args=False, supports_check_mode=False, + required_if=None): ''' common code for quickly building an ansible module in Python @@ -382,6 +386,7 @@ class AnsibleModule(object): self._check_argument_types() self._check_required_together(required_together) self._check_required_one_of(required_one_of) + self._check_required_if(required_if) self._set_defaults(pre=False) if not self.no_log: @@ -528,10 +533,10 @@ class AnsibleModule(object): path = os.path.dirname(path) return path - def is_nfs_path(self, path): + def is_special_selinux_path(self, path): """ - Returns a tuple containing (True, selinux_context) if the given path - is on a NFS mount point, otherwise the return will be (False, None). + Returns a tuple containing (True, selinux_context) if the given path is on a + NFS or other 'special' fs mount point, otherwise the return will be (False, None). """ try: f = open('/proc/mounts', 'r') @@ -542,9 +547,13 @@ class AnsibleModule(object): path_mount_point = self.find_mount_point(path) for line in mount_data: (device, mount_point, fstype, options, rest) = line.split(' ', 4) - if path_mount_point == mount_point and 'nfs' in fstype: - nfs_context = self.selinux_context(path_mount_point) - return (True, nfs_context) + + if path_mount_point == mount_point: + for fs in SELINUX_SPECIAL_FS.split(','): + if fs in fstype: + special_context = self.selinux_context(path_mount_point) + return (True, special_context) + return (False, None) def set_default_selinux_context(self, path, changed): @@ -562,9 +571,9 @@ class AnsibleModule(object): # Iterate over the current context instead of the # argument context, which may have selevel. - (is_nfs, nfs_context) = self.is_nfs_path(path) - if is_nfs: - new_context = nfs_context + (is_special_se, sp_context) = self.is_special_selinux_path(path) + if is_special_se: + new_context = sp_context else: for i in range(len(cur_context)): if len(context) > i: @@ -861,6 +870,7 @@ class AnsibleModule(object): locale.setlocale(locale.LC_ALL, 'C') os.environ['LANG'] = 'C' os.environ['LC_CTYPE'] = 'C' + os.environ['LC_MESSAGES'] = 'C' except Exception, e: self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" % e) @@ -950,6 +960,20 @@ class AnsibleModule(object): if len(missing) > 0: self.fail_json(msg="missing required arguments: %s" % ",".join(missing)) + def _check_required_if(self, spec): + ''' ensure that parameters which conditionally required are present ''' + if spec is None: + return + for (key, val, requirements) in spec: + missing = [] + if key in self.params and self.params[key] == val: + for check in requirements: + count = self._count_terms(check) + if count == 0: + missing.append(check) + if len(missing) > 0: + self.fail_json(msg="%s is %s but the following are missing: %s" % (key, val, ','.join(missing))) + def _check_argument_values(self): ''' ensure all arguments have the requested values, and there are no stray arguments ''' for (k,v) in self.argument_spec.iteritems(): @@ -1009,57 +1033,60 @@ class AnsibleModule(object): value = self.params[k] is_invalid = False - if wanted == 'str': - if not isinstance(value, basestring): - self.params[k] = str(value) - elif wanted == 'list': - if not isinstance(value, list): - if isinstance(value, basestring): - self.params[k] = value.split(",") - elif isinstance(value, int) or isinstance(value, float): - self.params[k] = [ str(value) ] - else: - is_invalid = True - elif wanted == 'dict': - if not isinstance(value, dict): - if isinstance(value, basestring): - if value.startswith("{"): - try: - self.params[k] = json.loads(value) - except: - (result, exc) = self.safe_eval(value, dict(), include_exceptions=True) - if exc is not None: - self.fail_json(msg="unable to evaluate dictionary for %s" % k) - self.params[k] = result - elif '=' in value: - self.params[k] = dict([x.strip().split("=", 1) for x in value.split(",")]) + try: + if wanted == 'str': + if not isinstance(value, basestring): + self.params[k] = str(value) + elif wanted == 'list': + if not isinstance(value, list): + if isinstance(value, basestring): + self.params[k] = value.split(",") + elif isinstance(value, int) or isinstance(value, float): + self.params[k] = [ str(value) ] else: - self.fail_json(msg="dictionary requested, could not parse JSON or key=value") - else: - is_invalid = True - elif wanted == 'bool': - if not isinstance(value, bool): - if isinstance(value, basestring): - self.params[k] = self.boolean(value) - else: - is_invalid = True - elif wanted == 'int': - if not isinstance(value, int): - if isinstance(value, basestring): - self.params[k] = int(value) - else: - is_invalid = True - elif wanted == 'float': - if not isinstance(value, float): - if isinstance(value, basestring): - self.params[k] = float(value) - else: - is_invalid = True - else: - self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) + is_invalid = True + elif wanted == 'dict': + if not isinstance(value, dict): + if isinstance(value, basestring): + if value.startswith("{"): + try: + self.params[k] = json.loads(value) + except: + (result, exc) = self.safe_eval(value, dict(), include_exceptions=True) + if exc is not None: + self.fail_json(msg="unable to evaluate dictionary for %s" % k) + self.params[k] = result + elif '=' in value: + self.params[k] = dict([x.strip().split("=", 1) for x in value.split(",")]) + else: + self.fail_json(msg="dictionary requested, could not parse JSON or key=value") + else: + is_invalid = True + elif wanted == 'bool': + if not isinstance(value, bool): + if isinstance(value, basestring): + self.params[k] = self.boolean(value) + else: + is_invalid = True + elif wanted == 'int': + if not isinstance(value, int): + if isinstance(value, basestring): + self.params[k] = int(value) + else: + is_invalid = True + elif wanted == 'float': + if not isinstance(value, float): + if isinstance(value, basestring): + self.params[k] = float(value) + else: + is_invalid = True + else: + self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) - if is_invalid: - self.fail_json(msg="argument %s is of invalid type: %s, required: %s" % (k, type(value), wanted)) + if is_invalid: + self.fail_json(msg="argument %s is of invalid type: %s, required: %s" % (k, type(value), wanted)) + except ValueError, e: + self.fail_json(msg="value of argument %s is not of type %s and we were unable to automatically convert" % (k, wanted)) def _set_defaults(self, pre=True): for (k,v) in self.argument_spec.iteritems(): diff --git a/v1/ansible/module_utils/cloudstack.py b/v1/ansible/module_utils/cloudstack.py index 82306b9a0b..e887367c2f 100644 --- a/v1/ansible/module_utils/cloudstack.py +++ b/v1/ansible/module_utils/cloudstack.py @@ -64,14 +64,12 @@ class AnsibleCloudStack: api_secret = self.module.params.get('secret_key') api_url = self.module.params.get('api_url') api_http_method = self.module.params.get('api_http_method') - api_timeout = self.module.params.get('api_timeout') if api_key and api_secret and api_url: self.cs = CloudStack( endpoint=api_url, key=api_key, secret=api_secret, - timeout=api_timeout, method=api_http_method ) else: diff --git a/v1/ansible/module_utils/facts.py b/v1/ansible/module_utils/facts.py index b223c5f5f7..1162e05b9c 100644 --- a/v1/ansible/module_utils/facts.py +++ b/v1/ansible/module_utils/facts.py @@ -99,8 +99,9 @@ class Facts(object): ('/etc/os-release', 'SuSE'), ('/etc/gentoo-release', 'Gentoo'), ('/etc/os-release', 'Debian'), + ('/etc/lsb-release', 'Mandriva'), ('/etc/os-release', 'NA'), - ('/etc/lsb-release', 'Mandriva')) + ) SELINUX_MODE_DICT = { 1: 'enforcing', 0: 'permissive', -1: 'disabled' } # A list of dicts. If there is a platform with more than one @@ -416,11 +417,13 @@ class Facts(object): self.facts['distribution_version'] = self.facts['distribution_version'] + '.' + release.group(1) elif name == 'Debian': data = get_file_content(path) - if 'Debian' in data or 'Raspbian' in data: + if 'Ubuntu' in data: + break # Ubuntu gets correct info from python functions + elif 'Debian' in data or 'Raspbian' in data: release = re.search("PRETTY_NAME=[^(]+ \(?([^)]+?)\)", data) if release: self.facts['distribution_release'] = release.groups()[0] - break + break elif name == 'Mandriva': data = get_file_content(path) if 'Mandriva' in data: @@ -2160,7 +2163,7 @@ class DarwinNetwork(GenericBsdIfconfigNetwork, Network): current_if['media'] = 'Unknown' # Mac does not give us this current_if['media_select'] = words[1] if len(words) > 2: - current_if['media_type'] = words[2][1:] + current_if['media_type'] = words[2][1:-1] if len(words) > 3: current_if['media_options'] = self.get_options(words[3]) @@ -2545,6 +2548,43 @@ class LinuxVirtual(Virtual): self.facts['virtualization_role'] = 'NA' return +class FreeBSDVirtual(Virtual): + """ + This is a FreeBSD-specific subclass of Virtual. It defines + - virtualization_type + - virtualization_role + """ + platform = 'FreeBSD' + + def __init__(self): + Virtual.__init__(self) + + def populate(self): + self.get_virtual_facts() + return self.facts + + def get_virtual_facts(self): + self.facts['virtualization_type'] = '' + self.facts['virtualization_role'] = '' + +class OpenBSDVirtual(Virtual): + """ + This is a OpenBSD-specific subclass of Virtual. It defines + - virtualization_type + - virtualization_role + """ + platform = 'OpenBSD' + + def __init__(self): + Virtual.__init__(self) + + def populate(self): + self.get_virtual_facts() + return self.facts + + def get_virtual_facts(self): + self.facts['virtualization_type'] = '' + self.facts['virtualization_role'] = '' class HPUXVirtual(Virtual): """ diff --git a/v1/ansible/module_utils/powershell.ps1 b/v1/ansible/module_utils/powershell.ps1 index ee7d3ddeca..9606f47783 100644 --- a/v1/ansible/module_utils/powershell.ps1 +++ b/v1/ansible/module_utils/powershell.ps1 @@ -65,7 +65,7 @@ Function Exit-Json($obj) $obj = New-Object psobject } - echo $obj | ConvertTo-Json -Depth 99 + echo $obj | ConvertTo-Json -Compress -Depth 99 Exit } @@ -89,7 +89,7 @@ Function Fail-Json($obj, $message = $null) Set-Attr $obj "msg" $message Set-Attr $obj "failed" $true - echo $obj | ConvertTo-Json -Depth 99 + echo $obj | ConvertTo-Json -Compress -Depth 99 Exit 1 } diff --git a/v1/ansible/module_utils/urls.py b/v1/ansible/module_utils/urls.py index d56cc89395..18317e86ae 100644 --- a/v1/ansible/module_utils/urls.py +++ b/v1/ansible/module_utils/urls.py @@ -50,6 +50,15 @@ try: except: HAS_SSL=False +HAS_MATCH_HOSTNAME = True +try: + from ssl import match_hostname, CertificateError +except ImportError: + try: + from backports.ssl_match_hostname import match_hostname, CertificateError + except ImportError: + HAS_MATCH_HOSTNAME = False + import httplib import os import re @@ -293,11 +302,13 @@ class SSLValidationHandler(urllib2.BaseHandler): connect_result = s.recv(4096) self.validate_proxy_response(connect_result) ssl_s = ssl.wrap_socket(s, ca_certs=tmp_ca_cert_path, cert_reqs=ssl.CERT_REQUIRED) + match_hostname(ssl_s.getpeercert(), self.hostname) else: self.module.fail_json(msg='Unsupported proxy scheme: %s. Currently ansible only supports HTTP proxies.' % proxy_parts.get('scheme')) else: s.connect((self.hostname, self.port)) ssl_s = ssl.wrap_socket(s, ca_certs=tmp_ca_cert_path, cert_reqs=ssl.CERT_REQUIRED) + match_hostname(ssl_s.getpeercert(), self.hostname) # close the ssl connection #ssl_s.unwrap() s.close() @@ -311,6 +322,9 @@ class SSLValidationHandler(urllib2.BaseHandler): 'Use validate_certs=no or make sure your managed systems have a valid CA certificate installed. ' + \ 'Paths checked for this platform: %s' % ", ".join(paths_checked) ) + except CertificateError: + self.module.fail_json(msg="SSL Certificate does not belong to %s. Make sure the url has a certificate that belongs to it or use validate_certs=no (insecure)" % self.hostname) + try: # cleanup the temp file created, don't worry # if it fails for some reason @@ -363,28 +377,29 @@ def fetch_url(module, url, data=None, headers=None, method=None, # FIXME: change the following to use the generic_urlparse function # to remove the indexed references for 'parsed' parsed = urlparse.urlparse(url) - if parsed[0] == 'https': - if not HAS_SSL and validate_certs: + if parsed[0] == 'https' and validate_certs: + if not HAS_SSL: if distribution == 'Redhat': module.fail_json(msg='SSL validation is not available in your version of python. You can use validate_certs=no, however this is unsafe and not recommended. You can also install python-ssl from EPEL') else: module.fail_json(msg='SSL validation is not available in your version of python. You can use validate_certs=no, however this is unsafe and not recommended') + if not HAS_MATCH_HOSTNAME: + module.fail_json(msg='Available SSL validation does not check that the certificate matches the hostname. You can install backports.ssl_match_hostname or update your managed machine to python-2.7.9 or newer. You could also use validate_certs=no, however this is unsafe and not recommended') - elif validate_certs: - # do the cert validation - netloc = parsed[1] - if '@' in netloc: - netloc = netloc.split('@', 1)[1] - if ':' in netloc: - hostname, port = netloc.split(':', 1) - port = int(port) - else: - hostname = netloc - port = 443 - # create the SSL validation handler and - # add it to the list of handlers - ssl_handler = SSLValidationHandler(module, hostname, port) - handlers.append(ssl_handler) + # do the cert validation + netloc = parsed[1] + if '@' in netloc: + netloc = netloc.split('@', 1)[1] + if ':' in netloc: + hostname, port = netloc.split(':', 1) + port = int(port) + else: + hostname = netloc + port = 443 + # create the SSL validation handler and + # add it to the list of handlers + ssl_handler = SSLValidationHandler(module, hostname, port) + handlers.append(ssl_handler) if parsed[0] != 'ftp': username = module.params.get('url_username', '') diff --git a/v1/ansible/runner/connection_plugins/ssh.py b/v1/ansible/runner/connection_plugins/ssh.py index ff7e8e03c8..036175f6a9 100644 --- a/v1/ansible/runner/connection_plugins/ssh.py +++ b/v1/ansible/runner/connection_plugins/ssh.py @@ -16,22 +16,21 @@ # along with Ansible. If not, see . # -import fcntl -import gettext -import hmac import os -import pipes -import pty -import pwd -import random import re -import select -import shlex import subprocess -import time +import shlex +import pipes +import random +import select +import fcntl +import hmac +import pwd +import gettext +import pty from hashlib import sha1 import ansible.constants as C -from ansible.callbacks import vvv, vv +from ansible.callbacks import vvv from ansible import errors from ansible import utils @@ -257,51 +256,7 @@ class Connection(object): vvv("EXEC previous known host file not found for %s" % host) return True - def exec_command(self, *args, **kwargs): - """ Wrapper around _exec_command to retry in the case of an ssh - failure - - Will retry if: - * an exception is caught - * ssh returns 255 - - Will not retry if - * remaining_tries is <2 - * retries limit reached - """ - remaining_tries = C.get_config( - C.p, 'ssh_connection', 'retries', - 'ANSIBLE_SSH_RETRIES', 3, integer=True) + 1 - cmd_summary = "%s %s..." % (args[0], str(kwargs)[:200]) - for attempt in xrange(remaining_tries): - pause = 2 ** attempt - 1 - if pause > 30: - pause = 30 - time.sleep(pause) - try: - return_tuple = self._exec_command(*args, **kwargs) - except Exception as e: - msg = ("ssh_retry: attempt: %d, caught exception(%s) from cmd " - "(%s).") % (attempt, e, cmd_summary) - vv(msg) - if attempt == remaining_tries - 1: - raise e - else: - continue - # 0 = success - # 1-254 = remote command return code - # 255 = failure from the ssh command itself - if return_tuple[0] != 255: - break - else: - msg = ('ssh_retry: attempt: %d, ssh return code is 255. cmd ' - '(%s).') % (attempt, cmd_summary) - vv(msg) - - return return_tuple - - - def _exec_command(self, cmd, tmp_path, become_user=None, sudoable=False, executable='/bin/sh', in_data=None): + def exec_command(self, cmd, tmp_path, become_user=None, sudoable=False, executable='/bin/sh', in_data=None): ''' run a command on the remote host ''' if sudoable and self.runner.become and self.runner.become_method not in self.become_methods_supported: diff --git a/v1/ansible/utils/__init__.py b/v1/ansible/utils/__init__.py index 7ed07a54c8..eb6fa2a712 100644 --- a/v1/ansible/utils/__init__.py +++ b/v1/ansible/utils/__init__.py @@ -1024,9 +1024,9 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False, if runas_opts: # priv user defaults to root later on to enable detecting when this option was given here - parser.add_option('-K', '--ask-sudo-pass', default=False, dest='ask_sudo_pass', action='store_true', + parser.add_option('-K', '--ask-sudo-pass', default=constants.DEFAULT_ASK_SUDO_PASS, dest='ask_sudo_pass', action='store_true', help='ask for sudo password (deprecated, use become)') - parser.add_option('--ask-su-pass', default=False, dest='ask_su_pass', action='store_true', + parser.add_option('--ask-su-pass', default=constants.DEFAULT_ASK_SU_PASS, dest='ask_su_pass', action='store_true', help='ask for su password (deprecated, use become)') parser.add_option("-s", "--sudo", default=constants.DEFAULT_SUDO, action="store_true", dest='sudo', help="run operations with sudo (nopasswd) (deprecated, use become)") @@ -1617,7 +1617,9 @@ def _load_vars_from_folder(folder_path, results, vault_password=None): names.sort() # do not parse hidden files or dirs, e.g. .svn/ - paths = [os.path.join(folder_path, name) for name in names if not name.startswith('.')] + paths = [os.path.join(folder_path, name) for name in names + if not name.startswith('.') + and os.path.splitext(name)[1] in C.YAML_FILENAME_EXTENSIONS] for path in paths: _found, results = _load_vars_from_path(path, results, vault_password=vault_password) return results diff --git a/v1/ansible/utils/module_docs.py b/v1/ansible/utils/module_docs.py index ee99af2cb5..c692057172 100644 --- a/v1/ansible/utils/module_docs.py +++ b/v1/ansible/utils/module_docs.py @@ -23,6 +23,8 @@ import ast import yaml import traceback +from collections import MutableMapping, MutableSet, MutableSequence + from ansible import utils # modules that are ok that they do not have documentation strings @@ -86,7 +88,14 @@ def get_docstring(filename, verbose=False): if not doc.has_key(key): doc[key] = value else: - doc[key].update(value) + if isinstance(doc[key], MutableMapping): + doc[key].update(value) + elif isinstance(doc[key], MutableSet): + doc[key].add(value) + elif isinstance(doc[key], MutableSequence): + doc[key] = sorted(frozenset(doc[key] + value)) + else: + raise Exception("Attempt to extend a documentation fragement of unknown type") if 'EXAMPLES' in (t.id for t in child.targets): plainexamples = child.value.s[1:] # Skip first empty line