diff --git a/lib/ansible/module_utils/network_common.py b/lib/ansible/module_utils/network_common.py index 82c0cc9f8f..add4ce18ea 100644 --- a/lib/ansible/module_utils/network_common.py +++ b/lib/ansible/module_utils/network_common.py @@ -25,13 +25,8 @@ # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -import socket -import struct -import signal - -from ansible.module_utils.basic import get_exception -from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils.six import iteritems +from ansible.module_utils.basic import AnsibleFallbackNotFound def to_list(val): if isinstance(val, (list, tuple, set)): @@ -41,103 +36,116 @@ def to_list(val): else: return list() -class ComplexDict: +class ComplexDict(object): + """Transforms a dict to with an argument spec - def __init__(self, attrs): + This class will take a dict and apply an Ansible argument spec to the + values. The resulting dict will contain all of the keys in the param + with appropriate values set. + + Example:: + + argument_spec = dict( + command=dict(key=True), + display=dict(default='text', choices=['text', 'json']), + validate=dict(type='bool') + ) + transform = ComplexDict(argument_spec, module) + value = dict(command='foo') + result = transform(value) + print result + {'command': 'foo', 'display': 'text', 'validate': None} + + Supported argument spec: + * key - specifies how to map a single value to a dict + * read_from - read and apply the argument_spec from the module + * required - a value is required + * type - type of value (uses AnsibleModule type checker) + * fallback - implements fallback function + * choices - set of valid options + * default - default value + + """ + + def __init__(self, attrs, module): self._attributes = attrs + self._module = module self.attr_names = frozenset(self._attributes.keys()) + + self._has_key = False + for name, attr in iteritems(self._attributes): + if attr.get('read_from'): + spec = self._module.argument_spec.get(attr['read_from']) + if not spec: + raise ValueError('argument_spec %s does not exist' % attr['read_from']) + for key, value in iteritems(spec): + if key not in attr: + attr[key] = value + + if attr.get('key'): + if self._has_key: + raise ValueError('only one key value can be specified') + self_has_key = True + attr['required'] = True + + + def _dict(self, value): + obj = {} for name, attr in iteritems(self._attributes): if attr.get('key'): - attr['required'] = True + obj[name] = value + else: + obj[name] = attr.get('default') + return obj def __call__(self, value): - if isinstance(value, dict): - unknown = set(value.keys()).difference(self.attr_names) - if unknown: - raise ValueError('invalid keys: %s' % ','.join(unknown)) - for name, attr in iteritems(self._attributes): - if attr.get('required') and name not in value: - raise ValueError('missing required attribute %s' % name) - if not value.get(name): - value[name] = attr.get('default') - return value - else: - obj = {} - for name, attr in iteritems(self._attributes): - if attr.get('key'): - obj[name] = value - else: - obj[name] = attr.get('default') - return obj + if not isinstance(value, dict): + value = self._dict(value) + unknown = set(value).difference(self.attr_names) + if unknown: + raise ValueError('invalid keys: %s' % ','.join(unknown)) -class ComplexList: - - def __init__(self, attrs): - self._attributes = attrs - self.attr_names = frozenset(self._attributes.keys()) for name, attr in iteritems(self._attributes): - if attr.get('key'): - attr['required'] = True + if not value.get(name): + value[name] = attr.get('default') + if attr.get('fallback') and not value.get(name): + fallback = attr.get('fallback', (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + value[name] = fallback_strategy(*fallback_args, **fallback_kwargs) + except AnsibleFallbackNotFound: + continue + + if attr.get('required') and value.get(name) is None: + raise ValueError('missing required attribute %s' % name) + + if 'choices' in attr: + if value[name] not in attr['choices']: + raise ValueError('%s must be one of %s, got %s' % \ + (name, ', '.join(attr['choices']), value[name])) + + if value[name] is not None: + value_type = attr.get('type', 'str') + type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type] + type_checker(value[name]) + + return value + +class ComplexList(ComplexDict): + """Extends ```ComplexDict``` to handle a list of dicts """ def __call__(self, values): - objects = list() - for value in values: - if isinstance(value, dict): - for name, attr in iteritems(self._attributes): - if attr.get('required') and name not in value: - raise ValueError('missing required attr %s' % name) - if not value.get(name): - value[name] = attr.get('default') - objects.append(value) - else: - obj = {} - for name, attr in iteritems(self._attributes): - if attr.get('key'): - obj[name] = value - else: - obj[name] = attr.get('default') - objects.append(obj) - return objects + if not isinstance(values, (list, tuple)): + raise TypeError('value must be an ordered iterable') + return [(super(ComplexList, self).__call__(v)) for v in values] -def send_data(s, data): - packed_len = struct.pack('!Q',len(data)) - return s.sendall(packed_len + data) - -def recv_data(s): - header_len = 8 # size of a packed unsigned long long - data = to_bytes("") - while len(data) < header_len: - d = s.recv(header_len - len(data)) - if not d: - return None - data += d - data_len = struct.unpack('!Q',data[:header_len])[0] - data = data[header_len:] - while len(data) < data_len: - d = s.recv(data_len - len(data)) - if not d: - return None - data += d - return data - -def exec_command(module, command): - try: - sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sf.connect(module._socket_path) - - data = "EXEC: %s" % command - send_data(sf, to_bytes(data.strip())) - - rc = int(recv_data(sf), 10) - stdout = recv_data(sf) - stderr = recv_data(sf) - except socket.error: - exc = get_exception() - sf.close() - module.fail_json(msg='unable to connect to socket', err=str(exc)) - - sf.close() - - return (rc, to_native(stdout), to_native(stderr)) diff --git a/lib/ansible/modules/network/eos/eos_command.py b/lib/ansible/modules/network/eos/eos_command.py index 11b0f59eda..ecce61c8c4 100644 --- a/lib/ansible/modules/network/eos/eos_command.py +++ b/lib/ansible/modules/network/eos/eos_command.py @@ -143,13 +143,14 @@ def to_lines(stdout): return lines def parse_commands(module, warnings): - transform = ComplexList(dict( + spec = dict( command=dict(key=True), output=dict(), prompt=dict(), response=dict() - )) + ) + transform = ComplexList(spec, module) commands = transform(module.params['commands']) for index, item in enumerate(commands): diff --git a/lib/ansible/modules/network/eos/eos_system.py b/lib/ansible/modules/network/eos/eos_system.py index ae83ef0fbc..bd4d48be3a 100644 --- a/lib/ansible/modules/network/eos/eos_system.py +++ b/lib/ansible/modules/network/eos/eos_system.py @@ -272,12 +272,12 @@ def map_params_to_obj(module): lookup_source = ComplexList(dict( interface=dict(key=True), vrf=dict() - )) + ), module) name_servers = ComplexList(dict( server=dict(key=True), vrf=dict(default='default') - )) + ), module) for arg, cast in [('lookup_source', lookup_source), ('name_servers', name_servers)]: if module.params[arg] is not None: diff --git a/lib/ansible/modules/network/ios/ios_command.py b/lib/ansible/modules/network/ios/ios_command.py index 2fa2bbb6fc..4abd06c6e4 100644 --- a/lib/ansible/modules/network/ios/ios_command.py +++ b/lib/ansible/modules/network/ios/ios_command.py @@ -149,7 +149,7 @@ def parse_commands(module, warnings): command=dict(key=True), prompt=dict(), response=dict() - )) + ), module) commands = command(module.params['commands']) for index, item in enumerate(commands): if module.check_mode and not item['command'].startswith('show'): diff --git a/lib/ansible/modules/network/ios/ios_system.py b/lib/ansible/modules/network/ios/ios_system.py index 3e2772dfda..1bc6edb53d 100644 --- a/lib/ansible/modules/network/ios/ios_system.py +++ b/lib/ansible/modules/network/ios/ios_system.py @@ -311,17 +311,17 @@ def map_params_to_obj(module): domain_name = ComplexList(dict( name=dict(key=True), vrf=dict() - )) + ), module) domain_search = ComplexList(dict( name=dict(key=True), vrf=dict() - )) + ), module) name_servers = ComplexList(dict( server=dict(key=True), vrf=dict() - )) + ), module) for arg, cast in [('domain_name', domain_name), ('domain_search', domain_search), diff --git a/lib/ansible/modules/network/iosxr/iosxr_command.py b/lib/ansible/modules/network/iosxr/iosxr_command.py index 5560321b92..a81ce18ade 100644 --- a/lib/ansible/modules/network/iosxr/iosxr_command.py +++ b/lib/ansible/modules/network/iosxr/iosxr_command.py @@ -163,7 +163,7 @@ def parse_commands(module, warnings): command=dict(key=True), prompt=dict(), response=dict() - )) + ), module) commands = command(module.params['commands']) for index, item in enumerate(commands): diff --git a/lib/ansible/modules/network/vyos/vyos_command.py b/lib/ansible/modules/network/vyos/vyos_command.py index 63218a94eb..2aefa622a2 100644 --- a/lib/ansible/modules/network/vyos/vyos_command.py +++ b/lib/ansible/modules/network/vyos/vyos_command.py @@ -152,8 +152,7 @@ def parse_commands(module, warnings): command=dict(key=True), prompt=dict(), response=dict(), - )) - + ), module) commands = command(module.params['commands']) for index, cmd in enumerate(commands):