1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Fix for persistent connection plugin on Python3 (#24431)

Fix for persistent connection plugin on Python3.  Note that fixes are also needed to each terminal plugin.  This PR only fixes the ios terminal (as proof that this approach is workable.)  Future PRs can address the other terminal types.

* On Python3, pickle needs to work with byte strings, not text strings.
* Set the pickle protocol version to 0 because we're using a pty to feed data to the connection plugin.  A pty can't have control characters.  So we have to send ascii only.  That means
only using protocol=0 for pickling the data.
* ansible-connection isn't being used with py3 in the bug but it needs
several changes to work with python3.
* In python3, closing the pty too early causes no data to be sent.  So
leave stdin open until after we finish with the ansible-connection
process.
* Fix typo using traceback.format_exc()
* Cleanup unnecessary StringIO, BytesIO, and to_bytes calls
* Modify the network_cli and terminal plugins for py3 compat.  Lots of mixing of text and byte strings that needs to be straightened out to be compatible with python3
* Documentation for the bytes<=>text strategy for terminal plugins
* Update unittests for more bytes-oriented internals

Fixes #24355
This commit is contained in:
Toshio Kuratomi 2017-05-12 09:13:51 -07:00 committed by GitHub
parent e539726543
commit d834412ead
6 changed files with 146 additions and 112 deletions

View file

@ -45,7 +45,8 @@ from io import BytesIO
from ansible import constants as C from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six.moves import cPickle, StringIO from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle
from ansible.playbook.play_context import PlayContext from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe from ansible.utils.path import unfrackpath, makedirs_safe
@ -73,11 +74,11 @@ def do_fork():
sys.exit(0) sys.exit(0)
if C.DEFAULT_LOG_PATH != '': if C.DEFAULT_LOG_PATH != '':
out_file = file(C.DEFAULT_LOG_PATH, 'a+') out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0) err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
else: else:
out_file = file('/dev/null', 'a+') out_file = open('/dev/null', 'ab+')
err_file = file('/dev/null', 'a+', 0) err_file = open('/dev/null', 'ab+', 0)
os.dup2(out_file.fileno(), sys.stdout.fileno()) os.dup2(out_file.fileno(), sys.stdout.fileno())
os.dup2(err_file.fileno(), sys.stderr.fileno()) os.dup2(err_file.fileno(), sys.stderr.fileno())
@ -211,11 +212,9 @@ class Server():
pass pass
elif data.startswith(b'CONTEXT: '): elif data.startswith(b'CONTEXT: '):
display.display("socket operation is CONTEXT", log_only=True) display.display("socket operation is CONTEXT", log_only=True)
pc_data = data.split(b'CONTEXT: ')[1] pc_data = data.split(b'CONTEXT: ', 1)[1]
src = StringIO(pc_data) pc_data = cPickle.loads(pc_data)
pc_data = cPickle.load(src)
src.close()
pc = PlayContext() pc = PlayContext()
pc.deserialize(pc_data) pc.deserialize(pc_data)
@ -234,12 +233,12 @@ class Server():
display.display("socket operation completed with rc %s" % rc, log_only=True) display.display("socket operation completed with rc %s" % rc, log_only=True)
send_data(s, to_bytes(str(rc))) send_data(s, to_bytes(rc))
send_data(s, to_bytes(stdout)) send_data(s, to_bytes(stdout))
send_data(s, to_bytes(stderr)) send_data(s, to_bytes(stderr))
s.close() s.close()
except Exception as e: except Exception as e:
display.display(traceback.format_exec(), log_only=True) display.display(traceback.format_exc(), log_only=True)
finally: finally:
# when done, close the connection properly and cleanup # when done, close the connection properly and cleanup
# the socket file so it can be recreated # the socket file so it can be recreated
@ -254,21 +253,25 @@ class Server():
os.remove(self.path) os.remove(self.path)
def main(): def main():
# Need stdin as a byte stream
if PY3:
stdin = sys.stdin.buffer
else:
stdin = sys.stdin
try: try:
# read the play context data via stdin, which means depickling it # read the play context data via stdin, which means depickling it
# FIXME: as noted above, we will probably need to deserialize the # FIXME: as noted above, we will probably need to deserialize the
# connection loader here as well at some point, otherwise this # connection loader here as well at some point, otherwise this
# won't find role- or playbook-based connection plugins # won't find role- or playbook-based connection plugins
cur_line = sys.stdin.readline() cur_line = stdin.readline()
init_data = '' init_data = b''
while cur_line.strip() != '#END_INIT#': while cur_line.strip() != b'#END_INIT#':
if cur_line == '': if cur_line == b'':
raise Exception("EOL found before init data was complete") raise Exception("EOF found before init data was complete")
init_data += cur_line init_data += cur_line
cur_line = sys.stdin.readline() cur_line = stdin.readline()
src = BytesIO(to_bytes(init_data)) pc_data = cPickle.loads(init_data)
pc_data = cPickle.load(src)
pc = PlayContext() pc = PlayContext()
pc.deserialize(pc_data) pc.deserialize(pc_data)
@ -319,10 +322,10 @@ def main():
# the connection will timeout here. Need to make this more resilient. # the connection will timeout here. Need to make this more resilient.
rc = 0 rc = 0
while rc == 0: while rc == 0:
data = sys.stdin.readline() data = stdin.readline()
if data == '': if data == b'':
break break
if data.strip() == '': if data.strip() == b'':
continue continue
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
attempts = 1 attempts = 1
@ -342,11 +345,10 @@ def main():
# send the play_context back into the connection so the connection # send the play_context back into the connection so the connection
# can handle any privilege escalation activities # can handle any privilege escalation activities
pc_data = 'CONTEXT: %s' % src.getvalue() pc_data = b'CONTEXT: %s' % init_data
send_data(sf, to_bytes(pc_data)) send_data(sf, pc_data)
src.close()
send_data(sf, to_bytes(data.strip())) send_data(sf, data.strip())
rc = int(recv_data(sf), 10) rc = int(recv_data(sf), 10)
stdout = recv_data(sf) stdout = recv_data(sf)

View file

@ -18,17 +18,18 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import re
import socket
import json import json
import signal
import datetime
import traceback
import logging import logging
import re
import signal
import socket
import traceback
from collections import Sequence
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils.six.moves import StringIO from ansible.module_utils.six import BytesIO, binary_type, text_type
from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins import terminal_loader from ansible.plugins import terminal_loader
from ansible.plugins.connection import ensure_connect from ansible.plugins.connection import ensure_connect
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection from ansible.plugins.connection.paramiko_ssh import Connection as _Connection
@ -113,7 +114,7 @@ class Connection(_Connection):
self._terminal.on_authorize(passwd=auth_pass) self._terminal.on_authorize(passwd=auth_pass)
display.display('shell successfully opened', log_only=True) display.display('shell successfully opened', log_only=True)
return (0, 'ok', '') return (0, b'ok', b'')
def close(self): def close(self):
display.display('closing connection', log_only=True) display.display('closing connection', log_only=True)
@ -131,11 +132,11 @@ class Connection(_Connection):
self._shell.close() self._shell.close()
self._shell = None self._shell = None
return (0, 'ok', '') return (0, b'ok', b'')
def receive(self, obj=None): def receive(self, obj=None):
"""Handles receiving of output from command""" """Handles receiving of output from command"""
recv = StringIO() recv = BytesIO()
handled = False handled = False
self._matched_prompt = None self._matched_prompt = None
@ -162,30 +163,30 @@ class Connection(_Connection):
try: try:
command = obj['command'] command = obj['command']
self._history.append(command) self._history.append(command)
self._shell.sendall('%s\r' % command) self._shell.sendall(b'%s\r' % command)
if obj.get('sendonly'): if obj.get('sendonly'):
return return
return self.receive(obj) return self.receive(obj)
except (socket.timeout, AttributeError) as exc: except (socket.timeout, AttributeError):
display.display(traceback.format_exc(), log_only=True) display.display(traceback.format_exc(), log_only=True)
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip()) raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())
def _strip(self, data): def _strip(self, data):
"""Removes ANSI codes from device response""" """Removes ANSI codes from device response"""
for regex in self._terminal.ansi_re: for regex in self._terminal.ansi_re:
data = regex.sub('', data) data = regex.sub(b'', data)
return data return data
def _handle_prompt(self, resp, obj): def _handle_prompt(self, resp, obj):
"""Matches the command prompt and responds""" """Matches the command prompt and responds"""
if not isinstance(obj['prompt'], list): if isinstance(obj, (binary_type, text_type)) or not isinstance(obj['prompt'], Sequence):
obj['prompt'] = [obj['prompt']] obj['prompt'] = [obj['prompt']]
prompts = [re.compile(r, re.I) for r in obj['prompt']] prompts = [re.compile(r, re.I) for r in obj['prompt']]
answer = obj['answer'] answer = obj['answer']
for regex in prompts: for regex in prompts:
match = regex.search(resp) match = regex.search(resp)
if match: if match:
self._shell.sendall('%s\r' % answer) self._shell.sendall(b'%s\r' % answer)
return True return True
def _sanitize(self, resp, obj=None): def _sanitize(self, resp, obj=None):
@ -196,7 +197,7 @@ class Connection(_Connection):
if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line: if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line:
continue continue
cleaned.append(line) cleaned.append(line)
return str("\n".join(cleaned)).strip() return b"\n".join(cleaned).strip()
def _find_prompt(self, response): def _find_prompt(self, response):
"""Searches the buffered response for a matching command prompt""" """Searches the buffered response for a matching command prompt"""
@ -225,9 +226,9 @@ class Connection(_Connection):
def exec_command(self, cmd): def exec_command(self, cmd):
"""Executes the cmd on in the shell and returns the output """Executes the cmd on in the shell and returns the output
The method accepts two forms of cmd. The first form is as a The method accepts two forms of cmd. The first form is as a byte
string that represents the command to be executed in the shell. The string that represents the command to be executed in the shell. The
second form is as a JSON string with additional keyword. second form is as a utf8 JSON byte string with additional keywords.
Keywords supported for cmd: Keywords supported for cmd:
* command - the command string to execute * command - the command string to execute
@ -235,28 +236,30 @@ class Connection(_Connection):
* answer - the string to respond to the prompt with * answer - the string to respond to the prompt with
* sendonly - bool to disable waiting for response * sendonly - bool to disable waiting for response
:arg cmd: the string that represents the command to be executed :arg cmd: the byte string that represents the command to be executed
which can be a single command or a json encoded string which can be a single command or a json encoded string.
:returns: a tuple of (return code, stdout, stderr). The return :returns: a tuple of (return code, stdout, stderr). The return
code is an integer and stdout and stderr are strings code is an integer and stdout and stderr are byte strings
""" """
try: try:
obj = json.loads(cmd) obj = json.loads(to_text(cmd, errors='surrogate_or_strict'))
obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items())
except (ValueError, TypeError): except (ValueError, TypeError):
obj = {'command': str(cmd).strip()} obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')}
if obj['command'] == 'close_shell()': if obj['command'] == b'close_shell()':
return self.close_shell() return self.close_shell()
elif obj['command'] == 'open_shell()': elif obj['command'] == b'open_shell()':
return self.open_shell() return self.open_shell()
elif obj['command'] == 'prompt()': elif obj['command'] == b'prompt()':
return (0, self._matched_prompt, '') return (0, self._matched_prompt, b'')
try: try:
if self._shell is None: if self._shell is None:
self.open_shell() self.open_shell()
except AnsibleConnectionFailure as exc: except AnsibleConnectionFailure as exc:
return (1, '', str(exc)) # FIXME: Feels like we should raise this rather than return it
return (1, b'', to_bytes(exc))
try: try:
if not signal.getsignal(signal.SIGALRM): if not signal.getsignal(signal.SIGALRM):
@ -264,6 +267,7 @@ class Connection(_Connection):
signal.alarm(self._play_context.timeout) signal.alarm(self._play_context.timeout)
out = self.send(obj) out = self.send(obj)
signal.alarm(0) signal.alarm(0)
return (0, out, '') return (0, out, b'')
except (AnsibleConnectionFailure, ValueError) as exc: except (AnsibleConnectionFailure, ValueError) as exc:
return (1, '', str(exc)) # FIXME: Feels like we should raise this rather than return it
return (1, b'', to_bytes(exc))

View file

@ -24,7 +24,7 @@ import subprocess
import sys import sys
from ansible.module_utils._text import to_bytes from ansible.module_utils._text import to_bytes
from ansible.module_utils.six.moves import cPickle, StringIO from ansible.module_utils.six.moves import cPickle
from ansible.plugins.connection import ConnectionBase from ansible.plugins.connection import ConnectionBase
try: try:
@ -52,16 +52,20 @@ class Connection(ConnectionBase):
stdin = os.fdopen(master, 'wb', 0) stdin = os.fdopen(master, 'wb', 0)
os.close(slave) os.close(slave)
src = StringIO() # Need to force a protocol that is compatible with both py2 and py3.
cPickle.dump(self._play_context.serialize(), src) # That would be protocol=2 or less.
stdin.write(src.getvalue()) # Also need to force a protocol that excludes certain control chars as
src.close() # stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
stdin.write(src)
stdin.write(b'\n#END_INIT#\n') stdin.write(b'\n#END_INIT#\n')
stdin.write(to_bytes(action)) stdin.write(to_bytes(action))
stdin.write(b'\n\n') stdin.write(b'\n\n')
stdin.close()
(stdout, stderr) = p.communicate() (stdout, stderr) = p.communicate()
stdin.close()
return (p.returncode, stdout, stderr) return (p.returncode, stdout, stderr)

