diff --git a/lib/ansible/plugins/connection/__init__.py b/lib/ansible/plugins/connection/__init__.py index 536925abfc..e82f771b0d 100644 --- a/lib/ansible/plugins/connection/__init__.py +++ b/lib/ansible/plugins/connection/__init__.py @@ -17,7 +17,8 @@ from ansible.errors import AnsibleError from ansible.module_utils.six import string_types from ansible.module_utils._text import to_bytes, to_text from ansible.plugins import AnsiblePlugin -from ansible.plugins.loader import shell_loader +from ansible.plugins.loader import shell_loader, connection_loader +from ansible.utils.path import unfrackpath try: from __main__ import display @@ -280,3 +281,86 @@ class ConnectionBase(AnsiblePlugin): def reset(self): display.warning("Reset is not implemented for this connection") + + +class NetworkConnectionBase(ConnectionBase): + """ + A base class for network-style connections. + """ + + force_persistence = True + # Do not use _remote_is_local in other connections + _remote_is_local = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(NetworkConnectionBase, self).__init__(play_context, new_stdin, *args, **kwargs) + + self._network_os = self._play_context.network_os + + self._local = connection_loader.get('local', play_context, '/dev/null') + self._local.set_options() + + self._implementation_plugins = [] + + # reconstruct the socket_path and set instance values accordingly + self._ansible_playbook_pid = kwargs.get('ansible_playbook_pid') + self._update_connection_state() + + def __getattr__(self, name): + try: + return self.__dict__[name] + except KeyError: + if not name.startswith('_'): + for plugin in self._implementation_plugins: + method = getattr(plugin, name, None) + if method is not None: + return method + raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) + + def exec_command(self, cmd, in_data=None, sudoable=True): + return self._local.exec_command(cmd, in_data, sudoable) + + def put_file(self, in_path, out_path): + """Transfer a file from local to remote""" + return self._local.put_file(in_path, out_path) + + def fetch_file(self, in_path, out_path): + """Fetch a file from remote to local""" + return self._local.fetch_file(in_path, out_path) + + def reset(self): + ''' + Reset the connection + ''' + if self._socket_path: + display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr) + self.close() + display.vvvv('reset call on connection instance', host=self._play_context.remote_addr) + + def close(self): + if self._connected: + self._connected = False + self._implementation_plugins = [] + + def _update_connection_state(self): + ''' + Reconstruct the connection socket_path and check if it exists + + If the socket path exists then the connection is active and set + both the _socket_path value to the path and the _connected value + to True. If the socket path doesn't exist, leave the socket path + value to None and the _connected value to False + ''' + ssh = connection_loader.get('ssh', class_only=True) + control_path = ssh._create_control_path( + self._play_context.remote_addr, self._play_context.port, + self._play_context.remote_user, self._play_context.connection, + self._ansible_playbook_pid + ) + + tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) + socket_path = unfrackpath(control_path % dict(directory=tmp_path)) + + if os.path.exists(socket_path): + self._connected = True + self._socket_path = socket_path diff --git a/lib/ansible/plugins/connection/httpapi.py b/lib/ansible/plugins/connection/httpapi.py index 5fae66b3ce..c44b5577fa 100644 --- a/lib/ansible/plugins/connection/httpapi.py +++ b/lib/ansible/plugins/connection/httpapi.py @@ -140,9 +140,6 @@ options: - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT """ -import os - -from ansible import constants as C from ansible.errors import AnsibleConnectionFailure from ansible.module_utils._text import to_bytes from ansible.module_utils.six import PY3 @@ -150,9 +147,8 @@ from ansible.module_utils.six.moves import cPickle from ansible.module_utils.six.moves.urllib.error import URLError from ansible.module_utils.urls import open_url from ansible.playbook.play_context import PlayContext -from ansible.plugins.loader import cliconf_loader, connection_loader, httpapi_loader -from ansible.plugins.connection import ConnectionBase -from ansible.utils.path import unfrackpath +from ansible.plugins.loader import cliconf_loader, httpapi_loader +from ansible.plugins.connection import NetworkConnectionBase try: from __main__ import display @@ -161,62 +157,24 @@ except ImportError: display = Display() -class Connection(ConnectionBase): +class Connection(NetworkConnectionBase): '''Network API connection''' transport = 'httpapi' has_pipelining = True - force_persistence = True - # Do not use _remote_is_local in other connections - _remote_is_local = True def __init__(self, play_context, new_stdin, *args, **kwargs): super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) - self._matched_prompt = None - self._matched_pattern = None - self._last_response = None - self._history = list() + self._url = None + self._auth = None - self._local = connection_loader.get('local', play_context, '/dev/null') - self._local.set_options() - - self._implementation_plugins = [] - - self._ansible_playbook_pid = kwargs.get('ansible_playbook_pid') - - self._network_os = self._play_context.network_os if not self._network_os: raise AnsibleConnectionFailure( 'Unable to automatically determine host network os. Please ' 'manually configure ansible_network_os value for this host' ) - - self._url = None - self._auth = None - - # reconstruct the socket_path and set instance values accordingly - self._update_connection_state() - - def __getattr__(self, name): - try: - return self.__dict__[name] - except KeyError: - if not name.startswith('_'): - for plugin in self._implementation_plugins: - method = getattr(plugin, name, None) - if method: - return method - raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) - - def exec_command(self, cmd, in_data=None, sudoable=True): - return self._local.exec_command(cmd, in_data, sudoable) - - def put_file(self, in_path, out_path): - return self._local.put_file(in_path, out_path) - - def fetch_file(self, in_path, out_path): - return self._local.fetch_file(in_path, out_path) + display.display('network_os is set to %s' % self._network_os, log_only=True) def update_play_context(self, pc_data): """Updates the play context information for the connection""" @@ -230,7 +188,11 @@ class Connection(ConnectionBase): messages = ['updating play_context for connection'] if self._play_context.become ^ play_context.become: - self._httpapi.set_become(play_context) + self.set_become(play_context) + if play_context.become is True: + messages.append('authorizing connection') + else: + messages.append('deauthorizing connection') self._play_context = play_context return messages @@ -260,43 +222,6 @@ class Connection(ConnectionBase): self._connected = True - def _update_connection_state(self): - ''' - Reconstruct the connection socket_path and check if it exists - - If the socket path exists then the connection is active and set - both the _socket_path value to the path and the _connected value - to True. If the socket path doesn't exist, leave the socket path - value to None and the _connected value to False - ''' - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path( - self._play_context.remote_addr, self._play_context.port, - self._play_context.remote_user, self._play_context.connection, - self._ansible_playbook_pid - ) - - tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) - socket_path = unfrackpath(cp % dict(directory=tmp_path)) - - if os.path.exists(socket_path): - self._connected = True - self._socket_path = socket_path - - def reset(self): - ''' - Reset the connection - ''' - if self._socket_path: - display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr) - self.close() - display.vvvv('reset call on connection instance', host=self._play_context.remote_addr) - - def close(self): - self._implementation_plugins = [] - if self._connected: - self._connected = False - def send(self, path, data, **kwargs): ''' Sends the command to the device over api diff --git a/lib/ansible/plugins/connection/netconf.py b/lib/ansible/plugins/connection/netconf.py index e3a096d980..17e47fe7ae 100644 --- a/lib/ansible/plugins/connection/netconf.py +++ b/lib/ansible/plugins/connection/netconf.py @@ -5,7 +5,6 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type - DOCUMENTATION = """ --- author: Ansible Networking Team @@ -101,6 +100,32 @@ options: key: host_key_auto_add env: - name: ANSIBLE_HOST_KEY_AUTO_ADD + look_for_keys: + default: True + description: 'TODO: write it' + env: + - name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS + ini: + - section: paramiko_connection + key: look_for_keys + type: boolean + host_key_checking: + description: 'Set this to "False" if you want to avoid host key checking by the underlying tools Ansible uses to connect to the host' + type: boolean + default: True + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + - name: ANSIBLE_NETCONF_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: paramiko_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + - name: ansible_netconf_host_key_checking persistent_connect_timeout: type: int description: @@ -110,8 +135,8 @@ options: will fail default: 30 ini: - section: persistent_connection - key: persistent_connect_timeout + - section: persistent_connection + key: connect_timeout env: - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT persistent_command_timeout: @@ -123,8 +148,8 @@ options: close default: 10 ini: - section: persistent_connection - key: persistent_command_timeout + - section: persistent_connection + key: command_timeout env: - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT """ @@ -133,13 +158,11 @@ import os import logging import json -from ansible import constants as C from ansible.errors import AnsibleConnectionFailure, AnsibleError from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE from ansible.plugins.loader import netconf_loader -from ansible.plugins.connection import ConnectionBase, ensure_connect -from ansible.plugins.connection.local import Connection as LocalConnection +from ansible.plugins.connection import NetworkConnectionBase try: from ncclient import manager @@ -165,36 +188,21 @@ NETWORK_OS_DEVICE_PARAM_MAP = { } -class Connection(ConnectionBase): +class Connection(NetworkConnectionBase): """NetConf connections""" transport = 'netconf' has_pipelining = False - force_persistence = True - # Do not use _remote_is_local in other connections - _remote_is_local = True def __init__(self, play_context, new_stdin, *args, **kwargs): super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) - self._network_os = self._play_context.network_os or 'default' + self._network_os = self._network_os or 'default' display.display('network_os is set to %s' % self._network_os, log_only=True) - self._netconf = None self._manager = None - self._connected = False - self._local = LocalConnection(play_context, new_stdin, *args, **kwargs) - - def __getattr__(self, name): - try: - return self.__dict__[name] - except KeyError: - if name.startswith('_'): - raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) - return getattr(self._netconf, name) - - def exec_command(self, request, in_data=None, sudoable=True): + def exec_command(self, cmd, in_data=None, sudoable=True): """Sends the request to the node and returns the reply The method accepts two forms of request. The first form is as a byte string that represents xml string be send over netconf session. @@ -202,7 +210,7 @@ class Connection(ConnectionBase): """ if self._manager: # to_ele operates on native strings - request = to_ele(to_native(request, errors='surrogate_or_strict')) + request = to_ele(to_native(cmd, errors='surrogate_or_strict')) if request is None: return 'unable to parse request' @@ -215,15 +223,7 @@ class Connection(ConnectionBase): return reply.data_xml else: - return self._local.exec_command(request, in_data, sudoable) - - def put_file(self, in_path, out_path): - """Transfer a file from local to remote""" - return self._local.put_file(in_path, out_path) - - def fetch_file(self, in_path, out_path): - """Fetch a file from remote to local""" - return self._local.fetch_file(in_path, out_path) + return super(Connection, self).exec_command(cmd, in_data, sudoable) def _connect(self): super(Connection, self)._connect() @@ -239,15 +239,14 @@ class Connection(ConnectionBase): if self._play_context.private_key_file: key_filename = os.path.expanduser(self._play_context.private_key_file) - network_os = self._play_context.network_os - - if not network_os: + if self._network_os == 'default': for cls in netconf_loader.all(class_only=True): network_os = cls.guess_network_os(self) if network_os: display.display('discovered network_os %s' % network_os, log_only=True) + self._network_os = network_os - device_params = {'name': (NETWORK_OS_DEVICE_PARAM_MAP.get(network_os) or network_os or 'default')} + device_params = {'name': NETWORK_OS_DEVICE_PARAM_MAP.get(self._network_os) or self._network_os} ssh_config = os.getenv('ANSIBLE_NETCONF_SSH_CONFIG', False) if ssh_config in BOOLEANS_TRUE: @@ -262,8 +261,8 @@ class Connection(ConnectionBase): username=self._play_context.remote_user, password=self._play_context.password, key_filename=str(key_filename), - hostkey_verify=C.HOST_KEY_CHECKING, - look_for_keys=C.PARAMIKO_LOOK_FOR_KEYS, + hostkey_verify=self.get_option('host_key_checking'), + look_for_keys=self.get_option('look_for_keys'), device_params=device_params, allow_agent=self._play_context.allow_agent, timeout=self._play_context.timeout, @@ -272,7 +271,7 @@ class Connection(ConnectionBase): except SSHUnknownHostError as exc: raise AnsibleConnectionFailure(str(exc)) except ImportError as exc: - raise AnsibleError("connection=netconf is not supported on {0}".format(network_os)) + raise AnsibleError("connection=netconf is not supported on {0}".format(self._network_os)) if not self._manager.connected: return 1, b'', b'not connected' @@ -281,26 +280,17 @@ class Connection(ConnectionBase): self._connected = True - self._netconf = netconf_loader.get(network_os, self) - if self._netconf: - display.display('loaded netconf plugin for network_os %s' % network_os, log_only=True) + netconf = netconf_loader.get(self._network_os, self) + if netconf: + display.display('loaded netconf plugin for network_os %s' % self._network_os, log_only=True) else: - self._netconf = netconf_loader.get("default", self) - display.display('unable to load netconf plugin for network_os %s, falling back to default plugin' % network_os) + netconf = netconf_loader.get("default", self) + display.display('unable to load netconf plugin for network_os %s, falling back to default plugin' % self._network_os) + self._implementation_plugins.append(netconf) return 0, to_bytes(self._manager.session_id, errors='surrogate_or_strict'), b'' - def reset(self): - ''' - Reset the connection - ''' - if self._socket_path: - display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr) - self.close() - display.vvvv('reset call on connection instance', host=self._play_context.remote_addr) - def close(self): if self._manager: self._manager.close_session() - self._connected = False super(Connection, self).close() diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py index 26a7974bad..5051ae132f 100644 --- a/lib/ansible/plugins/connection/network_cli.py +++ b/lib/ansible/plugins/connection/network_cli.py @@ -165,15 +165,13 @@ import os import socket import traceback -from ansible import constants as C from ansible.errors import AnsibleConnectionFailure from ansible.module_utils.six import BytesIO, PY3 from ansible.module_utils.six.moves import cPickle from ansible.module_utils._text import to_bytes, to_text from ansible.playbook.play_context import PlayContext +from ansible.plugins.connection import NetworkConnectionBase from ansible.plugins.loader import cliconf_loader, terminal_loader, connection_loader -from ansible.plugins.connection import ConnectionBase -from ansible.utils.path import unfrackpath try: from __main__ import display @@ -182,14 +180,11 @@ except ImportError: display = Display() -class Connection(ConnectionBase): +class Connection(NetworkConnectionBase): ''' CLI (shell) SSH connections on Paramiko ''' transport = 'network_cli' has_pipelining = True - force_persistence = True - # Do not use _remote_is_local in other connections - _remote_is_local = True def __init__(self, play_context, new_stdin, *args, **kwargs): super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) @@ -197,33 +192,17 @@ class Connection(ConnectionBase): self._ssh_shell = None self._matched_prompt = None + self._matched_cmd_prompt = None self._matched_pattern = None self._last_response = None self._history = list() - self._play_context = play_context - - self._local = connection_loader.get('local', play_context, '/dev/null') - self._local.set_options() self._terminal = None - self._cliconf = None - - self._ansible_playbook_pid = kwargs.get('ansible_playbook_pid') + self.paramiko_conn = None if self._play_context.verbosity > 3: logging.getLogger('paramiko').setLevel(logging.DEBUG) - # reconstruct the socket_path and set instance values accordingly - self._update_connection_state() - - def __getattr__(self, name): - try: - return self.__dict__[name] - except KeyError: - if name.startswith('_'): - raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) - return getattr(self._cliconf, name) - def _get_log_channel(self): name = "p=%s u=%s | " % (os.getpid(), getpass.getuser()) name += "paramiko [%s]" % self._play_context.remote_addr @@ -253,13 +232,7 @@ class Connection(ConnectionBase): return self.send(command=cmd) else: - return self._local.exec_command(cmd, in_data, sudoable) - - def put_file(self, in_path, out_path): - return self._local.put_file(in_path, out_path) - - def fetch_file(self, in_path, out_path): - return self._local.fetch_file(in_path, out_path) + return super(Connection, self).exec_command(cmd, in_data, sudoable) def update_play_context(self, pc_data): """Updates the play context information for the connection""" @@ -272,14 +245,14 @@ class Connection(ConnectionBase): play_context.deserialize(pc_data) messages = ['updating play_context for connection'] - if self._play_context.become is False and play_context.become is True: - auth_pass = play_context.become_pass - self._terminal.on_become(passwd=auth_pass) - messages.append('authorizing connection') - - elif self._play_context.become is True and not play_context.become: - self._terminal.on_unbecome() - messages.append('deauthorizing connection') + if self._play_context.become ^ play_context.become: + if play_context.become is True: + auth_pass = play_context.become_pass + self._terminal.on_become(passwd=auth_pass) + messages.append('authorizing connection') + else: + self._terminal.on_unbecome() + messages.append('deauthorizing connection') self._play_context = play_context @@ -292,84 +265,55 @@ class Connection(ConnectionBase): ''' Connects to the remote device and starts the terminal ''' - if self.connected: - return + if not self.connected: + if not self._network_os: + raise AnsibleConnectionFailure( + 'Unable to automatically determine host network os. Please ' + 'manually configure ansible_network_os value for this host' + ) + display.display('network_os is set to %s' % self._network_os, log_only=True) - self.paramiko_conn = connection_loader.get('paramiko', self._play_context, '/dev/null') - self.paramiko_conn._set_log_channel(self._get_log_channel()) - self.paramiko_conn.set_options(direct={'look_for_keys': not bool(self._play_context.password and not self._play_context.private_key_file)}) - self.paramiko_conn.force_persistence = self.force_persistence - ssh = self.paramiko_conn._connect() + self.paramiko_conn = connection_loader.get('paramiko', self._play_context, '/dev/null') + self.paramiko_conn._set_log_channel(self._get_log_channel()) + self.paramiko_conn.set_options(direct={'look_for_keys': not bool(self._play_context.password and not self._play_context.private_key_file)}) + self.paramiko_conn.force_persistence = self.force_persistence + ssh = self.paramiko_conn._connect() - display.vvvv('ssh connection done, setting terminal', host=self._play_context.remote_addr) + host = self.get_option('host') + display.vvvv('ssh connection done, setting terminal', host=host) - self._ssh_shell = ssh.ssh.invoke_shell() - self._ssh_shell.settimeout(self.get_option('persistent_command_timeout')) + self._ssh_shell = ssh.ssh.invoke_shell() + self._ssh_shell.settimeout(self.get_option('persistent_command_timeout')) - network_os = self._play_context.network_os - if not network_os: - raise AnsibleConnectionFailure( - 'Unable to automatically determine host network os. Please ' - 'manually configure ansible_network_os value for this host' - ) + self._terminal = terminal_loader.get(self._network_os, self) + if not self._terminal: + raise AnsibleConnectionFailure('network os %s is not supported' % self._network_os) - self._terminal = terminal_loader.get(network_os, self) - if not self._terminal: - raise AnsibleConnectionFailure('network os %s is not supported' % network_os) + display.vvvv('loaded terminal plugin for network_os %s' % self._network_os, host=host) - display.vvvv('loaded terminal plugin for network_os %s' % network_os, host=self._play_context.remote_addr) + cliconf = cliconf_loader.get(self._network_os, self) + if cliconf: + display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os, host=host) + self._implementation_plugins.append(cliconf) + else: + display.vvvv('unable to load cliconf for network_os %s' % self._network_os) - self._cliconf = cliconf_loader.get(network_os, self) - if self._cliconf: - display.vvvv('loaded cliconf plugin for network_os %s' % network_os, host=self._play_context.remote_addr) - else: - display.vvvv('unable to load cliconf for network_os %s' % network_os) + self.receive(prompts=self._terminal.terminal_initial_prompt, answer=self._terminal.terminal_initial_answer, + newline=self._terminal.terminal_inital_prompt_newline) - self.receive(prompts=self._terminal.terminal_initial_prompt, answer=self._terminal.terminal_initial_answer, - newline=self._terminal.terminal_inital_prompt_newline) + display.vvvv('firing event: on_open_shell()', host=host) + self._terminal.on_open_shell() - display.vvvv('firing event: on_open_shell()', host=self._play_context.remote_addr) - self._terminal.on_open_shell() + if self._play_context.become and self._play_context.become_method == 'enable': + display.vvvv('firing event: on_become', host=host) + auth_pass = self._play_context.become_pass + self._terminal.on_become(passwd=auth_pass) - if self._play_context.become and self._play_context.become_method == 'enable': - display.vvvv('firing event: on_become', host=self._play_context.remote_addr) - auth_pass = self._play_context.become_pass - self._terminal.on_become(passwd=auth_pass) - - display.vvvv('ssh connection has completed successfully', host=self._play_context.remote_addr) - self._connected = True + display.vvvv('ssh connection has completed successfully', host=host) + self._connected = True return self - def _update_connection_state(self): - ''' - Reconstruct the connection socket_path and check if it exists - - If the socket path exists then the connection is active and set - both the _socket_path value to the path and the _connected value - to True. If the socket path doesn't exist, leave the socket path - value to None and the _connected value to False - ''' - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(self._play_context.remote_addr, self._play_context.port, self._play_context.remote_user, self._play_context.connection, - self._ansible_playbook_pid) - - tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) - socket_path = unfrackpath(cp % dict(directory=tmp_path)) - - if os.path.exists(socket_path): - self._connected = True - self._socket_path = socket_path - - def reset(self): - ''' - Reset the connection - ''' - if self._socket_path: - display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr) - self.close() - display.vvvv('reset call on connection instance', host=self._play_context.remote_addr) - def close(self): ''' Close the active connection to the device @@ -387,7 +331,7 @@ class Connection(ConnectionBase): self.paramiko_conn.close() self.paramiko_conn = None display.debug("ssh connection has been closed successfully") - self._connected = False + super(Connection, self).close() def receive(self, command=None, prompts=None, answer=None, newline=True, prompt_retry_check=False): ''' diff --git a/test/units/plugins/connection/test_netconf.py b/test/units/plugins/connection/test_netconf.py index ffd43b55ed..bed851b747 100644 --- a/test/units/plugins/connection/test_netconf.py +++ b/test/units/plugins/connection/test_netconf.py @@ -46,9 +46,11 @@ def import_mock(name, *args): if PY3: with patch('builtins.__import__', side_effect=import_mock): from ansible.plugins.connection import netconf + from ansible.plugins.loader import connection_loader else: with patch('__builtin__.__import__', side_effect=import_mock): from ansible.plugins.connection import netconf + from ansible.plugins.loader import connection_loader class TestNetconfConnectionClass(unittest.TestCase): @@ -68,7 +70,7 @@ class TestNetconfConnectionClass(unittest.TestCase): pc = PlayContext() new_stdin = StringIO() - conn = netconf.Connection(pc, new_stdin) + conn = connection_loader.get('netconf', pc, new_stdin) mock_manager = MagicMock() mock_manager.session_id = '123456789' diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py index b939e24f17..9831b5d6e6 100644 --- a/test/units/plugins/connection/test_network_cli.py +++ b/test/units/plugins/connection/test_network_cli.py @@ -64,10 +64,10 @@ class TestConnectionClass(unittest.TestCase): @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect") def test_network_cli__connect(self, mocked_super, mocked_terminal_loader): pc = PlayContext() + pc.network_os = 'ios' new_stdin = StringIO() conn = connection_loader.get('network_cli', pc, '/dev/null') - pc.network_os = 'ios' conn.ssh = MagicMock() conn.receive = MagicMock()