From 9754c67138f77264652606ac26d6e220903dd258 Mon Sep 17 00:00:00 2001 From: Matt Martz Date: Wed, 13 May 2015 10:58:46 -0500 Subject: [PATCH 1/2] Use a decorator to ensure jit connection, instead of an explicit call to _connect --- lib/ansible/executor/task_executor.py | 1 - lib/ansible/plugins/connections/__init__.py | 12 +++++++++++- lib/ansible/plugins/connections/paramiko_ssh.py | 8 ++++++-- lib/ansible/plugins/connections/ssh.py | 6 +++++- lib/ansible/plugins/connections/winrm.py | 6 +++++- 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 69cbb63f47..8de8f7027a 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -210,7 +210,6 @@ class TaskExecutor: # get the connection and the handler for this execution self._connection = self._get_connection(variables) self._connection.set_host_overrides(host=self._host) - self._connection._connect() self._handler = self._get_action_handler(connection=self._connection, templar=templar) diff --git a/lib/ansible/plugins/connections/__init__.py b/lib/ansible/plugins/connections/__init__.py index 897bc58982..da0775530d 100644 --- a/lib/ansible/plugins/connections/__init__.py +++ b/lib/ansible/plugins/connections/__init__.py @@ -22,6 +22,7 @@ __metaclass__ = type from abc import ABCMeta, abstractmethod, abstractproperty +from functools import wraps from six import with_metaclass from ansible import constants as C @@ -32,7 +33,16 @@ from ansible.errors import AnsibleError # which may want to output display/logs too from ansible.utils.display import Display -__all__ = ['ConnectionBase'] +__all__ = ['ConnectionBase', 'ensure_connect'] + + +def ensure_connect(func): + @wraps(func) + def wrapped(self, *args, **kwargs): + self._connect() + return func(self, *args, **kwargs) + return wrapped + class ConnectionBase(with_metaclass(ABCMeta, object)): ''' diff --git a/lib/ansible/plugins/connections/paramiko_ssh.py b/lib/ansible/plugins/connections/paramiko_ssh.py index 0d7a82c34b..8beaecf492 100644 --- a/lib/ansible/plugins/connections/paramiko_ssh.py +++ b/lib/ansible/plugins/connections/paramiko_ssh.py @@ -41,7 +41,7 @@ from binascii import hexlify from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase +from ansible.plugins.connections import ConnectionBase, ensure_connect from ansible.utils.path import makedirs_safe AUTHENTICITY_MSG=""" @@ -61,6 +61,7 @@ with warnings.catch_warnings(): except ImportError: pass + class MyAddPolicy(object): """ Based on AutoAddPolicy in paramiko so we can determine when keys are added @@ -188,6 +189,7 @@ class Connection(ConnectionBase): return ssh + @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): ''' run a command on the remote host ''' @@ -248,6 +250,7 @@ class Connection(ConnectionBase): return (chan.recv_exit_status(), '', no_prompt_out + stdout, no_prompt_out + stderr) + @ensure_connect def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' @@ -272,9 +275,10 @@ class Connection(ConnectionBase): if cache_key in SFTP_CONNECTION_CACHE: return SFTP_CONNECTION_CACHE[cache_key] else: - result = SFTP_CONNECTION_CACHE[cache_key] = self.connect().ssh.open_sftp() + result = SFTP_CONNECTION_CACHE[cache_key] = self._connect().ssh.open_sftp() return result + @ensure_connect def fetch_file(self, in_path, out_path): ''' save a remote file to the specified path ''' diff --git a/lib/ansible/plugins/connections/ssh.py b/lib/ansible/plugins/connections/ssh.py index b3ada343c0..5a435093d0 100644 --- a/lib/ansible/plugins/connections/ssh.py +++ b/lib/ansible/plugins/connections/ssh.py @@ -34,7 +34,8 @@ from hashlib import sha1 from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase +from ansible.plugins.connections import ConnectionBase, ensure_connect + class Connection(ConnectionBase): ''' ssh based connections ''' @@ -269,6 +270,7 @@ class Connection(ConnectionBase): self._display.vvv("EXEC previous known host file not found for {0}".format(host)) return True + @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): ''' run a command on the remote host ''' @@ -390,6 +392,7 @@ class Connection(ConnectionBase): return (p.returncode, '', no_prompt_out + stdout, no_prompt_err + stderr) + @ensure_connect def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' self._display.vvv("PUT {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr) @@ -425,6 +428,7 @@ class Connection(ConnectionBase): if returncode != 0: raise AnsibleError("failed to transfer file to {0}:\n{1}\n{2}".format(out_path, stdout, stderr)) + @ensure_connect def fetch_file(self, in_path, out_path): ''' fetch a file from remote to local ''' self._display.vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr) diff --git a/lib/ansible/plugins/connections/winrm.py b/lib/ansible/plugins/connections/winrm.py index f16da0f6e6..ee28749189 100644 --- a/lib/ansible/plugins/connections/winrm.py +++ b/lib/ansible/plugins/connections/winrm.py @@ -42,10 +42,11 @@ except ImportError: from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase +from ansible.plugins.connections import ConnectionBase, ensure_connect from ansible.plugins import shell_loader from ansible.utils.path import makedirs_safe + class Connection(ConnectionBase): '''WinRM connections over HTTP/HTTPS.''' @@ -151,6 +152,7 @@ class Connection(ConnectionBase): self.protocol = self._winrm_connect() return self + @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): cmd = cmd.encode('utf-8') @@ -172,6 +174,7 @@ class Connection(ConnectionBase): raise AnsibleError("failed to exec cmd %s" % cmd) return (result.status_code, '', result.std_out.encode('utf-8'), result.std_err.encode('utf-8')) + @ensure_connect def put_file(self, in_path, out_path): self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) if not os.path.exists(in_path): @@ -210,6 +213,7 @@ class Connection(ConnectionBase): traceback.print_exc() raise AnsibleError("failed to transfer file to %s" % out_path) + @ensure_connect def fetch_file(self, in_path, out_path): out_path = out_path.replace('\\', '/') self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) From bce281014cfc8aaa2675c129ca3117a360041e5c Mon Sep 17 00:00:00 2001 From: Matt Martz Date: Thu, 4 Jun 2015 13:27:18 -0500 Subject: [PATCH 2/2] Decorate the ConnectionBase methods, switch to calling super from individual connection classes --- lib/ansible/plugins/connections/__init__.py | 3 +++ lib/ansible/plugins/connections/local.py | 7 +++++++ lib/ansible/plugins/connections/paramiko_ssh.py | 11 +++++++---- lib/ansible/plugins/connections/ssh.py | 13 +++++++++---- lib/ansible/plugins/connections/winrm.py | 10 ++++++---- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/lib/ansible/plugins/connections/__init__.py b/lib/ansible/plugins/connections/__init__.py index da0775530d..1d3a2bdeed 100644 --- a/lib/ansible/plugins/connections/__init__.py +++ b/lib/ansible/plugins/connections/__init__.py @@ -92,16 +92,19 @@ class ConnectionBase(with_metaclass(ABCMeta, object)): """Connect to the host we've been initialized with""" pass + @ensure_connect @abstractmethod def exec_command(self, cmd, tmp_path, executable=None, in_data=None): """Run a command on the remote host""" pass + @ensure_connect @abstractmethod def put_file(self, in_path, out_path): """Transfer a file from local to remote""" pass + @ensure_connect @abstractmethod def fetch_file(self, in_path, out_path): """Fetch a file from remote to local""" diff --git a/lib/ansible/plugins/connections/local.py b/lib/ansible/plugins/connections/local.py index 1dc6076b0d..85bc51de0a 100644 --- a/lib/ansible/plugins/connections/local.py +++ b/lib/ansible/plugins/connections/local.py @@ -49,6 +49,8 @@ class Connection(ConnectionBase): def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): ''' run a command on the local host ''' + super(Connection, self).exec_command(cmd, tmp_path, executable=executable, in_data=in_data) + debug("in local.exec_command()") # su requires to be run from a terminal, and therefore isn't supported here (yet?) #if self._connection_info.su: @@ -108,6 +110,8 @@ class Connection(ConnectionBase): def put_file(self, in_path, out_path): ''' transfer a file from local to local ''' + super(Connection, self).put_file(in_path, out_path) + #vvv("PUT {0} TO {1}".format(in_path, out_path), host=self.host) self._display.vvv("{0} PUT {1} TO {2}".format(self._connection_info.remote_addr, in_path, out_path)) if not os.path.exists(in_path): @@ -123,6 +127,9 @@ class Connection(ConnectionBase): def fetch_file(self, in_path, out_path): ''' fetch a file from local to local -- for copatibility ''' + + super(Connection, self).fetch_file(in_path, out_path) + #vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self.host) self._display.vvv("{0} FETCH {1} TO {2}".format(self._connection_info.remote_addr, in_path, out_path)) self.put_file(in_path, out_path) diff --git a/lib/ansible/plugins/connections/paramiko_ssh.py b/lib/ansible/plugins/connections/paramiko_ssh.py index 8beaecf492..5a5259c5fc 100644 --- a/lib/ansible/plugins/connections/paramiko_ssh.py +++ b/lib/ansible/plugins/connections/paramiko_ssh.py @@ -41,7 +41,7 @@ from binascii import hexlify from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase, ensure_connect +from ansible.plugins.connections import ConnectionBase from ansible.utils.path import makedirs_safe AUTHENTICITY_MSG=""" @@ -189,10 +189,11 @@ class Connection(ConnectionBase): return ssh - @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): ''' run a command on the remote host ''' + super(Connection, self).exec_command(cmd, tmp_path, executable=executable, in_data=in_data) + if in_data: raise AnsibleError("Internal Error: this module does not support optimized module pipelining") @@ -250,10 +251,11 @@ class Connection(ConnectionBase): return (chan.recv_exit_status(), '', no_prompt_out + stdout, no_prompt_out + stderr) - @ensure_connect def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' + super(Connection, self).put_file(in_path, out_path) + self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) if not os.path.exists(in_path): @@ -278,10 +280,11 @@ class Connection(ConnectionBase): result = SFTP_CONNECTION_CACHE[cache_key] = self._connect().ssh.open_sftp() return result - @ensure_connect def fetch_file(self, in_path, out_path): ''' save a remote file to the specified path ''' + super(Connection, self).fetch_file(in_path, out_path) + self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) try: diff --git a/lib/ansible/plugins/connections/ssh.py b/lib/ansible/plugins/connections/ssh.py index 5a435093d0..e2251ca5b0 100644 --- a/lib/ansible/plugins/connections/ssh.py +++ b/lib/ansible/plugins/connections/ssh.py @@ -34,7 +34,7 @@ from hashlib import sha1 from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase, ensure_connect +from ansible.plugins.connections import ConnectionBase class Connection(ConnectionBase): @@ -270,10 +270,11 @@ class Connection(ConnectionBase): self._display.vvv("EXEC previous known host file not found for {0}".format(host)) return True - @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): ''' run a command on the remote host ''' + super(Connection, self).exec_command(cmd, tmp_path, executable=executable, in_data=in_data) + ssh_cmd = self._password_cmd() ssh_cmd += ("ssh", "-C") if not in_data: @@ -392,9 +393,11 @@ class Connection(ConnectionBase): return (p.returncode, '', no_prompt_out + stdout, no_prompt_err + stderr) - @ensure_connect def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' + + super(Connection, self).put_file(in_path, out_path) + self._display.vvv("PUT {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr) if not os.path.exists(in_path): raise AnsibleFileNotFound("file or module does not exist: {0}".format(in_path)) @@ -428,9 +431,11 @@ class Connection(ConnectionBase): if returncode != 0: raise AnsibleError("failed to transfer file to {0}:\n{1}\n{2}".format(out_path, stdout, stderr)) - @ensure_connect def fetch_file(self, in_path, out_path): ''' fetch a file from remote to local ''' + + super(Connection, self).fetch_file(in_path, out_path) + self._display.vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr) cmd = self._password_cmd() diff --git a/lib/ansible/plugins/connections/winrm.py b/lib/ansible/plugins/connections/winrm.py index ee28749189..2bc1ee0053 100644 --- a/lib/ansible/plugins/connections/winrm.py +++ b/lib/ansible/plugins/connections/winrm.py @@ -42,7 +42,7 @@ except ImportError: from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase, ensure_connect +from ansible.plugins.connections import ConnectionBase from ansible.plugins import shell_loader from ansible.utils.path import makedirs_safe @@ -152,8 +152,8 @@ class Connection(ConnectionBase): self.protocol = self._winrm_connect() return self - @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): + super(Connection, self).exec_command(cmd, tmp_path, executable=executable, in_data,in_data) cmd = cmd.encode('utf-8') cmd_parts = shlex.split(cmd, posix=False) @@ -174,8 +174,9 @@ class Connection(ConnectionBase): raise AnsibleError("failed to exec cmd %s" % cmd) return (result.status_code, '', result.std_out.encode('utf-8'), result.std_err.encode('utf-8')) - @ensure_connect def put_file(self, in_path, out_path): + super(Connection, self).put_file(in_path, out_path) + self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) if not os.path.exists(in_path): raise AnsibleFileNotFound("file or module does not exist: %s" % in_path) @@ -213,8 +214,9 @@ class Connection(ConnectionBase): traceback.print_exc() raise AnsibleError("failed to transfer file to %s" % out_path) - @ensure_connect def fetch_file(self, in_path, out_path): + super(Connection, self).fetch_file(in_path, out_path) + out_path = out_path.replace('\\', '/') self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) buffer_size = 2**19 # 0.5MB chunks