View file

@ -30,33 +30,54 @@ from ansible.module_utils.six import with_metaclass
class TerminalBase(with_metaclass(ABCMeta, object)): class TerminalBase(with_metaclass(ABCMeta, object)):
''' '''
A base class for implementing cli connections A base class for implementing cli connections
.. note:: Unlike most of Ansible, nearly all strings in
:class:`TerminalBase` plugins are byte strings. This is because of
how close to the underlying platform these plugins operate. Remember
to mark literal strings as byte string (``b"string"``) and to use
:func:`~ansible.module_utils._text.to_bytes` and
:func:`~ansible.module_utils._text.to_text` to avoid unexpected
problems.
''' '''
# compiled regular expression as stdout #: compiled bytes regular expressions as stdout
terminal_stdout_re = [] terminal_stdout_re = []
# compiled regular expression as stderr #: compiled bytes regular expressions as stderr
terminal_stderr_re = [] terminal_stderr_re = []
# copiled regular expression to remove ANSI codes #: compiled bytes regular expressions to remove ANSI codes
ansi_re = [ ansi_re = [
re.compile(r'(\x1b\[\?1h\x1b=)'), re.compile(br'(\x1b\[\?1h\x1b=)'),
re.compile(r'\x08.') re.compile(br'\x08.')
] ]
def __init__(self, connection): def __init__(self, connection):
self._connection = connection self._connection = connection
def _exec_cli_command(self, cmd, check_rc=True): def _exec_cli_command(self, cmd, check_rc=True):
"""Executes a CLI command on the device""" """
Executes a CLI command on the device
:arg cmd: Byte string consisting of the command to execute
:kwarg check_rc: If True, the default, raise an
:exc:`AnsibleConnectionFailure` if the return code from the
command is nonzero
:returns: A tuple of return code, stdout, and stderr from running the
command. stdout and stderr are both byte strings.
"""
rc, out, err = self._connection.exec_command(cmd) rc, out, err = self._connection.exec_command(cmd)
if check_rc and rc != 0: if check_rc and rc != 0:
raise AnsibleConnectionFailure(err) raise AnsibleConnectionFailure(err)
return rc, out, err return rc, out, err
def _get_prompt(self): def _get_prompt(self):
""" Returns the current prompt from the device""" """
for cmd in ['\n', 'prompt()']: Returns the current prompt from the device
:returns: A byte string of the prompt
"""
for cmd in (b'\n', b'prompt()'):
rc, out, err = self._exec_cli_command(cmd) rc, out, err = self._exec_cli_command(cmd)
return out return out
@ -82,6 +103,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
def on_authorize(self, passwd=None): def on_authorize(self, passwd=None):
"""Called when privilege escalation is requested """Called when privilege escalation is requested
:kwarg passwd: String containing the password
This method is called when the privilege is requested to be elevated This method is called when the privilege is requested to be elevated
in the play context by setting become to True. It is the responsibility in the play context by setting become to True. It is the responsibility
of the terminal plugin to actually do the privilege escalation such of the terminal plugin to actually do the privilege escalation such
@ -94,6 +117,6 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
This method is called when the privilege changed from escalated This method is called when the privilege changed from escalated
(become=True) to non escalated (become=False). It is the responsibility (become=True) to non escalated (become=False). It is the responsibility
of the this method to actually perform the deauthorization procedure of this method to actually perform the deauthorization procedure
""" """
pass pass

