From 251c9182fe5f0ada907959dc934a50696da579f9 Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Wed, 6 Jul 2016 18:24:24 -0700 Subject: [PATCH] Refactor network and eos module_utils to use a subclass instead of factory function to create the NetworkModule --- lib/ansible/module_utils/eos.py | 6 +- lib/ansible/module_utils/network.py | 89 ++++++++++++++--------------- 2 files changed, 47 insertions(+), 48 deletions(-) diff --git a/lib/ansible/module_utils/eos.py b/lib/ansible/module_utils/eos.py index 2d87d93c45..aa371a91ff 100644 --- a/lib/ansible/module_utils/eos.py +++ b/lib/ansible/module_utils/eos.py @@ -19,8 +19,10 @@ import re -from ansible.module_utils.basic import json, get_exception, AnsibleModule -from ansible.module_utils.network import NetCli, NetworkError, get_module, Command +from ansible.module_utils.basic import json, AnsibleModule +# We make NetworkModule available here for module code to use. example: +# from ansible.module_utils.eos import NetworkModule +from ansible.module_utils.network import NetCli, NetworkError, NetworkModule, Command from ansible.module_utils.network import add_argument, register_transport, to_list from ansible.module_utils.netcfg import NetworkConfig from ansible.module_utils.urls import fetch_url, url_argument_spec diff --git a/lib/ansible/module_utils/network.py b/lib/ansible/module_utils/network.py index d3cedf997e..0b3e4450a3 100644 --- a/lib/ansible/module_utils/network.py +++ b/lib/ansible/module_utils/network.py @@ -48,24 +48,6 @@ def to_list(val): else: return list() -def connect(module): - try: - if not module.connected: - module.connection.connect(module.params) - if module.params['authorize']: - module.connection.authorize(module.params) - except NetworkError: - exc = get_exception() - module.fail_json(msg=exc.message) - -def disconnect(module): - try: - if module.connected: - module.connection.disconnect() - except NetworkError: - exc = get_exception() - module.fail_json(msg=exc.message) - class Command(object): @@ -157,15 +139,38 @@ class NetworkError(Exception): class NetworkModule(AnsibleModule): def __init__(self, *args, **kwargs): + connect_on_load = kwargs.pop('connect_on_load', True) + + argument_spec = NET_TRANSPORT_ARGS.copy() + argument_spec['transport']['choices'] = NET_CONNECTIONS.keys() + argument_spec.update(NET_CONNECTION_ARGS.copy()) + + if kwargs.get('argument_spec'): + argument_spec.update(kwargs['argument_spec']) + kwargs['argument_spec'] = argument_spec + super(NetworkModule, self).__init__(*args, **kwargs) self.connection = None self._cli = None self._config = None + try: + transport = self.params['transport'] or '__default__' + cls = NET_CONNECTIONS[transport] + self.connection = cls() + except KeyError: + self.fail_json(msg='Unknown transport or no default transport specified') + except (TypeError, NetworkError): + exc = get_exception() + self.fail_json(msg=exc.message) + + if connect_on_load: + self.connect() + @property def cli(self): if not self.connected: - connect(self) + self.connect(self) if self._cli: return self._cli self._cli = Cli(self.connection) @@ -174,7 +179,7 @@ class NetworkModule(AnsibleModule): @property def config(self): if not self.connected: - connect(self) + self.connect(self) if self._config: return self._config self._config = Config(self.connection) @@ -193,6 +198,24 @@ class NetworkModule(AnsibleModule): if self.params.get(key) is None and value is not None: self.params[key] = value + def connect(self): + try: + if not self.connected: + self.connection.connect(self.params) + if self.params['authorize']: + self.connection.authorize(self.params) + except NetworkError: + exc = get_exception() + self.fail_json(msg=exc.message) + + def disconnect(self): + try: + if self.connected: + self.connection.disconnect() + except NetworkError: + exc = get_exception() + self.fail_json(msg=exc.message) + class NetCli(object): """Basic paramiko-based ssh transport any NetworkModule can use.""" @@ -245,32 +268,6 @@ class NetCli(object): raise NetworkError(exc.message, commands=commands) -def get_module(connect_on_load=True, **kwargs): - argument_spec = NET_TRANSPORT_ARGS.copy() - argument_spec['transport']['choices'] = NET_CONNECTIONS.keys() - argument_spec.update(NET_CONNECTION_ARGS.copy()) - - if kwargs.get('argument_spec'): - argument_spec.update(kwargs['argument_spec']) - kwargs['argument_spec'] = argument_spec - - module = NetworkModule(**kwargs) - - try: - transport = module.params['transport'] or '__default__' - cls = NET_CONNECTIONS[transport] - module.connection = cls() - except KeyError: - module.fail_json(msg='Unknown transport or no default transport specified') - except (TypeError, NetworkError): - exc = get_exception() - module.fail_json(msg=exc.message) - - if connect_on_load: - connect(module) - - return module - def register_transport(transport, default=False): def register(cls): NET_CONNECTIONS[transport] = cls