From d834412eadccca6428501544f398d1bb59084e98 Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Fri, 12 May 2017 09:13:51 -0700 Subject: [PATCH] Fix for persistent connection plugin on Python3 (#24431) Fix for persistent connection plugin on Python3. Note that fixes are also needed to each terminal plugin. This PR only fixes the ios terminal (as proof that this approach is workable.) Future PRs can address the other terminal types. * On Python3, pickle needs to work with byte strings, not text strings. * Set the pickle protocol version to 0 because we're using a pty to feed data to the connection plugin. A pty can't have control characters. So we have to send ascii only. That means only using protocol=0 for pickling the data. * ansible-connection isn't being used with py3 in the bug but it needs several changes to work with python3. * In python3, closing the pty too early causes no data to be sent. So leave stdin open until after we finish with the ansible-connection process. * Fix typo using traceback.format_exc() * Cleanup unnecessary StringIO, BytesIO, and to_bytes calls * Modify the network_cli and terminal plugins for py3 compat. Lots of mixing of text and byte strings that needs to be straightened out to be compatible with python3 * Documentation for the bytes<=>text strategy for terminal plugins * Update unittests for more bytes-oriented internals Fixes #24355 --- bin/ansible-connection | 58 ++++++++--------- lib/ansible/plugins/connection/network_cli.py | 62 ++++++++++--------- lib/ansible/plugins/connection/persistent.py | 16 +++-- lib/ansible/plugins/terminal/__init__.py | 41 +++++++++--- lib/ansible/plugins/terminal/ios.py | 51 +++++++-------- .../plugins/connection/test_network_cli.py | 30 ++++----- 6 files changed, 146 insertions(+), 112 deletions(-) diff --git a/bin/ansible-connection b/bin/ansible-connection index 10449115f6..676cd70223 100755 --- a/bin/ansible-connection +++ b/bin/ansible-connection @@ -45,7 +45,8 @@ from io import BytesIO from ansible import constants as C from ansible.module_utils._text import to_bytes, to_native -from ansible.module_utils.six.moves import cPickle, StringIO +from ansible.module_utils.six import PY3 +from ansible.module_utils.six.moves import cPickle from ansible.playbook.play_context import PlayContext from ansible.plugins import connection_loader from ansible.utils.path import unfrackpath, makedirs_safe @@ -73,11 +74,11 @@ def do_fork(): sys.exit(0) if C.DEFAULT_LOG_PATH != '': - out_file = file(C.DEFAULT_LOG_PATH, 'a+') - err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0) + out_file = open(C.DEFAULT_LOG_PATH, 'ab+') + err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0) else: - out_file = file('/dev/null', 'a+') - err_file = file('/dev/null', 'a+', 0) + out_file = open('/dev/null', 'ab+') + err_file = open('/dev/null', 'ab+', 0) os.dup2(out_file.fileno(), sys.stdout.fileno()) os.dup2(err_file.fileno(), sys.stderr.fileno()) @@ -90,7 +91,7 @@ def do_fork(): sys.exit(1) def send_data(s, data): - packed_len = struct.pack('!Q',len(data)) + packed_len = struct.pack('!Q', len(data)) return s.sendall(packed_len + data) def recv_data(s): @@ -101,7 +102,7 @@ def recv_data(s): if not d: return None data += d - data_len = struct.unpack('!Q',data[:header_len])[0] + data_len = struct.unpack('!Q', data[:header_len])[0] data = data[header_len:] while len(data) < data_len: d = s.recv(data_len - len(data)) @@ -211,11 +212,9 @@ class Server(): pass elif data.startswith(b'CONTEXT: '): display.display("socket operation is CONTEXT", log_only=True) - pc_data = data.split(b'CONTEXT: ')[1] + pc_data = data.split(b'CONTEXT: ', 1)[1] - src = StringIO(pc_data) - pc_data = cPickle.load(src) - src.close() + pc_data = cPickle.loads(pc_data) pc = PlayContext() pc.deserialize(pc_data) @@ -234,12 +233,12 @@ class Server(): display.display("socket operation completed with rc %s" % rc, log_only=True) - send_data(s, to_bytes(str(rc))) + send_data(s, to_bytes(rc)) send_data(s, to_bytes(stdout)) send_data(s, to_bytes(stderr)) s.close() except Exception as e: - display.display(traceback.format_exec(), log_only=True) + display.display(traceback.format_exc(), log_only=True) finally: # when done, close the connection properly and cleanup # the socket file so it can be recreated @@ -254,21 +253,25 @@ class Server(): os.remove(self.path) def main(): + # Need stdin as a byte stream + if PY3: + stdin = sys.stdin.buffer + else: + stdin = sys.stdin try: # read the play context data via stdin, which means depickling it # FIXME: as noted above, we will probably need to deserialize the # connection loader here as well at some point, otherwise this # won't find role- or playbook-based connection plugins - cur_line = sys.stdin.readline() - init_data = '' - while cur_line.strip() != '#END_INIT#': - if cur_line == '': - raise Exception("EOL found before init data was complete") + cur_line = stdin.readline() + init_data = b'' + while cur_line.strip() != b'#END_INIT#': + if cur_line == b'': + raise Exception("EOF found before init data was complete") init_data += cur_line - cur_line = sys.stdin.readline() - src = BytesIO(to_bytes(init_data)) - pc_data = cPickle.load(src) + cur_line = stdin.readline() + pc_data = cPickle.loads(init_data) pc = PlayContext() pc.deserialize(pc_data) @@ -319,10 +322,10 @@ def main(): # the connection will timeout here. Need to make this more resilient. rc = 0 while rc == 0: - data = sys.stdin.readline() - if data == '': + data = stdin.readline() + if data == b'': break - if data.strip() == '': + if data.strip() == b'': continue sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) attempts = 1 @@ -342,11 +345,10 @@ def main(): # send the play_context back into the connection so the connection # can handle any privilege escalation activities - pc_data = 'CONTEXT: %s' % src.getvalue() - send_data(sf, to_bytes(pc_data)) - src.close() + pc_data = b'CONTEXT: %s' % init_data + send_data(sf, pc_data) - send_data(sf, to_bytes(data.strip())) + send_data(sf, data.strip()) rc = int(recv_data(sf), 10) stdout = recv_data(sf) diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py index 137884104f..6f004c00fc 100644 --- a/lib/ansible/plugins/connection/network_cli.py +++ b/lib/ansible/plugins/connection/network_cli.py @@ -18,17 +18,18 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import re -import socket import json -import signal -import datetime -import traceback import logging +import re +import signal +import socket +import traceback +from collections import Sequence from ansible import constants as C from ansible.errors import AnsibleConnectionFailure -from ansible.module_utils.six.moves import StringIO +from ansible.module_utils.six import BytesIO, binary_type, text_type +from ansible.module_utils._text import to_bytes, to_text from ansible.plugins import terminal_loader from ansible.plugins.connection import ensure_connect from ansible.plugins.connection.paramiko_ssh import Connection as _Connection @@ -113,7 +114,7 @@ class Connection(_Connection): self._terminal.on_authorize(passwd=auth_pass) display.display('shell successfully opened', log_only=True) - return (0, 'ok', '') + return (0, b'ok', b'') def close(self): display.display('closing connection', log_only=True) @@ -131,11 +132,11 @@ class Connection(_Connection): self._shell.close() self._shell = None - return (0, 'ok', '') + return (0, b'ok', b'') def receive(self, obj=None): """Handles receiving of output from command""" - recv = StringIO() + recv = BytesIO() handled = False self._matched_prompt = None @@ -162,30 +163,30 @@ class Connection(_Connection): try: command = obj['command'] self._history.append(command) - self._shell.sendall('%s\r' % command) + self._shell.sendall(b'%s\r' % command) if obj.get('sendonly'): return return self.receive(obj) - except (socket.timeout, AttributeError) as exc: + except (socket.timeout, AttributeError): display.display(traceback.format_exc(), log_only=True) raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip()) def _strip(self, data): """Removes ANSI codes from device response""" for regex in self._terminal.ansi_re: - data = regex.sub('', data) + data = regex.sub(b'', data) return data def _handle_prompt(self, resp, obj): """Matches the command prompt and responds""" - if not isinstance(obj['prompt'], list): + if isinstance(obj, (binary_type, text_type)) or not isinstance(obj['prompt'], Sequence): obj['prompt'] = [obj['prompt']] prompts = [re.compile(r, re.I) for r in obj['prompt']] answer = obj['answer'] for regex in prompts: match = regex.search(resp) if match: - self._shell.sendall('%s\r' % answer) + self._shell.sendall(b'%s\r' % answer) return True def _sanitize(self, resp, obj=None): @@ -196,7 +197,7 @@ class Connection(_Connection): if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line: continue cleaned.append(line) - return str("\n".join(cleaned)).strip() + return b"\n".join(cleaned).strip() def _find_prompt(self, response): """Searches the buffered response for a matching command prompt""" @@ -225,9 +226,9 @@ class Connection(_Connection): def exec_command(self, cmd): """Executes the cmd on in the shell and returns the output - The method accepts two forms of cmd. The first form is as a + The method accepts two forms of cmd. The first form is as a byte string that represents the command to be executed in the shell. The - second form is as a JSON string with additional keyword. + second form is as a utf8 JSON byte string with additional keywords. Keywords supported for cmd: * command - the command string to execute @@ -235,28 +236,30 @@ class Connection(_Connection): * answer - the string to respond to the prompt with * sendonly - bool to disable waiting for response - :arg cmd: the string that represents the command to be executed - which can be a single command or a json encoded string + :arg cmd: the byte string that represents the command to be executed + which can be a single command or a json encoded string. :returns: a tuple of (return code, stdout, stderr). The return - code is an integer and stdout and stderr are strings + code is an integer and stdout and stderr are byte strings """ try: - obj = json.loads(cmd) + obj = json.loads(to_text(cmd, errors='surrogate_or_strict')) + obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items()) except (ValueError, TypeError): - obj = {'command': str(cmd).strip()} + obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')} - if obj['command'] == 'close_shell()': + if obj['command'] == b'close_shell()': return self.close_shell() - elif obj['command'] == 'open_shell()': + elif obj['command'] == b'open_shell()': return self.open_shell() - elif obj['command'] == 'prompt()': - return (0, self._matched_prompt, '') + elif obj['command'] == b'prompt()': + return (0, self._matched_prompt, b'') try: if self._shell is None: self.open_shell() except AnsibleConnectionFailure as exc: - return (1, '', str(exc)) + # FIXME: Feels like we should raise this rather than return it + return (1, b'', to_bytes(exc)) try: if not signal.getsignal(signal.SIGALRM): @@ -264,6 +267,7 @@ class Connection(_Connection): signal.alarm(self._play_context.timeout) out = self.send(obj) signal.alarm(0) - return (0, out, '') + return (0, out, b'') except (AnsibleConnectionFailure, ValueError) as exc: - return (1, '', str(exc)) + # FIXME: Feels like we should raise this rather than return it + return (1, b'', to_bytes(exc)) diff --git a/lib/ansible/plugins/connection/persistent.py b/lib/ansible/plugins/connection/persistent.py index 8465d66268..fc210a9766 100644 --- a/lib/ansible/plugins/connection/persistent.py +++ b/lib/ansible/plugins/connection/persistent.py @@ -24,7 +24,7 @@ import subprocess import sys from ansible.module_utils._text import to_bytes -from ansible.module_utils.six.moves import cPickle, StringIO +from ansible.module_utils.six.moves import cPickle from ansible.plugins.connection import ConnectionBase try: @@ -52,16 +52,20 @@ class Connection(ConnectionBase): stdin = os.fdopen(master, 'wb', 0) os.close(slave) - src = StringIO() - cPickle.dump(self._play_context.serialize(), src) - stdin.write(src.getvalue()) - src.close() + # Need to force a protocol that is compatible with both py2 and py3. + # That would be protocol=2 or less. + # Also need to force a protocol that excludes certain control chars as + # stdin in this case is a pty and control chars will cause problems. + # that means only protocol=0 will work. + src = cPickle.dumps(self._play_context.serialize(), protocol=0) + stdin.write(src) stdin.write(b'\n#END_INIT#\n') stdin.write(to_bytes(action)) stdin.write(b'\n\n') - stdin.close() + (stdout, stderr) = p.communicate() + stdin.close() return (p.returncode, stdout, stderr) diff --git a/lib/ansible/plugins/terminal/__init__.py b/lib/ansible/plugins/terminal/__init__.py index e8e04884eb..e52ae62273 100644 --- a/lib/ansible/plugins/terminal/__init__.py +++ b/lib/ansible/plugins/terminal/__init__.py @@ -30,33 +30,54 @@ from ansible.module_utils.six import with_metaclass class TerminalBase(with_metaclass(ABCMeta, object)): ''' A base class for implementing cli connections + + .. note:: Unlike most of Ansible, nearly all strings in + :class:`TerminalBase` plugins are byte strings. This is because of + how close to the underlying platform these plugins operate. Remember + to mark literal strings as byte string (``b"string"``) and to use + :func:`~ansible.module_utils._text.to_bytes` and + :func:`~ansible.module_utils._text.to_text` to avoid unexpected + problems. ''' - # compiled regular expression as stdout + #: compiled bytes regular expressions as stdout terminal_stdout_re = [] - # compiled regular expression as stderr + #: compiled bytes regular expressions as stderr terminal_stderr_re = [] - # copiled regular expression to remove ANSI codes + #: compiled bytes regular expressions to remove ANSI codes ansi_re = [ - re.compile(r'(\x1b\[\?1h\x1b=)'), - re.compile(r'\x08.') + re.compile(br'(\x1b\[\?1h\x1b=)'), + re.compile(br'\x08.') ] def __init__(self, connection): self._connection = connection def _exec_cli_command(self, cmd, check_rc=True): - """Executes a CLI command on the device""" + """ + Executes a CLI command on the device + + :arg cmd: Byte string consisting of the command to execute + :kwarg check_rc: If True, the default, raise an + :exc:`AnsibleConnectionFailure` if the return code from the + command is nonzero + :returns: A tuple of return code, stdout, and stderr from running the + command. stdout and stderr are both byte strings. + """ rc, out, err = self._connection.exec_command(cmd) if check_rc and rc != 0: raise AnsibleConnectionFailure(err) return rc, out, err def _get_prompt(self): - """ Returns the current prompt from the device""" - for cmd in ['\n', 'prompt()']: + """ + Returns the current prompt from the device + + :returns: A byte string of the prompt + """ + for cmd in (b'\n', b'prompt()'): rc, out, err = self._exec_cli_command(cmd) return out @@ -82,6 +103,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)): def on_authorize(self, passwd=None): """Called when privilege escalation is requested + :kwarg passwd: String containing the password + This method is called when the privilege is requested to be elevated in the play context by setting become to True. It is the responsibility of the terminal plugin to actually do the privilege escalation such @@ -94,6 +117,6 @@ class TerminalBase(with_metaclass(ABCMeta, object)): This method is called when the privilege changed from escalated (become=True) to non escalated (become=False). It is the responsibility - of the this method to actually perform the deauthorization procedure + of this method to actually perform the deauthorization procedure """ pass diff --git a/lib/ansible/plugins/terminal/ios.py b/lib/ansible/plugins/terminal/ios.py index 4ce5dc9406..fb01e79fa3 100644 --- a/lib/ansible/plugins/terminal/ios.py +++ b/lib/ansible/plugins/terminal/ios.py @@ -19,49 +19,52 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import re import json +import re -from ansible.plugins.terminal import TerminalBase from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_bytes +from ansible.plugins.terminal import TerminalBase class TerminalModule(TerminalBase): terminal_stdout_re = [ - re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"), - re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$") + re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"), + re.compile(br"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$") ] terminal_stderr_re = [ - re.compile(r"% ?Error"), - #re.compile(r"^% \w+", re.M), - re.compile(r"% ?Bad secret"), - re.compile(r"invalid input", re.I), - re.compile(r"(?:incomplete|ambiguous) command", re.I), - re.compile(r"connection timed out", re.I), - re.compile(r"[^\r\n]+ not found", re.I), - re.compile(r"'[^']' +returned error code: ?\d+"), + re.compile(br"% ?Error"), + #re.compile(br"^% \w+", re.M), + re.compile(br"% ?Bad secret"), + re.compile(br"invalid input", re.I), + re.compile(br"(?:incomplete|ambiguous) command", re.I), + re.compile(br"connection timed out", re.I), + re.compile(br"[^\r\n]+ not found", re.I), + re.compile(br"'[^']' +returned error code: ?\d+"), ] def on_open_shell(self): try: - for cmd in ['terminal length 0', 'terminal width 512']: + for cmd in (b'terminal length 0', b'terminal width 512'): self._exec_cli_command(cmd) except AnsibleConnectionFailure: raise AnsibleConnectionFailure('unable to set terminal parameters') def on_authorize(self, passwd=None): - if self._get_prompt().endswith('#'): + if self._get_prompt().endswith(b'#'): return - cmd = {'command': 'enable'} + cmd = {u'command': u'enable'} if passwd: - cmd['prompt'] = r"[\r\n]?password: $" - cmd['answer'] = passwd + # Note: python-3.5 cannot combine u"" and r"" together. Thus make + # an r string and use to_text to ensure it's text on both py2 and py3. + cmd[u'prompt'] = to_text(r"[\r\n]?password: $", errors='surrogate_or_strict') + cmd[u'answer'] = passwd try: - self._exec_cli_command(json.dumps(cmd)) + self._exec_cli_command(to_bytes(json.dumps(cmd), errors='surrogate_or_strict')) except AnsibleConnectionFailure: raise AnsibleConnectionFailure('unable to elevate privilege to enable mode') @@ -71,11 +74,9 @@ class TerminalModule(TerminalBase): # if prompt is None most likely the terminal is hung up at a prompt return - if '(config' in prompt: - self._exec_cli_command('end') - self._exec_cli_command('disable') - - elif prompt.endswith('#'): - self._exec_cli_command('disable') - + if b'(config' in prompt: + self._exec_cli_command(b'end') + self._exec_cli_command(b'disable') + elif prompt.endswith(b'#'): + self._exec_cli_command(b'disable') diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py index 818376a2eb..b9e4c8e40e 100644 --- a/test/units/plugins/connection/test_network_cli.py +++ b/test/units/plugins/connection/test_network_cli.py @@ -117,21 +117,21 @@ class TestConnectionClass(unittest.TestCase): mock_open_shell = MagicMock() conn.open_shell = mock_open_shell - mock_send = MagicMock(return_value='command response') + mock_send = MagicMock(return_value=b'command response') conn.send = mock_send # test sending a single command and converting to dict rc, out, err = conn.exec_command('command') - self.assertEqual(out, 'command response') + self.assertEqual(out, b'command response') self.assertTrue(mock_open_shell.called) - mock_send.assert_called_with({'command': 'command'}) + mock_send.assert_called_with({'command': b'command'}) mock_open_shell.reset_mock() # test sending a json string rc, out, err = conn.exec_command(json.dumps({'command': 'command'})) - self.assertEqual(out, 'command response') - mock_send.assert_called_with({'command': 'command'}) + self.assertEqual(out, b'command response') + mock_send.assert_called_with({'command': b'command'}) self.assertTrue(mock_open_shell.called) mock_open_shell.reset_mock() @@ -139,9 +139,9 @@ class TestConnectionClass(unittest.TestCase): # test _shell already open rc, out, err = conn.exec_command('command') - self.assertEqual(out, 'command response') + self.assertEqual(out, b'command response') self.assertFalse(mock_open_shell.called) - mock_send.assert_called_with({'command': 'command'}) + mock_send.assert_called_with({'command': b'command'}) def test_network_cli_send(self): @@ -150,14 +150,14 @@ class TestConnectionClass(unittest.TestCase): conn = network_cli.Connection(pc, new_stdin) mock__terminal = MagicMock() - mock__terminal.terminal_stdout_re = [re.compile('device#')] - mock__terminal.terminal_stderr_re = [re.compile('^ERROR')] + mock__terminal.terminal_stdout_re = [re.compile(b'device#')] + mock__terminal.terminal_stderr_re = [re.compile(b'^ERROR')] conn._terminal = mock__terminal mock__shell = MagicMock() conn._shell = mock__shell - response = """device#command + response = b"""device#command command response device# @@ -165,15 +165,15 @@ class TestConnectionClass(unittest.TestCase): mock__shell.recv.return_value = response - output = conn.send({'command': 'command'}) + output = conn.send({'command': b'command'}) - mock__shell.sendall.assert_called_with('command\r') - self.assertEqual(output, 'command response') + mock__shell.sendall.assert_called_with(b'command\r') + self.assertEqual(output, b'command response') mock__shell.reset_mock() - mock__shell.recv.return_value = "ERROR: error message" + mock__shell.recv.return_value = b"ERROR: error message" with self.assertRaises(AnsibleConnectionFailure) as exc: - conn.send({'command': 'command'}) + conn.send({'command': b'command'}) self.assertEqual(str(exc.exception), 'ERROR: error message')