View file

@ -19,49 +19,52 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import re
import json import json
import re
from ansible.plugins.terminal import TerminalBase
from ansible.errors import AnsibleConnectionFailure from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils._text import to_bytes
from ansible.plugins.terminal import TerminalBase
class TerminalModule(TerminalBase): class TerminalModule(TerminalBase):
terminal_stdout_re = [ terminal_stdout_re = [
re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"), re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$") re.compile(br"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
] ]
terminal_stderr_re = [ terminal_stderr_re = [
re.compile(r"% ?Error"), re.compile(br"% ?Error"),
#re.compile(r"^% \w+", re.M), #re.compile(br"^% \w+", re.M),
re.compile(r"% ?Bad secret"), re.compile(br"% ?Bad secret"),
re.compile(r"invalid input", re.I), re.compile(br"invalid input", re.I),
re.compile(r"(?:incomplete|ambiguous) command", re.I), re.compile(br"(?:incomplete|ambiguous) command", re.I),
re.compile(r"connection timed out", re.I), re.compile(br"connection timed out", re.I),
re.compile(r"[^\r\n]+ not found", re.I), re.compile(br"[^\r\n]+ not found", re.I),
re.compile(r"'[^']' +returned error code: ?\d+"), re.compile(br"'[^']' +returned error code: ?\d+"),
] ]
def on_open_shell(self): def on_open_shell(self):
try: try:
for cmd in ['terminal length 0', 'terminal width 512']: for cmd in (b'terminal length 0', b'terminal width 512'):
self._exec_cli_command(cmd) self._exec_cli_command(cmd)
except AnsibleConnectionFailure: except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to set terminal parameters') raise AnsibleConnectionFailure('unable to set terminal parameters')
def on_authorize(self, passwd=None): def on_authorize(self, passwd=None):
if self._get_prompt().endswith('#'): if self._get_prompt().endswith(b'#'):
return return
cmd = {'command': 'enable'} cmd = {u'command': u'enable'}
if passwd: if passwd:
cmd['prompt'] = r"[\r\n]?password: $" # Note: python-3.5 cannot combine u"" and r"" together. Thus make
cmd['answer'] = passwd # an r string and use to_text to ensure it's text on both py2 and py3.
cmd[u'prompt'] = to_text(r"[\r\n]?password: $", errors='surrogate_or_strict')
cmd[u'answer'] = passwd
try: try:
self._exec_cli_command(json.dumps(cmd)) self._exec_cli_command(to_bytes(json.dumps(cmd), errors='surrogate_or_strict'))
except AnsibleConnectionFailure: except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to elevate privilege to enable mode') raise AnsibleConnectionFailure('unable to elevate privilege to enable mode')
@ -71,11 +74,9 @@ class TerminalModule(TerminalBase):
# if prompt is None most likely the terminal is hung up at a prompt # if prompt is None most likely the terminal is hung up at a prompt
return return
if '(config' in prompt: if b'(config' in prompt:
self._exec_cli_command('end') self._exec_cli_command(b'end')
self._exec_cli_command('disable') self._exec_cli_command(b'disable')
elif prompt.endswith('#'):
self._exec_cli_command('disable')
elif prompt.endswith(b'#'):
self._exec_cli_command(b'disable')

View file

@ -117,21 +117,21 @@ class TestConnectionClass(unittest.TestCase):
mock_open_shell = MagicMock() mock_open_shell = MagicMock()
conn.open_shell = mock_open_shell conn.open_shell = mock_open_shell
mock_send = MagicMock(return_value='command response') mock_send = MagicMock(return_value=b'command response')
conn.send = mock_send conn.send = mock_send
# test sending a single command and converting to dict # test sending a single command and converting to dict
rc, out, err = conn.exec_command('command') rc, out, err = conn.exec_command('command')
self.assertEqual(out, 'command response') self.assertEqual(out, b'command response')
self.assertTrue(mock_open_shell.called) self.assertTrue(mock_open_shell.called)
mock_send.assert_called_with({'command': 'command'}) mock_send.assert_called_with({'command': b'command'})
mock_open_shell.reset_mock() mock_open_shell.reset_mock()
# test sending a json string # test sending a json string
rc, out, err = conn.exec_command(json.dumps({'command': 'command'})) rc, out, err = conn.exec_command(json.dumps({'command': 'command'}))
self.assertEqual(out, 'command response') self.assertEqual(out, b'command response')
mock_send.assert_called_with({'command': 'command'}) mock_send.assert_called_with({'command': b'command'})
self.assertTrue(mock_open_shell.called) self.assertTrue(mock_open_shell.called)
mock_open_shell.reset_mock() mock_open_shell.reset_mock()
@ -139,9 +139,9 @@ class TestConnectionClass(unittest.TestCase):
# test _shell already open # test _shell already open
rc, out, err = conn.exec_command('command') rc, out, err = conn.exec_command('command')
self.assertEqual(out, 'command response') self.assertEqual(out, b'command response')
self.assertFalse(mock_open_shell.called) self.assertFalse(mock_open_shell.called)
mock_send.assert_called_with({'command': 'command'}) mock_send.assert_called_with({'command': b'command'})
def test_network_cli_send(self): def test_network_cli_send(self):
@ -150,14 +150,14 @@ class TestConnectionClass(unittest.TestCase):
conn = network_cli.Connection(pc, new_stdin) conn = network_cli.Connection(pc, new_stdin)
mock__terminal = MagicMock() mock__terminal = MagicMock()
mock__terminal.terminal_stdout_re = [re.compile('device#')] mock__terminal.terminal_stdout_re = [re.compile(b'device#')]
mock__terminal.terminal_stderr_re = [re.compile('^ERROR')] mock__terminal.terminal_stderr_re = [re.compile(b'^ERROR')]
conn._terminal = mock__terminal conn._terminal = mock__terminal
mock__shell = MagicMock() mock__shell = MagicMock()
conn._shell = mock__shell conn._shell = mock__shell
response = """device#command response = b"""device#command
command response command response
device# device#
@ -165,15 +165,15 @@ class TestConnectionClass(unittest.TestCase):
mock__shell.recv.return_value = response mock__shell.recv.return_value = response
output = conn.send({'command': 'command'}) output = conn.send({'command': b'command'})
mock__shell.sendall.assert_called_with('command\r') mock__shell.sendall.assert_called_with(b'command\r')
self.assertEqual(output, 'command response') self.assertEqual(output, b'command response')
mock__shell.reset_mock() mock__shell.reset_mock()
mock__shell.recv.return_value = "ERROR: error message" mock__shell.recv.return_value = b"ERROR: error message"
with self.assertRaises(AnsibleConnectionFailure) as exc: with self.assertRaises(AnsibleConnectionFailure) as exc:
conn.send({'command': 'command'}) conn.send({'command': b'command'})
self.assertEqual(str(exc.exception), 'ERROR: error message') self.assertEqual(str(exc.exception), 'ERROR: error message')