diff --git a/lib/ansible/plugins/connection/ssh.py b/lib/ansible/plugins/connection/ssh.py index 9b7f0e8d3c..6ad68957ee 100644 --- a/lib/ansible/plugins/connection/ssh.py +++ b/lib/ansible/plugins/connection/ssh.py @@ -29,6 +29,7 @@ import socket import subprocess import time +from functools import wraps from ansible import constants as C from ansible.compat import selectors from ansible.compat.six import PY3, text_type, binary_type @@ -51,6 +52,54 @@ except ImportError: SSHPASS_AVAILABLE = None +def _ssh_retry(func): + """ + Decorator to retry ssh/scp/sftp in the case of a connection failure + + Will retry if: + * an exception is caught + * ssh returns 255 + Will not retry if + * remaining_tries is <2 + * retries limit reached + """ + @wraps(func) + def wrapped(self, *args, **kwargs): + remaining_tries = int(C.ANSIBLE_SSH_RETRIES) + 1 + cmd_summary = "%s..." % args[0] + for attempt in range(remaining_tries): + try: + return_tuple = func(self, *args, **kwargs) + display.vvv(return_tuple, host=self.host) + # 0 = success + # 1-254 = remote command return code + # 255 = failure from the ssh command itself + if return_tuple[0] != 255: + break + else: + raise AnsibleConnectionFailure("Failed to connect to the host via ssh: %s" % to_native(return_tuple[2])) + except (AnsibleConnectionFailure, Exception) as e: + if attempt == remaining_tries - 1: + raise + else: + pause = 2 ** attempt - 1 + if pause > 30: + pause = 30 + + if isinstance(e, AnsibleConnectionFailure): + msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt, cmd_summary, pause) + else: + msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt, e, cmd_summary, pause) + + display.vv(msg, host=self.host) + + time.sleep(pause) + continue + + return return_tuple + return wrapped + + class Connection(ConnectionBase): ''' ssh based connections ''' @@ -352,6 +401,7 @@ class Connection(ConnectionBase): return b''.join(output), remainder + @_ssh_retry def _run(self, cmd, in_data, sudoable=True, checkrc=True): ''' Starts the command and communicates with it until it ends. @@ -618,28 +668,6 @@ class Connection(ConnectionBase): return (p.returncode, b_stdout, b_stderr) - def _exec_command(self, cmd, in_data=None, sudoable=True): - ''' run a command on the remote host ''' - - super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable) - - display.vvv(u"ESTABLISH SSH CONNECTION FOR USER: {0}".format(self._play_context.remote_user), host=self._play_context.remote_addr) - - - # we can only use tty when we are not pipelining the modules. piping - # data into /usr/bin/python inside a tty automatically invokes the - # python interactive-mode but the modules are not compatible with the - # interactive-mode ("unexpected indent" mainly because of empty lines) - if not in_data and sudoable: - args = ('ssh', '-tt', self.host, cmd) - else: - args = ('ssh', self.host, cmd) - - cmd = self._build_command(*args) - (returncode, stdout, stderr) = self._run(cmd, in_data, sudoable=sudoable) - - return (returncode, stdout, stderr) - def _file_transport_command(self, in_path, out_path, sftp_action): # scp and sftp require square brackets for IPv6 addresses, but # accept them for hostnames and IPv4 addresses too. @@ -674,7 +702,6 @@ class Connection(ConnectionBase): methods = ['sftp'] success = False - res = None for method in methods: returncode = stdout = stderr = None if method == 'sftp': @@ -693,77 +720,58 @@ class Connection(ConnectionBase): if sftp_action == 'get': # we pass sudoable=False to disable pty allocation, which # would end up mixing stdout/stderr and screwing with newlines - (returncode, stdout, stderr) = self._exec_command('dd if=%s bs=%s' % (in_path, BUFSIZE), sudoable=False) + (returncode, stdout, stderr) = self.exec_command('dd if=%s bs=%s' % (in_path, BUFSIZE), sudoable=False) out_file = open(to_bytes(out_path, errors='surrogate_or_strict'), 'wb+') out_file.write(stdout) out_file.close() else: in_data = open(to_bytes(in_path, errors='surrogate_or_strict'), 'rb').read() in_data = to_bytes(in_data, nonstring='passthru') - (returncode, stdout, stderr) = self._exec_command('dd of=%s bs=%s' % (out_path, BUFSIZE), in_data=in_data) + (returncode, stdout, stderr) = self.exec_command('dd of=%s bs=%s' % (out_path, BUFSIZE), in_data=in_data) # Check the return code and rollover to next method if failed if returncode == 0: - success = True - break + return (returncode, stdout, stderr) else: # If not in smart mode, the data will be printed by the raise below if len(methods) > 1: display.warning(msg='%s transfer mechanism failed on %s. Use ANSIBLE_DEBUG=1 to see detailed information' % (method, host)) display.debug(msg='%s' % to_native(stdout)) display.debug(msg='%s' % to_native(stderr)) - res = (returncode, stdout, stderr) - if not success: - raise AnsibleError("failed to transfer file {0} to {1}:\n{2}\n{3}"\ - .format(to_native(in_path), to_native(out_path), to_native(res[1]), to_native(res[2]))) + if returncode == 255: + raise AnsibleConnectionFailure("Failed to connect to the host via %s: %s" % (method, to_native(stderr))) + else: + raise AnsibleError("failed to transfer file to {0} {1}:\n{2}\n{3}"\ + .format(to_native(in_path), to_native(out_path), to_native(stdout), to_native(stderr))) # # Main public methods # - def exec_command(self, *args, **kwargs): - """ - Wrapper around _exec_command to retry in the case of an ssh failure + def exec_command(self, cmd, in_data=None, sudoable=True): + ''' run a command on the remote host ''' - Will retry if: - * an exception is caught - * ssh returns 255 - Will not retry if - * remaining_tries is <2 - * retries limit reached - """ + super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable) - remaining_tries = int(C.ANSIBLE_SSH_RETRIES) + 1 - cmd_summary = "%s..." % args[0] - for attempt in range(remaining_tries): - try: - return_tuple = self._exec_command(*args, **kwargs) - # 0 = success - # 1-254 = remote command return code - # 255 = failure from the ssh command itself - if return_tuple[0] != 255: - break - else: - raise AnsibleConnectionFailure("Failed to connect to the host via ssh: %s" % to_native(return_tuple[2])) - except (AnsibleConnectionFailure, Exception) as e: - if attempt == remaining_tries - 1: - raise - else: - pause = 2 ** attempt - 1 - if pause > 30: - pause = 30 + display.vvv(u"ESTABLISH SSH CONNECTION FOR USER: {0}".format(self._play_context.remote_user), host=self._play_context.remote_addr) - if isinstance(e, AnsibleConnectionFailure): - msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt, cmd_summary, pause) - else: - msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt, e, cmd_summary, pause) - display.vv(msg, host=self.host) + # we can only use tty when we are not pipelining the modules. piping + # data into /usr/bin/python inside a tty automatically invokes the + # python interactive-mode but the modules are not compatible with the + # interactive-mode ("unexpected indent" mainly because of empty lines) - time.sleep(pause) - continue + ssh_executable = self._play_context.ssh_executable - return return_tuple + if not in_data and sudoable: + args = (ssh_executable, '-tt', self.host, cmd) + else: + args = (ssh_executable, self.host, cmd) + + cmd = self._build_command(*args) + (returncode, stdout, stderr) = self._run(cmd, in_data, sudoable=sudoable) + + return (returncode, stdout, stderr) def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' @@ -774,7 +782,7 @@ class Connection(ConnectionBase): if not os.path.exists(to_bytes(in_path, errors='surrogate_or_strict')): raise AnsibleFileNotFound("file or module does not exist: {0}".format(to_native(in_path))) - self._file_transport_command(in_path, out_path, 'put') + return self._file_transport_command(in_path, out_path, 'put') def fetch_file(self, in_path, out_path): ''' fetch a file from remote to local ''' @@ -782,7 +790,7 @@ class Connection(ConnectionBase): super(Connection, self).fetch_file(in_path, out_path) display.vvv(u"FETCH {0} TO {1}".format(in_path, out_path), host=self.host) - self._file_transport_command(in_path, out_path, 'get') + return self._file_transport_command(in_path, out_path, 'get') def reset(self): # If we have a persistent ssh connection (ControlPersist), we can ask it to stop listening. diff --git a/test/units/plugins/connection/test_ssh.py b/test/units/plugins/connection/test_ssh.py index a9e2107fb0..7fa9ba5db3 100644 --- a/test/units/plugins/connection/test_ssh.py +++ b/test/units/plugins/connection/test_ssh.py @@ -25,7 +25,7 @@ from io import StringIO import pytest from ansible.compat.tests import unittest -from ansible.compat.tests.mock import patch, MagicMock +from ansible.compat.tests.mock import patch, MagicMock, PropertyMock from ansible import constants as C from ansible.compat.selectors import SelectorKey, EVENT_READ @@ -72,7 +72,7 @@ class TestConnectionBaseClass(unittest.TestCase): conn = ssh.Connection(pc, new_stdin) conn._build_command('ssh') - def test_plugins_connection_ssh__exec_command(self): + def test_plugins_connection_ssh_exec_command(self): pc = PlayContext() new_stdin = StringIO() conn = ssh.Connection(pc, new_stdin) @@ -82,8 +82,8 @@ class TestConnectionBaseClass(unittest.TestCase): conn._run = MagicMock() conn._run.return_value = (0, 'stdout', 'stderr') - res, stdout, stderr = conn._exec_command('ssh') - res, stdout, stderr = conn._exec_command('ssh', 'this is some data') + res, stdout, stderr = conn.exec_command('ssh') + res, stdout, stderr = conn.exec_command('ssh', 'this is some data') def test_plugins_connection_ssh__examine_output(self): pc = PlayContext() @@ -193,36 +193,8 @@ class TestConnectionBaseClass(unittest.TestCase): self.assertTrue(conn._flags['become_nopasswd_error']) @patch('time.sleep') - def test_plugins_connection_ssh_exec_command(self, mock_sleep): - pc = PlayContext() - new_stdin = StringIO() - conn = ssh.Connection(pc, new_stdin) - conn._build_command = MagicMock() - conn._exec_command = MagicMock() - - C.ANSIBLE_SSH_RETRIES = 9 - - # test a regular, successful execution - conn._exec_command.return_value = (0, b'stdout', b'') - res = conn.exec_command('ssh', 'some data') - self.assertEquals(res, (0, b'stdout', b''), msg='exec_command did not return what the _exec_command helper returned') - - # test a retry, followed by success - conn._exec_command.return_value = None - conn._exec_command.side_effect = [(255, '', ''), (0, b'stdout', b'')] - res = conn.exec_command('ssh', 'some data') - self.assertEquals(res, (0, b'stdout', b''), msg='exec_command did not return what the _exec_command helper returned') - - # test multiple failures - conn._exec_command.side_effect = [(255, b'', b'')] * 10 - self.assertRaises(AnsibleConnectionFailure, conn.exec_command, 'ssh', 'some data') - - # test other failure from exec_command - conn._exec_command.side_effect = [Exception('bad')] * 10 - self.assertRaises(Exception, conn.exec_command, 'ssh', 'some data') - @patch('os.path.exists') - def test_plugins_connection_ssh_put_file(self, mock_ospe): + def test_plugins_connection_ssh_put_file(self, mock_ospe, mock_sleep): pc = PlayContext() new_stdin = StringIO() conn = ssh.Connection(pc, new_stdin) @@ -234,6 +206,8 @@ class TestConnectionBaseClass(unittest.TestCase): conn._run.return_value = (0, '', '') conn.host = "some_host" + C.ANSIBLE_SSH_RETRIES = 9 + # Test with C.DEFAULT_SCP_IF_SSH set to smart # Test when SFTP works C.DEFAULT_SCP_IF_SSH = 'smart' @@ -276,7 +250,8 @@ class TestConnectionBaseClass(unittest.TestCase): conn._run.return_value = (0, 'stdout', '') self.assertRaises(AnsibleFileNotFound, conn.put_file, '/path/to/bad/file', '/remote/path/to/file') - def test_plugins_connection_ssh_fetch_file(self): + @patch('time.sleep') + def test_plugins_connection_ssh_fetch_file(self, mock_sleep): pc = PlayContext() new_stdin = StringIO() conn = ssh.Connection(pc, new_stdin) @@ -287,6 +262,8 @@ class TestConnectionBaseClass(unittest.TestCase): conn._run.return_value = (0, '', '') conn.host = "some_host" + C.ANSIBLE_SSH_RETRIES = 9 + # Test with C.DEFAULT_SCP_IF_SSH set to smart # Test when SFTP works C.DEFAULT_SCP_IF_SSH = 'smart' @@ -535,3 +512,120 @@ class TestSSHConnectionRun(object): assert self.mock_selector.register.called is True assert self.mock_selector.register.call_count == 2 assert self.conn._send_initial_data.called is False + + +@pytest.mark.usefixtures('mock_run_env') +class TestSSHConnectionRetries(object): + def test_retry_then_success(self): + self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"] + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [] + ] + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + + return_code, b_stdout, b_stderr = self.conn.exec_command('ssh', 'some data') + assert return_code == 0 + assert b_stdout == b'my_stdout\nsecond_line' + assert b_stderr == b'my_stderr' + + @patch('time.sleep') + def test_multiple_failures(self, mock_sleep): + C.ANSIBLE_SSH_RETRIES = 9 + + self.mock_popen_res.stdout.read.side_effect = [b""] * 11 + self.mock_popen_res.stderr.read.side_effect = [b""] * 11 + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 30) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + ] * 10 + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + + pytest.raises(AnsibleConnectionFailure, self.conn.exec_command, 'ssh', 'some data') + assert self.mock_popen.call_count == 10 + + @patch('time.sleep') + def test_abitrary_exceptions(self, mock_sleep): + C.ANSIBLE_SSH_RETRIES = 9 + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + + self.mock_popen.side_effect = [Exception('bad')] * 10 + pytest.raises(Exception, self.conn.exec_command, 'ssh', 'some data') + assert self.mock_popen.call_count == 10 + + @patch('time.sleep') + @patch('ansible.plugins.connection.ssh.os') + def test_put_file_retries(self, os_mock, time_mock): + os_mock.path.exists.return_value = True + + self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"] + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [] + ] + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + + return_code, b_stdout, b_stderr = self.conn.put_file('/path/to/in/file', '/path/to/dest/file') + assert return_code == 0 + assert b_stdout == b"my_stdout\nsecond_line" + assert b_stderr == b"my_stderr" + assert self.mock_popen.call_count == 2 + + @patch('time.sleep') + @patch('ansible.plugins.connection.ssh.os') + def test_fetch_file_retries(self, os_mock, time_mock): + os_mock.path.exists.return_value = True + + self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"] + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [] + ] + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + + return_code, b_stdout, b_stderr = self.conn.fetch_file('/path/to/in/file', '/path/to/dest/file') + assert return_code == 0 + assert b_stdout == b"my_stdout\nsecond_line" + assert b_stderr == b"my_stderr" + assert self.mock_popen.call_count == 2