mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Fix for persistent connection plugin on Python3 (#24431)
Fix for persistent connection plugin on Python3. Note that fixes are also needed to each terminal plugin. This PR only fixes the ios terminal (as proof that this approach is workable.) Future PRs can address the other terminal types. * On Python3, pickle needs to work with byte strings, not text strings. * Set the pickle protocol version to 0 because we're using a pty to feed data to the connection plugin. A pty can't have control characters. So we have to send ascii only. That means only using protocol=0 for pickling the data. * ansible-connection isn't being used with py3 in the bug but it needs several changes to work with python3. * In python3, closing the pty too early causes no data to be sent. So leave stdin open until after we finish with the ansible-connection process. * Fix typo using traceback.format_exc() * Cleanup unnecessary StringIO, BytesIO, and to_bytes calls * Modify the network_cli and terminal plugins for py3 compat. Lots of mixing of text and byte strings that needs to be straightened out to be compatible with python3 * Documentation for the bytes<=>text strategy for terminal plugins * Update unittests for more bytes-oriented internals Fixes #24355
This commit is contained in:
parent
e539726543
commit
d834412ead
6 changed files with 146 additions and 112 deletions
|
@ -45,7 +45,8 @@ from io import BytesIO
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
from ansible.module_utils._text import to_bytes, to_native
|
from ansible.module_utils._text import to_bytes, to_native
|
||||||
from ansible.module_utils.six.moves import cPickle, StringIO
|
from ansible.module_utils.six import PY3
|
||||||
|
from ansible.module_utils.six.moves import cPickle
|
||||||
from ansible.playbook.play_context import PlayContext
|
from ansible.playbook.play_context import PlayContext
|
||||||
from ansible.plugins import connection_loader
|
from ansible.plugins import connection_loader
|
||||||
from ansible.utils.path import unfrackpath, makedirs_safe
|
from ansible.utils.path import unfrackpath, makedirs_safe
|
||||||
|
@ -73,11 +74,11 @@ def do_fork():
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if C.DEFAULT_LOG_PATH != '':
|
if C.DEFAULT_LOG_PATH != '':
|
||||||
out_file = file(C.DEFAULT_LOG_PATH, 'a+')
|
out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
|
||||||
err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0)
|
err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
|
||||||
else:
|
else:
|
||||||
out_file = file('/dev/null', 'a+')
|
out_file = open('/dev/null', 'ab+')
|
||||||
err_file = file('/dev/null', 'a+', 0)
|
err_file = open('/dev/null', 'ab+', 0)
|
||||||
|
|
||||||
os.dup2(out_file.fileno(), sys.stdout.fileno())
|
os.dup2(out_file.fileno(), sys.stdout.fileno())
|
||||||
os.dup2(err_file.fileno(), sys.stderr.fileno())
|
os.dup2(err_file.fileno(), sys.stderr.fileno())
|
||||||
|
@ -211,11 +212,9 @@ class Server():
|
||||||
pass
|
pass
|
||||||
elif data.startswith(b'CONTEXT: '):
|
elif data.startswith(b'CONTEXT: '):
|
||||||
display.display("socket operation is CONTEXT", log_only=True)
|
display.display("socket operation is CONTEXT", log_only=True)
|
||||||
pc_data = data.split(b'CONTEXT: ')[1]
|
pc_data = data.split(b'CONTEXT: ', 1)[1]
|
||||||
|
|
||||||
src = StringIO(pc_data)
|
pc_data = cPickle.loads(pc_data)
|
||||||
pc_data = cPickle.load(src)
|
|
||||||
src.close()
|
|
||||||
|
|
||||||
pc = PlayContext()
|
pc = PlayContext()
|
||||||
pc.deserialize(pc_data)
|
pc.deserialize(pc_data)
|
||||||
|
@ -234,12 +233,12 @@ class Server():
|
||||||
|
|
||||||
display.display("socket operation completed with rc %s" % rc, log_only=True)
|
display.display("socket operation completed with rc %s" % rc, log_only=True)
|
||||||
|
|
||||||
send_data(s, to_bytes(str(rc)))
|
send_data(s, to_bytes(rc))
|
||||||
send_data(s, to_bytes(stdout))
|
send_data(s, to_bytes(stdout))
|
||||||
send_data(s, to_bytes(stderr))
|
send_data(s, to_bytes(stderr))
|
||||||
s.close()
|
s.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
display.display(traceback.format_exec(), log_only=True)
|
display.display(traceback.format_exc(), log_only=True)
|
||||||
finally:
|
finally:
|
||||||
# when done, close the connection properly and cleanup
|
# when done, close the connection properly and cleanup
|
||||||
# the socket file so it can be recreated
|
# the socket file so it can be recreated
|
||||||
|
@ -254,21 +253,25 @@ class Server():
|
||||||
os.remove(self.path)
|
os.remove(self.path)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Need stdin as a byte stream
|
||||||
|
if PY3:
|
||||||
|
stdin = sys.stdin.buffer
|
||||||
|
else:
|
||||||
|
stdin = sys.stdin
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# read the play context data via stdin, which means depickling it
|
# read the play context data via stdin, which means depickling it
|
||||||
# FIXME: as noted above, we will probably need to deserialize the
|
# FIXME: as noted above, we will probably need to deserialize the
|
||||||
# connection loader here as well at some point, otherwise this
|
# connection loader here as well at some point, otherwise this
|
||||||
# won't find role- or playbook-based connection plugins
|
# won't find role- or playbook-based connection plugins
|
||||||
cur_line = sys.stdin.readline()
|
cur_line = stdin.readline()
|
||||||
init_data = ''
|
init_data = b''
|
||||||
while cur_line.strip() != '#END_INIT#':
|
while cur_line.strip() != b'#END_INIT#':
|
||||||
if cur_line == '':
|
if cur_line == b'':
|
||||||
raise Exception("EOL found before init data was complete")
|
raise Exception("EOF found before init data was complete")
|
||||||
init_data += cur_line
|
init_data += cur_line
|
||||||
cur_line = sys.stdin.readline()
|
cur_line = stdin.readline()
|
||||||
src = BytesIO(to_bytes(init_data))
|
pc_data = cPickle.loads(init_data)
|
||||||
pc_data = cPickle.load(src)
|
|
||||||
|
|
||||||
pc = PlayContext()
|
pc = PlayContext()
|
||||||
pc.deserialize(pc_data)
|
pc.deserialize(pc_data)
|
||||||
|
@ -319,10 +322,10 @@ def main():
|
||||||
# the connection will timeout here. Need to make this more resilient.
|
# the connection will timeout here. Need to make this more resilient.
|
||||||
rc = 0
|
rc = 0
|
||||||
while rc == 0:
|
while rc == 0:
|
||||||
data = sys.stdin.readline()
|
data = stdin.readline()
|
||||||
if data == '':
|
if data == b'':
|
||||||
break
|
break
|
||||||
if data.strip() == '':
|
if data.strip() == b'':
|
||||||
continue
|
continue
|
||||||
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
attempts = 1
|
attempts = 1
|
||||||
|
@ -342,11 +345,10 @@ def main():
|
||||||
|
|
||||||
# send the play_context back into the connection so the connection
|
# send the play_context back into the connection so the connection
|
||||||
# can handle any privilege escalation activities
|
# can handle any privilege escalation activities
|
||||||
pc_data = 'CONTEXT: %s' % src.getvalue()
|
pc_data = b'CONTEXT: %s' % init_data
|
||||||
send_data(sf, to_bytes(pc_data))
|
send_data(sf, pc_data)
|
||||||
src.close()
|
|
||||||
|
|
||||||
send_data(sf, to_bytes(data.strip()))
|
send_data(sf, data.strip())
|
||||||
|
|
||||||
rc = int(recv_data(sf), 10)
|
rc = int(recv_data(sf), 10)
|
||||||
stdout = recv_data(sf)
|
stdout = recv_data(sf)
|
||||||
|
|
|
@ -18,17 +18,18 @@
|
||||||
from __future__ import (absolute_import, division, print_function)
|
from __future__ import (absolute_import, division, print_function)
|
||||||
__metaclass__ = type
|
__metaclass__ = type
|
||||||
|
|
||||||
import re
|
|
||||||
import socket
|
|
||||||
import json
|
import json
|
||||||
import signal
|
|
||||||
import datetime
|
|
||||||
import traceback
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import traceback
|
||||||
|
from collections import Sequence
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
from ansible.errors import AnsibleConnectionFailure
|
from ansible.errors import AnsibleConnectionFailure
|
||||||
from ansible.module_utils.six.moves import StringIO
|
from ansible.module_utils.six import BytesIO, binary_type, text_type
|
||||||
|
from ansible.module_utils._text import to_bytes, to_text
|
||||||
from ansible.plugins import terminal_loader
|
from ansible.plugins import terminal_loader
|
||||||
from ansible.plugins.connection import ensure_connect
|
from ansible.plugins.connection import ensure_connect
|
||||||
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection
|
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection
|
||||||
|
@ -113,7 +114,7 @@ class Connection(_Connection):
|
||||||
self._terminal.on_authorize(passwd=auth_pass)
|
self._terminal.on_authorize(passwd=auth_pass)
|
||||||
|
|
||||||
display.display('shell successfully opened', log_only=True)
|
display.display('shell successfully opened', log_only=True)
|
||||||
return (0, 'ok', '')
|
return (0, b'ok', b'')
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
display.display('closing connection', log_only=True)
|
display.display('closing connection', log_only=True)
|
||||||
|
@ -131,11 +132,11 @@ class Connection(_Connection):
|
||||||
self._shell.close()
|
self._shell.close()
|
||||||
self._shell = None
|
self._shell = None
|
||||||
|
|
||||||
return (0, 'ok', '')
|
return (0, b'ok', b'')
|
||||||
|
|
||||||
def receive(self, obj=None):
|
def receive(self, obj=None):
|
||||||
"""Handles receiving of output from command"""
|
"""Handles receiving of output from command"""
|
||||||
recv = StringIO()
|
recv = BytesIO()
|
||||||
handled = False
|
handled = False
|
||||||
|
|
||||||
self._matched_prompt = None
|
self._matched_prompt = None
|
||||||
|
@ -162,30 +163,30 @@ class Connection(_Connection):
|
||||||
try:
|
try:
|
||||||
command = obj['command']
|
command = obj['command']
|
||||||
self._history.append(command)
|
self._history.append(command)
|
||||||
self._shell.sendall('%s\r' % command)
|
self._shell.sendall(b'%s\r' % command)
|
||||||
if obj.get('sendonly'):
|
if obj.get('sendonly'):
|
||||||
return
|
return
|
||||||
return self.receive(obj)
|
return self.receive(obj)
|
||||||
except (socket.timeout, AttributeError) as exc:
|
except (socket.timeout, AttributeError):
|
||||||
display.display(traceback.format_exc(), log_only=True)
|
display.display(traceback.format_exc(), log_only=True)
|
||||||
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())
|
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())
|
||||||
|
|
||||||
def _strip(self, data):
|
def _strip(self, data):
|
||||||
"""Removes ANSI codes from device response"""
|
"""Removes ANSI codes from device response"""
|
||||||
for regex in self._terminal.ansi_re:
|
for regex in self._terminal.ansi_re:
|
||||||
data = regex.sub('', data)
|
data = regex.sub(b'', data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _handle_prompt(self, resp, obj):
|
def _handle_prompt(self, resp, obj):
|
||||||
"""Matches the command prompt and responds"""
|
"""Matches the command prompt and responds"""
|
||||||
if not isinstance(obj['prompt'], list):
|
if isinstance(obj, (binary_type, text_type)) or not isinstance(obj['prompt'], Sequence):
|
||||||
obj['prompt'] = [obj['prompt']]
|
obj['prompt'] = [obj['prompt']]
|
||||||
prompts = [re.compile(r, re.I) for r in obj['prompt']]
|
prompts = [re.compile(r, re.I) for r in obj['prompt']]
|
||||||
answer = obj['answer']
|
answer = obj['answer']
|
||||||
for regex in prompts:
|
for regex in prompts:
|
||||||
match = regex.search(resp)
|
match = regex.search(resp)
|
||||||
if match:
|
if match:
|
||||||
self._shell.sendall('%s\r' % answer)
|
self._shell.sendall(b'%s\r' % answer)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _sanitize(self, resp, obj=None):
|
def _sanitize(self, resp, obj=None):
|
||||||
|
@ -196,7 +197,7 @@ class Connection(_Connection):
|
||||||
if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line:
|
if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line:
|
||||||
continue
|
continue
|
||||||
cleaned.append(line)
|
cleaned.append(line)
|
||||||
return str("\n".join(cleaned)).strip()
|
return b"\n".join(cleaned).strip()
|
||||||
|
|
||||||
def _find_prompt(self, response):
|
def _find_prompt(self, response):
|
||||||
"""Searches the buffered response for a matching command prompt"""
|
"""Searches the buffered response for a matching command prompt"""
|
||||||
|
@ -225,9 +226,9 @@ class Connection(_Connection):
|
||||||
def exec_command(self, cmd):
|
def exec_command(self, cmd):
|
||||||
"""Executes the cmd on in the shell and returns the output
|
"""Executes the cmd on in the shell and returns the output
|
||||||
|
|
||||||
The method accepts two forms of cmd. The first form is as a
|
The method accepts two forms of cmd. The first form is as a byte
|
||||||
string that represents the command to be executed in the shell. The
|
string that represents the command to be executed in the shell. The
|
||||||
second form is as a JSON string with additional keyword.
|
second form is as a utf8 JSON byte string with additional keywords.
|
||||||
|
|
||||||
Keywords supported for cmd:
|
Keywords supported for cmd:
|
||||||
* command - the command string to execute
|
* command - the command string to execute
|
||||||
|
@ -235,28 +236,30 @@ class Connection(_Connection):
|
||||||
* answer - the string to respond to the prompt with
|
* answer - the string to respond to the prompt with
|
||||||
* sendonly - bool to disable waiting for response
|
* sendonly - bool to disable waiting for response
|
||||||
|
|
||||||
:arg cmd: the string that represents the command to be executed
|
:arg cmd: the byte string that represents the command to be executed
|
||||||
which can be a single command or a json encoded string
|
which can be a single command or a json encoded string.
|
||||||
:returns: a tuple of (return code, stdout, stderr). The return
|
:returns: a tuple of (return code, stdout, stderr). The return
|
||||||
code is an integer and stdout and stderr are strings
|
code is an integer and stdout and stderr are byte strings
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
obj = json.loads(cmd)
|
obj = json.loads(to_text(cmd, errors='surrogate_or_strict'))
|
||||||
|
obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items())
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
obj = {'command': str(cmd).strip()}
|
obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')}
|
||||||
|
|
||||||
if obj['command'] == 'close_shell()':
|
if obj['command'] == b'close_shell()':
|
||||||
return self.close_shell()
|
return self.close_shell()
|
||||||
elif obj['command'] == 'open_shell()':
|
elif obj['command'] == b'open_shell()':
|
||||||
return self.open_shell()
|
return self.open_shell()
|
||||||
elif obj['command'] == 'prompt()':
|
elif obj['command'] == b'prompt()':
|
||||||
return (0, self._matched_prompt, '')
|
return (0, self._matched_prompt, b'')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self._shell is None:
|
if self._shell is None:
|
||||||
self.open_shell()
|
self.open_shell()
|
||||||
except AnsibleConnectionFailure as exc:
|
except AnsibleConnectionFailure as exc:
|
||||||
return (1, '', str(exc))
|
# FIXME: Feels like we should raise this rather than return it
|
||||||
|
return (1, b'', to_bytes(exc))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not signal.getsignal(signal.SIGALRM):
|
if not signal.getsignal(signal.SIGALRM):
|
||||||
|
@ -264,6 +267,7 @@ class Connection(_Connection):
|
||||||
signal.alarm(self._play_context.timeout)
|
signal.alarm(self._play_context.timeout)
|
||||||
out = self.send(obj)
|
out = self.send(obj)
|
||||||
signal.alarm(0)
|
signal.alarm(0)
|
||||||
return (0, out, '')
|
return (0, out, b'')
|
||||||
except (AnsibleConnectionFailure, ValueError) as exc:
|
except (AnsibleConnectionFailure, ValueError) as exc:
|
||||||
return (1, '', str(exc))
|
# FIXME: Feels like we should raise this rather than return it
|
||||||
|
return (1, b'', to_bytes(exc))
|
||||||
|
|
|
@ -24,7 +24,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from ansible.module_utils._text import to_bytes
|
from ansible.module_utils._text import to_bytes
|
||||||
from ansible.module_utils.six.moves import cPickle, StringIO
|
from ansible.module_utils.six.moves import cPickle
|
||||||
from ansible.plugins.connection import ConnectionBase
|
from ansible.plugins.connection import ConnectionBase
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -52,16 +52,20 @@ class Connection(ConnectionBase):
|
||||||
stdin = os.fdopen(master, 'wb', 0)
|
stdin = os.fdopen(master, 'wb', 0)
|
||||||
os.close(slave)
|
os.close(slave)
|
||||||
|
|
||||||
src = StringIO()
|
# Need to force a protocol that is compatible with both py2 and py3.
|
||||||
cPickle.dump(self._play_context.serialize(), src)
|
# That would be protocol=2 or less.
|
||||||
stdin.write(src.getvalue())
|
# Also need to force a protocol that excludes certain control chars as
|
||||||
src.close()
|
# stdin in this case is a pty and control chars will cause problems.
|
||||||
|
# that means only protocol=0 will work.
|
||||||
|
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
|
||||||
|
stdin.write(src)
|
||||||
|
|
||||||
stdin.write(b'\n#END_INIT#\n')
|
stdin.write(b'\n#END_INIT#\n')
|
||||||
stdin.write(to_bytes(action))
|
stdin.write(to_bytes(action))
|
||||||
stdin.write(b'\n\n')
|
stdin.write(b'\n\n')
|
||||||
stdin.close()
|
|
||||||
(stdout, stderr) = p.communicate()
|
(stdout, stderr) = p.communicate()
|
||||||
|
stdin.close()
|
||||||
|
|
||||||
return (p.returncode, stdout, stderr)
|
return (p.returncode, stdout, stderr)
|
||||||
|
|
||||||
|
|
|
@ -30,33 +30,54 @@ from ansible.module_utils.six import with_metaclass
|
||||||
class TerminalBase(with_metaclass(ABCMeta, object)):
|
class TerminalBase(with_metaclass(ABCMeta, object)):
|
||||||
'''
|
'''
|
||||||
A base class for implementing cli connections
|
A base class for implementing cli connections
|
||||||
|
|
||||||
|
.. note:: Unlike most of Ansible, nearly all strings in
|
||||||
|
:class:`TerminalBase` plugins are byte strings. This is because of
|
||||||
|
how close to the underlying platform these plugins operate. Remember
|
||||||
|
to mark literal strings as byte string (``b"string"``) and to use
|
||||||
|
:func:`~ansible.module_utils._text.to_bytes` and
|
||||||
|
:func:`~ansible.module_utils._text.to_text` to avoid unexpected
|
||||||
|
problems.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# compiled regular expression as stdout
|
#: compiled bytes regular expressions as stdout
|
||||||
terminal_stdout_re = []
|
terminal_stdout_re = []
|
||||||
|
|
||||||
# compiled regular expression as stderr
|
#: compiled bytes regular expressions as stderr
|
||||||
terminal_stderr_re = []
|
terminal_stderr_re = []
|
||||||
|
|
||||||
# copiled regular expression to remove ANSI codes
|
#: compiled bytes regular expressions to remove ANSI codes
|
||||||
ansi_re = [
|
ansi_re = [
|
||||||
re.compile(r'(\x1b\[\?1h\x1b=)'),
|
re.compile(br'(\x1b\[\?1h\x1b=)'),
|
||||||
re.compile(r'\x08.')
|
re.compile(br'\x08.')
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
|
|
||||||
def _exec_cli_command(self, cmd, check_rc=True):
|
def _exec_cli_command(self, cmd, check_rc=True):
|
||||||
"""Executes a CLI command on the device"""
|
"""
|
||||||
|
Executes a CLI command on the device
|
||||||
|
|
||||||
|
:arg cmd: Byte string consisting of the command to execute
|
||||||
|
:kwarg check_rc: If True, the default, raise an
|
||||||
|
:exc:`AnsibleConnectionFailure` if the return code from the
|
||||||
|
command is nonzero
|
||||||
|
:returns: A tuple of return code, stdout, and stderr from running the
|
||||||
|
command. stdout and stderr are both byte strings.
|
||||||
|
"""
|
||||||
rc, out, err = self._connection.exec_command(cmd)
|
rc, out, err = self._connection.exec_command(cmd)
|
||||||
if check_rc and rc != 0:
|
if check_rc and rc != 0:
|
||||||
raise AnsibleConnectionFailure(err)
|
raise AnsibleConnectionFailure(err)
|
||||||
return rc, out, err
|
return rc, out, err
|
||||||
|
|
||||||
def _get_prompt(self):
|
def _get_prompt(self):
|
||||||
""" Returns the current prompt from the device"""
|
"""
|
||||||
for cmd in ['\n', 'prompt()']:
|
Returns the current prompt from the device
|
||||||
|
|
||||||
|
:returns: A byte string of the prompt
|
||||||
|
"""
|
||||||
|
for cmd in (b'\n', b'prompt()'):
|
||||||
rc, out, err = self._exec_cli_command(cmd)
|
rc, out, err = self._exec_cli_command(cmd)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -82,6 +103,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
|
||||||
def on_authorize(self, passwd=None):
|
def on_authorize(self, passwd=None):
|
||||||
"""Called when privilege escalation is requested
|
"""Called when privilege escalation is requested
|
||||||
|
|
||||||
|
:kwarg passwd: String containing the password
|
||||||
|
|
||||||
This method is called when the privilege is requested to be elevated
|
This method is called when the privilege is requested to be elevated
|
||||||
in the play context by setting become to True. It is the responsibility
|
in the play context by setting become to True. It is the responsibility
|
||||||
of the terminal plugin to actually do the privilege escalation such
|
of the terminal plugin to actually do the privilege escalation such
|
||||||
|
@ -94,6 +117,6 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
|
||||||
|
|
||||||
This method is called when the privilege changed from escalated
|
This method is called when the privilege changed from escalated
|
||||||
(become=True) to non escalated (become=False). It is the responsibility
|
(become=True) to non escalated (become=False). It is the responsibility
|
||||||
of the this method to actually perform the deauthorization procedure
|
of this method to actually perform the deauthorization procedure
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -19,49 +19,52 @@
|
||||||
from __future__ import (absolute_import, division, print_function)
|
from __future__ import (absolute_import, division, print_function)
|
||||||
__metaclass__ = type
|
__metaclass__ = type
|
||||||
|
|
||||||
import re
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from ansible.plugins.terminal import TerminalBase
|
|
||||||
from ansible.errors import AnsibleConnectionFailure
|
from ansible.errors import AnsibleConnectionFailure
|
||||||
|
from ansible.module_utils._text import to_bytes
|
||||||
|
from ansible.plugins.terminal import TerminalBase
|
||||||
|
|
||||||
|
|
||||||
class TerminalModule(TerminalBase):
|
class TerminalModule(TerminalBase):
|
||||||
|
|
||||||
terminal_stdout_re = [
|
terminal_stdout_re = [
|
||||||
re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
|
re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
|
||||||
re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
|
re.compile(br"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
|
||||||
]
|
]
|
||||||
|
|
||||||
terminal_stderr_re = [
|
terminal_stderr_re = [
|
||||||
re.compile(r"% ?Error"),
|
re.compile(br"% ?Error"),
|
||||||
#re.compile(r"^% \w+", re.M),
|
#re.compile(br"^% \w+", re.M),
|
||||||
re.compile(r"% ?Bad secret"),
|
re.compile(br"% ?Bad secret"),
|
||||||
re.compile(r"invalid input", re.I),
|
re.compile(br"invalid input", re.I),
|
||||||
re.compile(r"(?:incomplete|ambiguous) command", re.I),
|
re.compile(br"(?:incomplete|ambiguous) command", re.I),
|
||||||
re.compile(r"connection timed out", re.I),
|
re.compile(br"connection timed out", re.I),
|
||||||
re.compile(r"[^\r\n]+ not found", re.I),
|
re.compile(br"[^\r\n]+ not found", re.I),
|
||||||
re.compile(r"'[^']' +returned error code: ?\d+"),
|
re.compile(br"'[^']' +returned error code: ?\d+"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def on_open_shell(self):
|
def on_open_shell(self):
|
||||||
try:
|
try:
|
||||||
for cmd in ['terminal length 0', 'terminal width 512']:
|
for cmd in (b'terminal length 0', b'terminal width 512'):
|
||||||
self._exec_cli_command(cmd)
|
self._exec_cli_command(cmd)
|
||||||
except AnsibleConnectionFailure:
|
except AnsibleConnectionFailure:
|
||||||
raise AnsibleConnectionFailure('unable to set terminal parameters')
|
raise AnsibleConnectionFailure('unable to set terminal parameters')
|
||||||
|
|
||||||
def on_authorize(self, passwd=None):
|
def on_authorize(self, passwd=None):
|
||||||
if self._get_prompt().endswith('#'):
|
if self._get_prompt().endswith(b'#'):
|
||||||
return
|
return
|
||||||
|
|
||||||
cmd = {'command': 'enable'}
|
cmd = {u'command': u'enable'}
|
||||||
if passwd:
|
if passwd:
|
||||||
cmd['prompt'] = r"[\r\n]?password: $"
|
# Note: python-3.5 cannot combine u"" and r"" together. Thus make
|
||||||
cmd['answer'] = passwd
|
# an r string and use to_text to ensure it's text on both py2 and py3.
|
||||||
|
cmd[u'prompt'] = to_text(r"[\r\n]?password: $", errors='surrogate_or_strict')
|
||||||
|
cmd[u'answer'] = passwd
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._exec_cli_command(json.dumps(cmd))
|
self._exec_cli_command(to_bytes(json.dumps(cmd), errors='surrogate_or_strict'))
|
||||||
except AnsibleConnectionFailure:
|
except AnsibleConnectionFailure:
|
||||||
raise AnsibleConnectionFailure('unable to elevate privilege to enable mode')
|
raise AnsibleConnectionFailure('unable to elevate privilege to enable mode')
|
||||||
|
|
||||||
|
@ -71,11 +74,9 @@ class TerminalModule(TerminalBase):
|
||||||
# if prompt is None most likely the terminal is hung up at a prompt
|
# if prompt is None most likely the terminal is hung up at a prompt
|
||||||
return
|
return
|
||||||
|
|
||||||
if '(config' in prompt:
|
if b'(config' in prompt:
|
||||||
self._exec_cli_command('end')
|
self._exec_cli_command(b'end')
|
||||||
self._exec_cli_command('disable')
|
self._exec_cli_command(b'disable')
|
||||||
|
|
||||||
elif prompt.endswith('#'):
|
|
||||||
self._exec_cli_command('disable')
|
|
||||||
|
|
||||||
|
|
||||||
|
elif prompt.endswith(b'#'):
|
||||||
|
self._exec_cli_command(b'disable')
|
||||||
|
|
|
@ -117,21 +117,21 @@ class TestConnectionClass(unittest.TestCase):
|
||||||
mock_open_shell = MagicMock()
|
mock_open_shell = MagicMock()
|
||||||
conn.open_shell = mock_open_shell
|
conn.open_shell = mock_open_shell
|
||||||
|
|
||||||
mock_send = MagicMock(return_value='command response')
|
mock_send = MagicMock(return_value=b'command response')
|
||||||
conn.send = mock_send
|
conn.send = mock_send
|
||||||
|
|
||||||
# test sending a single command and converting to dict
|
# test sending a single command and converting to dict
|
||||||
rc, out, err = conn.exec_command('command')
|
rc, out, err = conn.exec_command('command')
|
||||||
self.assertEqual(out, 'command response')
|
self.assertEqual(out, b'command response')
|
||||||
self.assertTrue(mock_open_shell.called)
|
self.assertTrue(mock_open_shell.called)
|
||||||
mock_send.assert_called_with({'command': 'command'})
|
mock_send.assert_called_with({'command': b'command'})
|
||||||
|
|
||||||
mock_open_shell.reset_mock()
|
mock_open_shell.reset_mock()
|
||||||
|
|
||||||
# test sending a json string
|
# test sending a json string
|
||||||
rc, out, err = conn.exec_command(json.dumps({'command': 'command'}))
|
rc, out, err = conn.exec_command(json.dumps({'command': 'command'}))
|
||||||
self.assertEqual(out, 'command response')
|
self.assertEqual(out, b'command response')
|
||||||
mock_send.assert_called_with({'command': 'command'})
|
mock_send.assert_called_with({'command': b'command'})
|
||||||
self.assertTrue(mock_open_shell.called)
|
self.assertTrue(mock_open_shell.called)
|
||||||
|
|
||||||
mock_open_shell.reset_mock()
|
mock_open_shell.reset_mock()
|
||||||
|
@ -139,9 +139,9 @@ class TestConnectionClass(unittest.TestCase):
|
||||||
|
|
||||||
# test _shell already open
|
# test _shell already open
|
||||||
rc, out, err = conn.exec_command('command')
|
rc, out, err = conn.exec_command('command')
|
||||||
self.assertEqual(out, 'command response')
|
self.assertEqual(out, b'command response')
|
||||||
self.assertFalse(mock_open_shell.called)
|
self.assertFalse(mock_open_shell.called)
|
||||||
mock_send.assert_called_with({'command': 'command'})
|
mock_send.assert_called_with({'command': b'command'})
|
||||||
|
|
||||||
|
|
||||||
def test_network_cli_send(self):
|
def test_network_cli_send(self):
|
||||||
|
@ -150,14 +150,14 @@ class TestConnectionClass(unittest.TestCase):
|
||||||
conn = network_cli.Connection(pc, new_stdin)
|
conn = network_cli.Connection(pc, new_stdin)
|
||||||
|
|
||||||
mock__terminal = MagicMock()
|
mock__terminal = MagicMock()
|
||||||
mock__terminal.terminal_stdout_re = [re.compile('device#')]
|
mock__terminal.terminal_stdout_re = [re.compile(b'device#')]
|
||||||
mock__terminal.terminal_stderr_re = [re.compile('^ERROR')]
|
mock__terminal.terminal_stderr_re = [re.compile(b'^ERROR')]
|
||||||
conn._terminal = mock__terminal
|
conn._terminal = mock__terminal
|
||||||
|
|
||||||
mock__shell = MagicMock()
|
mock__shell = MagicMock()
|
||||||
conn._shell = mock__shell
|
conn._shell = mock__shell
|
||||||
|
|
||||||
response = """device#command
|
response = b"""device#command
|
||||||
command response
|
command response
|
||||||
|
|
||||||
device#
|
device#
|
||||||
|
@ -165,15 +165,15 @@ class TestConnectionClass(unittest.TestCase):
|
||||||
|
|
||||||
mock__shell.recv.return_value = response
|
mock__shell.recv.return_value = response
|
||||||
|
|
||||||
output = conn.send({'command': 'command'})
|
output = conn.send({'command': b'command'})
|
||||||
|
|
||||||
mock__shell.sendall.assert_called_with('command\r')
|
mock__shell.sendall.assert_called_with(b'command\r')
|
||||||
self.assertEqual(output, 'command response')
|
self.assertEqual(output, b'command response')
|
||||||
|
|
||||||
mock__shell.reset_mock()
|
mock__shell.reset_mock()
|
||||||
mock__shell.recv.return_value = "ERROR: error message"
|
mock__shell.recv.return_value = b"ERROR: error message"
|
||||||
|
|
||||||
with self.assertRaises(AnsibleConnectionFailure) as exc:
|
with self.assertRaises(AnsibleConnectionFailure) as exc:
|
||||||
conn.send({'command': 'command'})
|
conn.send({'command': b'command'})
|
||||||
self.assertEqual(str(exc.exception), 'ERROR: error message')
|
self.assertEqual(str(exc.exception), 'ERROR: error message')
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue