1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Move ssh and local connection plugins from using raw select to selectors

At the moment, this change will use EPoll on Linux, KQueue on *BSDs,
etc, so it should alleviate problems with too many open file
descriptors.

* Bundle a copy of selectors2 so that we have the selectors API everywhere.
* Add licensing information to selectors2 file so it's clear what the
  licensing terms and conditions are.
* Exclude the bundled copy of selectors2 from our boilerplate code-smell test
* Rewrite ssh_run tests to attempt to work around problem with mocking
  select on shippable

Fixes #14143
This commit is contained in:
Toshio Kuratomi 2017-02-01 08:48:23 -08:00
parent 2c70450e23
commit d1a6b07fe1
7 changed files with 1100 additions and 227 deletions

View file

@ -0,0 +1,47 @@
# (c) 2014, 2017 Toshio Kuratomi <tkuratomi@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
'''
Compat selectors library. Python-3.5 has this builtin. The selectors2
package exists on pypi to backport the functionality as far as python-2.6.
'''
# The following makes it easier for us to script updates of the bundled code
_BUNDLED_METADATA = { "pypi_name": "selectors2", "version": "1.1.0" }
import os.path
import sys
try:
# Python 3.4+
import selectors as _system_selectors
except ImportError:
try:
# backport package installed in the system
import selectors2 as _system_selectors
except ImportError:
_system_selectors = None
if _system_selectors:
selectors = _system_selectors
else:
# Our bundled copy
from . import _selectors2 as selectors
sys.modules['ansible.compat.selectors'] = selectors

View file

@ -0,0 +1,667 @@
# This file is from the selectors2.py package. It backports the PSF Licensed
# selectors module from the Python-3.5 stdlib to older versions of Python.
# The author, Seth Michael Larson, dual licenses his modifications under the
# PSF License and MIT License:
# https://github.com/SethMichaelLarson/selectors2#license
#
# Seth's copy of the MIT license is reproduced below
#
# MIT License
#
# Copyright (c) 2016 Seth Michael Larson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Backport of selectors.py from Python 3.5+ to support Python < 3.4
# Also has the behavior specified in PEP 475 which is to retry syscalls
# in the case of an EINTR error. This module is required because selectors34
# does not follow this behavior and instead returns that no dile descriptor
# events have occurred rather than retry the syscall. The decision to drop
# support for select.devpoll is made to maintain 100% test coverage.
import errno
import math
import select
import socket
import sys
import time
from collections import namedtuple, Mapping
try:
monotonic = time.monotonic
except (AttributeError, ImportError): # Python 3.3<
monotonic = time.time
__author__ = 'Seth Michael Larson'
__email__ = 'sethmichaellarson@protonmail.com'
__version__ = '1.1.0'
__license__ = 'MIT'
__all__ = [
'EVENT_READ',
'EVENT_WRITE',
'SelectorError',
'SelectorKey',
'DefaultSelector'
]
EVENT_READ = (1 << 0)
EVENT_WRITE = (1 << 1)
HAS_SELECT = True # Variable that shows whether the platform has a selector.
_SYSCALL_SENTINEL = object() # Sentinel in case a system call returns None.
class SelectorError(Exception):
def __init__(self, errcode):
super(SelectorError, self).__init__()
self.errno = errcode
def __repr__(self):
return "<SelectorError errno={0}>".format(self.errno)
def __str__(self):
return self.__repr__()
def _fileobj_to_fd(fileobj):
""" Return a file descriptor from a file object. If
given an integer will simply return that integer back. """
if isinstance(fileobj, int):
fd = fileobj
else:
try:
fd = int(fileobj.fileno())
except (AttributeError, TypeError, ValueError):
raise ValueError("Invalid file object: {0!r}".format(fileobj))
if fd < 0:
raise ValueError("Invalid file descriptor: {0}".format(fd))
return fd
# Python 3.5 uses a more direct route to wrap system calls to increase speed.
if sys.version_info >= (3, 5):
def _syscall_wrapper(func, _, *args, **kwargs):
""" This is the short-circuit version of the below logic
because in Python 3.5+ all selectors restart system calls. """
try:
return func(*args, **kwargs)
except (OSError, IOError, select.error) as e:
errcode = None
if hasattr(e, "errno"):
errcode = e.errno
elif hasattr(e, "args"):
errcode = e.args[0]
raise SelectorError(errcode)
else:
def _syscall_wrapper(func, recalc_timeout, *args, **kwargs):
""" Wrapper function for syscalls that could fail due to EINTR.
All functions should be retried if there is time left in the timeout
in accordance with PEP 475. """
timeout = kwargs.get("timeout", None)
if timeout is None:
expires = None
recalc_timeout = False
else:
timeout = float(timeout)
if timeout < 0.0: # Timeout less than 0 treated as no timeout.
expires = None
else:
expires = monotonic() + timeout
args = list(args)
if recalc_timeout and "timeout" not in kwargs:
raise ValueError(
"Timeout must be in args or kwargs to be recalculated")
result = _SYSCALL_SENTINEL
while result is _SYSCALL_SENTINEL:
try:
result = func(*args, **kwargs)
# OSError is thrown by select.select
# IOError is thrown by select.epoll.poll
# select.error is thrown by select.poll.poll
# Aren't we thankful for Python 3.x rework for exceptions?
except (OSError, IOError, select.error) as e:
# select.error wasn't a subclass of OSError in the past.
errcode = None
if hasattr(e, "errno"):
errcode = e.errno
elif hasattr(e, "args"):
errcode = e.args[0]
# Also test for the Windows equivalent of EINTR.
is_interrupt = (errcode == errno.EINTR or (hasattr(errno, "WSAEINTR") and
errcode == errno.WSAEINTR))
if is_interrupt:
if expires is not None:
current_time = monotonic()
if current_time > expires:
raise OSError(errno=errno.ETIMEDOUT)
if recalc_timeout:
if "timeout" in kwargs:
kwargs["timeout"] = expires - current_time
continue
if errcode:
raise SelectorError(errcode)
else:
raise
return result
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
class _SelectorMapping(Mapping):
""" Mapping of file objects to selector keys """
def __init__(self, selector):
self._selector = selector
def __len__(self):
return len(self._selector._fd_to_key)
def __getitem__(self, fileobj):
try:
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key[fd]
except KeyError:
raise KeyError("{0!r} is not registered.".format(fileobj))
def __iter__(self):
return iter(self._selector._fd_to_key)
class BaseSelector(object):
""" Abstract Selector class
A selector supports registering file objects to be monitored
for specific I/O events.
A file object is a file descriptor or any object with a
`fileno()` method. An arbitrary object can be attached to the
file object which can be used for example to store context info,
a callback, etc.
A selector can use various implementations (select(), poll(), epoll(),
and kqueue()) depending on the platform. The 'DefaultSelector' class uses
the most efficient implementation for the current platform.
"""
def __init__(self):
# Maps file descriptors to keys.
self._fd_to_key = {}
# Read-only mapping returned by get_map()
self._map = _SelectorMapping(self)
def _fileobj_lookup(self, fileobj):
""" Return a file descriptor from a file object.
This wraps _fileobj_to_fd() to do an exhaustive
search in case the object is invalid but we still
have it in our map. Used by unregister() so we can
unregister an object that was previously registered
even if it is closed. It is also used by _SelectorMapping
"""
try:
return _fileobj_to_fd(fileobj)
except ValueError:
# Search through all our mapped keys.
for key in self._fd_to_key.values():
if key.fileobj is fileobj:
return key.fd
# Raise ValueError after all.
raise
def register(self, fileobj, events, data=None):
""" Register a file object for a set of events to monitor. """
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)):
raise ValueError("Invalid events: {0!r}".format(events))
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
if key.fd in self._fd_to_key:
raise KeyError("{0!r} (FD {1}) is already registered"
.format(fileobj, key.fd))
self._fd_to_key[key.fd] = key
return key
def unregister(self, fileobj):
""" Unregister a file object from being monitored. """
try:
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
# Getting the fileno of a closed socket on Windows errors with EBADF.
except socket.error as err:
if err.errno != errno.EBADF:
raise
else:
for key in self._fd_to_key.values():
if key.fileobj is fileobj:
self._fd_to_key.pop(key.fd)
break
else:
raise KeyError("{0!r} is not registered".format(fileobj))
return key
def modify(self, fileobj, events, data=None):
""" Change a registered file object monitored events and data. """
# NOTE: Some subclasses optimize this operation even further.
try:
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
if events != key.events:
self.unregister(fileobj)
key = self.register(fileobj, events, data)
elif data != key.data:
# Use a shortcut to update the data.
key = key._replace(data=data)
self._fd_to_key[key.fd] = key
return key
def select(self, timeout=None):
""" Perform the actual selection until some monitored file objects
are ready or the timeout expires. """
raise NotImplementedError()
def close(self):
""" Close the selector. This must be called to ensure that all
underlying resources are freed. """
self._fd_to_key.clear()
self._map = None
def get_key(self, fileobj):
""" Return the key associated with a registered file object. """
mapping = self.get_map()
if mapping is None:
raise RuntimeError("Selector is closed")
try:
return mapping[fileobj]
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
def get_map(self):
""" Return a mapping of file objects to selector keys """
return self._map
def _key_from_fd(self, fd):
""" Return the key associated to a given file descriptor
Return None if it is not found. """
try:
return self._fd_to_key[fd]
except KeyError:
return None
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
# Almost all platforms have select.select()
if hasattr(select, "select"):
class SelectSelector(BaseSelector):
""" Select-based selector. """
def __init__(self):
super(SelectSelector, self).__init__()
self._readers = set()
self._writers = set()
def register(self, fileobj, events, data=None):
key = super(SelectSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
self._readers.add(key.fd)
if events & EVENT_WRITE:
self._writers.add(key.fd)
return key
def unregister(self, fileobj):
key = super(SelectSelector, self).unregister(fileobj)
self._readers.discard(key.fd)
self._writers.discard(key.fd)
return key
def _select(self, r, w, timeout=None):
""" Wrapper for select.select because timeout is a positional arg """
return select.select(r, w, [], timeout)
def select(self, timeout=None):
# Selecting on empty lists on Windows errors out.
if not len(self._readers) and not len(self._writers):
return []
timeout = None if timeout is None else max(timeout, 0.0)
ready = []
r, w, _ = _syscall_wrapper(self._select, True, self._readers,
self._writers, timeout)
r = set(r)
w = set(w)
for fd in r | w:
events = 0
if fd in r:
events |= EVENT_READ
if fd in w:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
__all__.append('SelectSelector')
if hasattr(select, "poll"):
class PollSelector(BaseSelector):
""" Poll-based selector """
def __init__(self):
super(PollSelector, self).__init__()
self._poll = select.poll()
def register(self, fileobj, events, data=None):
key = super(PollSelector, self).register(fileobj, events, data)
event_mask = 0
if events & EVENT_READ:
event_mask |= select.POLLIN
if events & EVENT_WRITE:
event_mask |= select.POLLOUT
self._poll.register(key.fd, event_mask)
return key
def unregister(self, fileobj):
key = super(PollSelector, self).unregister(fileobj)
self._poll.unregister(key.fd)
return key
def _wrap_poll(self, timeout=None):
""" Wrapper function for select.poll.poll() so that
_syscall_wrapper can work with only seconds. """
if timeout is not None:
if timeout <= 0:
timeout = 0
else:
# select.poll.poll() has a resolution of 1 millisecond,
# round away from zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
result = self._poll.poll(timeout)
return result
def select(self, timeout=None):
ready = []
fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout)
for fd, event_mask in fd_events:
events = 0
if event_mask & ~select.POLLIN:
events |= EVENT_WRITE
if event_mask & ~select.POLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
__all__.append('PollSelector')
if hasattr(select, "epoll"):
class EpollSelector(BaseSelector):
""" Epoll-based selector """
def __init__(self):
super(EpollSelector, self).__init__()
self._epoll = select.epoll()
def fileno(self):
return self._epoll.fileno()
def register(self, fileobj, events, data=None):
key = super(EpollSelector, self).register(fileobj, events, data)
events_mask = 0
if events & EVENT_READ:
events_mask |= select.EPOLLIN
if events & EVENT_WRITE:
events_mask |= select.EPOLLOUT
_syscall_wrapper(self._epoll.register, False, key.fd, events_mask)
return key
def unregister(self, fileobj):
key = super(EpollSelector, self).unregister(fileobj)
try:
_syscall_wrapper(self._epoll.unregister, False, key.fd)
except SelectorError:
# This can occur when the fd was closed since registry.
pass
return key
def select(self, timeout=None):
if timeout is not None:
if timeout <= 0:
timeout = 0.0
else:
# select.epoll.poll() has a resolution of 1 millisecond
# but luckily takes seconds so we don't need a wrapper
# like PollSelector. Just for better rounding.
timeout = math.ceil(timeout * 1e3) * 1e-3
timeout = float(timeout)
else:
timeout = -1.0 # epoll.poll() must have a float.
# We always want at least 1 to ensure that select can be called
# with no file descriptors registered. Otherwise will fail.
max_events = max(len(self._fd_to_key), 1)
ready = []
fd_events = _syscall_wrapper(self._epoll.poll, True,
timeout=timeout,
maxevents=max_events)
for fd, event_mask in fd_events:
events = 0
if event_mask & ~select.EPOLLIN:
events |= EVENT_WRITE
if event_mask & ~select.EPOLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._epoll.close()
super(EpollSelector, self).close()
__all__.append('EpollSelector')
if hasattr(select, "devpoll"):
class DevpollSelector(BaseSelector):
"""Solaris /dev/poll selector."""
def __init__(self):
super(DevpollSelector, self).__init__()
self._devpoll = select.devpoll()
def fileno(self):
return self._devpoll.fileno()
def register(self, fileobj, events, data=None):
key = super(DevpollSelector, self).register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
if events & EVENT_WRITE:
poll_events |= select.POLLOUT
self._devpoll.register(key.fd, poll_events)
return key
def unregister(self, fileobj):
key = super(DevpollSelector, self).unregister(fileobj)
self._devpoll.unregister(key.fd)
return key
def _wrap_poll(self, timeout=None):
""" Wrapper function for select.poll.poll() so that
_syscall_wrapper can work with only seconds. """
if timeout is not None:
if timeout <= 0:
timeout = 0
else:
# select.devpoll.poll() has a resolution of 1 millisecond,
# round away from zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
result = self._devpoll.poll(timeout)
return result
def select(self, timeout=None):
ready = []
fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout)
for fd, event_mask in fd_events:
events = 0
if event_mask & ~select.POLLIN:
events |= EVENT_WRITE
if event_mask & ~select.POLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._devpoll.close()
super(DevpollSelector, self).close()
__all__.append('DevpollSelector')
if hasattr(select, "kqueue"):
class KqueueSelector(BaseSelector):
""" Kqueue / Kevent-based selector """
def __init__(self):
super(KqueueSelector, self).__init__()
self._kqueue = select.kqueue()
def fileno(self):
return self._kqueue.fileno()
def register(self, fileobj, events, data=None):
key = super(KqueueSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
kevent = select.kevent(key.fd,
select.KQ_FILTER_READ,
select.KQ_EV_ADD)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
if events & EVENT_WRITE:
kevent = select.kevent(key.fd,
select.KQ_FILTER_WRITE,
select.KQ_EV_ADD)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
return key
def unregister(self, fileobj):
key = super(KqueueSelector, self).unregister(fileobj)
if key.events & EVENT_READ:
kevent = select.kevent(key.fd,
select.KQ_FILTER_READ,
select.KQ_EV_DELETE)
try:
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
except SelectorError:
pass
if key.events & EVENT_WRITE:
kevent = select.kevent(key.fd,
select.KQ_FILTER_WRITE,
select.KQ_EV_DELETE)
try:
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
except SelectorError:
pass
return key
def select(self, timeout=None):
if timeout is not None:
timeout = max(timeout, 0)
max_events = len(self._fd_to_key) * 2
ready_fds = {}
kevent_list = _syscall_wrapper(self._kqueue.control, True,
None, max_events, timeout)
for kevent in kevent_list:
fd = kevent.ident
event_mask = kevent.filter
events = 0
if event_mask == select.KQ_FILTER_READ:
events |= EVENT_READ
if event_mask == select.KQ_FILTER_WRITE:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
if key.fd not in ready_fds:
ready_fds[key.fd] = (key, events & key.events)
else:
old_events = ready_fds[key.fd][1]
ready_fds[key.fd] = (key, (events | old_events) & key.events)
return list(ready_fds.values())
def close(self):
self._kqueue.close()
super(KqueueSelector, self).close()
__all__.append('KqueueSelector')
# Choose the best implementation, roughly:
# kqueue == epoll == devpoll > poll > select.
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
if 'KqueueSelector' in globals(): # Platform-specific: Mac OS and BSD
DefaultSelector = KqueueSelector
if 'DevpollSelector' in globals():
DefaultSelector = DevpollSelector
elif 'EpollSelector' in globals(): # Platform-specific: Linux
DefaultSelector = EpollSelector
elif 'PollSelector' in globals(): # Platform-specific: Linux
DefaultSelector = PollSelector
elif 'SelectSelector' in globals(): # Platform-specific: Windows
DefaultSelector = SelectSelector
else: # Platform-specific: AppEngine
def no_selector(_):
raise ValueError("Platform does not have a selector")
DefaultSelector = no_selector
HAS_SELECT = False

View file

@ -1,5 +1,5 @@
# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com> # (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com> # (c) 2015, 2017 Toshio Kuratomi <tkuratomi@ansible.com>
# #
# This file is part of Ansible # This file is part of Ansible
# #
@ -19,16 +19,14 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import os import os
import select
import shutil import shutil
import subprocess import subprocess
import fcntl import fcntl
import getpass import getpass
from ansible.compat.six import text_type, binary_type
import ansible.constants as C import ansible.constants as C
from ansible.compat import selectors
from ansible.compat.six import text_type, binary_type
from ansible.errors import AnsibleError, AnsibleFileNotFound from ansible.errors import AnsibleError, AnsibleFileNotFound
from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils._text import to_bytes, to_native
from ansible.plugins.connection import ConnectionBase from ansible.plugins.connection import ConnectionBase
@ -90,21 +88,31 @@ class Connection(ConnectionBase):
if self._play_context.prompt and sudoable: if self._play_context.prompt and sudoable:
fcntl.fcntl(p.stdout, fcntl.F_SETFL, fcntl.fcntl(p.stdout, fcntl.F_GETFL) | os.O_NONBLOCK) fcntl.fcntl(p.stdout, fcntl.F_SETFL, fcntl.fcntl(p.stdout, fcntl.F_GETFL) | os.O_NONBLOCK)
fcntl.fcntl(p.stderr, fcntl.F_SETFL, fcntl.fcntl(p.stderr, fcntl.F_GETFL) | os.O_NONBLOCK) fcntl.fcntl(p.stderr, fcntl.F_SETFL, fcntl.fcntl(p.stderr, fcntl.F_GETFL) | os.O_NONBLOCK)
become_output = b'' selector = selectors.DefaultSelector()
while not self.check_become_success(become_output) and not self.check_password_prompt(become_output): selector.register(p.stdout, selectors.EVENT_READ)
selector.register(p.stderr, selectors.EVENT_READ)
become_output = b''
try:
while not self.check_become_success(become_output) and not self.check_password_prompt(become_output):
events = selector.select(self._play_context.timeout)
if not events:
stdout, stderr = p.communicate()
raise AnsibleError('timeout waiting for privilege escalation password prompt:\n' + to_native(become_output))
for key, event in events:
if key.fileobj == p.stdout:
chunk = p.stdout.read()
elif key.fileobj == p.stderr:
chunk = p.stderr.read()
if not chunk:
stdout, stderr = p.communicate()
raise AnsibleError('privilege output closed while waiting for password prompt:\n' + to_native(become_output))
become_output += chunk
finally:
selector.close()
rfd, wfd, efd = select.select([p.stdout, p.stderr], [], [p.stdout, p.stderr], self._play_context.timeout)
if p.stdout in rfd:
chunk = p.stdout.read()
elif p.stderr in rfd:
chunk = p.stderr.read()
else:
stdout, stderr = p.communicate()
raise AnsibleError('timeout waiting for privilege escalation password prompt:\n' + to_native(become_output))
if not chunk:
stdout, stderr = p.communicate()
raise AnsibleError('privilege output closed while waiting for password prompt:\n' + to_native(become_output))
become_output += chunk
if not self.check_become_success(become_output): if not self.check_become_success(become_output):
p.stdin.write(to_bytes(self._play_context.become_pass, errors='surrogate_or_strict') + b'\n') p.stdin.write(to_bytes(self._play_context.become_pass, errors='surrogate_or_strict') + b'\n')
fcntl.fcntl(p.stdout, fcntl.F_SETFL, fcntl.fcntl(p.stdout, fcntl.F_GETFL) & ~os.O_NONBLOCK) fcntl.fcntl(p.stdout, fcntl.F_SETFL, fcntl.fcntl(p.stdout, fcntl.F_GETFL) & ~os.O_NONBLOCK)

View file

@ -1,5 +1,6 @@
# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com> # (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
# Copyright 2015 Abhijit Menon-Sen <ams@2ndQuadrant.com> # Copyright 2015 Abhijit Menon-Sen <ams@2ndQuadrant.com>
# Copyright 2017 Toshio Kuratomi <tkuratomi@ansible.com>
# #
# This file is part of Ansible # This file is part of Ansible
# #
@ -24,11 +25,11 @@ import fcntl
import hashlib import hashlib
import os import os
import pty import pty
import select
import subprocess import subprocess
import time import time
from ansible import constants as C from ansible import constants as C
from ansible.compat import selectors
from ansible.compat.six import PY3, text_type, binary_type from ansible.compat.six import PY3, text_type, binary_type
from ansible.compat.six.moves import shlex_quote from ansible.compat.six.moves import shlex_quote
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
@ -443,148 +444,158 @@ class Connection(ConnectionBase):
# they will race each other when we can't connect, and the connect # they will race each other when we can't connect, and the connect
# timeout usually fails # timeout usually fails
timeout = 2 + self._play_context.timeout timeout = 2 + self._play_context.timeout
rpipes = [p.stdout, p.stderr] for fd in (p.stdout, p.stderr):
for fd in rpipes:
fcntl.fcntl(fd, fcntl.F_SETFL, fcntl.fcntl(fd, fcntl.F_GETFL) | os.O_NONBLOCK) fcntl.fcntl(fd, fcntl.F_SETFL, fcntl.fcntl(fd, fcntl.F_GETFL) | os.O_NONBLOCK)
# If we can send initial data without waiting for anything, we do so ### TODO: bcoca would like to use SelectSelector() when open
# before we call select. # filehandles is low, then switch to more efficient ones when higher.
# select is faster when filehandles is low.
selector = selectors.DefaultSelector()
selector.register(p.stdout, selectors.EVENT_READ)
selector.register(p.stderr, selectors.EVENT_READ)
# If we can send initial data without waiting for anything, we do so
# before we start polling
if states[state] == 'ready_to_send' and in_data: if states[state] == 'ready_to_send' and in_data:
self._send_initial_data(stdin, in_data) self._send_initial_data(stdin, in_data)
state += 1 state += 1
while True: try:
rfd, wfd, efd = select.select(rpipes, [], [], timeout) while True:
events = selector.select(timeout)
# We pay attention to timeouts only while negotiating a prompt. # We pay attention to timeouts only while negotiating a prompt.
if not rfd: if not events:
if state <= states.index('awaiting_escalation'): # We timed out
# If the process has already exited, then it's not really a if state <= states.index('awaiting_escalation'):
# timeout; we'll let the normal error handling deal with it. # If the process has already exited, then it's not really a
if p.poll() is not None: # timeout; we'll let the normal error handling deal with it.
if p.poll() is not None:
break
self._terminate_process(p)
raise AnsibleError('Timeout (%ds) waiting for privilege escalation prompt: %s' % (timeout, to_native(b_stdout)))
# Read whatever output is available on stdout and stderr, and stop
# listening to the pipe if it's been closed.
for key, event in events:
if key.fileobj == p.stdout:
b_chunk = p.stdout.read()
if b_chunk == b'':
# stdout has been closed, stop watching it
selector.unregister(p.stdout)
# When ssh has ControlMaster (+ControlPath/Persist) enabled, the
# first connection goes into the background and we never see EOF
# on stderr. If we see EOF on stdout, lower the select timeout
# to reduce the time wasted selecting on stderr if we observe
# that the process has not yet existed after this EOF. Otherwise
# we may spend a long timeout period waiting for an EOF that is
# not going to arrive until the persisted connection closes.
timeout = 1
b_tmp_stdout += b_chunk
display.debug("stdout chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk)))
elif key.fileobj == p.stderr:
b_chunk = p.stderr.read()
if b_chunk == b'':
# stderr has been closed, stop watching it
selector.unregister(p.stderr)
b_tmp_stderr += b_chunk
display.debug("stderr chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk)))
# We examine the output line-by-line until we have negotiated any
# privilege escalation prompt and subsequent success/error message.
# Afterwards, we can accumulate output without looking at it.
if state < states.index('ready_to_send'):
if b_tmp_stdout:
b_output, b_unprocessed = self._examine_output('stdout', states[state], b_tmp_stdout, sudoable)
b_stdout += b_output
b_tmp_stdout = b_unprocessed
if b_tmp_stderr:
b_output, b_unprocessed = self._examine_output('stderr', states[state], b_tmp_stderr, sudoable)
b_stderr += b_output
b_tmp_stderr = b_unprocessed
else:
b_stdout += b_tmp_stdout
b_stderr += b_tmp_stderr
b_tmp_stdout = b_tmp_stderr = b''
# If we see a privilege escalation prompt, we send the password.
# (If we're expecting a prompt but the escalation succeeds, we
# didn't need the password and can carry on regardless.)
if states[state] == 'awaiting_prompt':
if self._flags['become_prompt']:
display.debug('Sending become_pass in response to prompt')
stdin.write(to_bytes(self._play_context.become_pass) + b'\n')
self._flags['become_prompt'] = False
state += 1
elif self._flags['become_success']:
state += 1
# We've requested escalation (with or without a password), now we
# wait for an error message or a successful escalation.
if states[state] == 'awaiting_escalation':
if self._flags['become_success']:
display.debug('Escalation succeeded')
self._flags['become_success'] = False
state += 1
elif self._flags['become_error']:
display.debug('Escalation failed')
self._terminate_process(p)
self._flags['become_error'] = False
raise AnsibleError('Incorrect %s password' % self._play_context.become_method)
elif self._flags['become_nopasswd_error']:
display.debug('Escalation requires password')
self._terminate_process(p)
self._flags['become_nopasswd_error'] = False
raise AnsibleError('Missing %s password' % self._play_context.become_method)
elif self._flags['become_prompt']:
# This shouldn't happen, because we should see the "Sorry,
# try again" message first.
display.debug('Escalation prompt repeated')
self._terminate_process(p)
self._flags['become_prompt'] = False
raise AnsibleError('Incorrect %s password' % self._play_context.become_method)
# Once we're sure that the privilege escalation prompt, if any, has
# been dealt with, we can send any initial data and start waiting
# for output.
if states[state] == 'ready_to_send':
if in_data:
self._send_initial_data(stdin, in_data)
state += 1
# Now we're awaiting_exit: has the child process exited? If it has,
# and we've read all available output from it, we're done.
if p.poll() is not None:
if not selector.get_map() or not events:
break break
self._terminate_process(p) # We should not see further writes to the stdout/stderr file
raise AnsibleError('Timeout (%ds) waiting for privilege escalation prompt: %s' % (timeout, to_native(b_stdout))) # descriptors after the process has closed, set the select
# timeout to gather any last writes we may have missed.
timeout = 0
continue
# Read whatever output is available on stdout and stderr, and stop # If the process has not yet exited, but we've already read EOF from
# listening to the pipe if it's been closed. # its stdout and stderr (and thus no longer watching any file
# descriptors), we can just wait for it to exit.
if p.stdout in rfd: elif not selector.get_map():
b_chunk = p.stdout.read() p.wait()
if b_chunk == b'':
rpipes.remove(p.stdout)
# When ssh has ControlMaster (+ControlPath/Persist) enabled, the
# first connection goes into the background and we never see EOF
# on stderr. If we see EOF on stdout, lower the select timeout
# to reduce the time wasted selecting on stderr if we observe
# that the process has not yet existed after this EOF. Otherwise
# we may spend a long timeout period waiting for an EOF that is
# not going to arrive until the persisted connection closes.
timeout = 1
b_tmp_stdout += b_chunk
display.debug("stdout chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk)))
if p.stderr in rfd:
b_chunk = p.stderr.read()
if b_chunk == b'':
rpipes.remove(p.stderr)
b_tmp_stderr += b_chunk
display.debug("stderr chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk)))
# We examine the output line-by-line until we have negotiated any
# privilege escalation prompt and subsequent success/error message.
# Afterwards, we can accumulate output without looking at it.
if state < states.index('ready_to_send'):
if b_tmp_stdout:
b_output, b_unprocessed = self._examine_output('stdout', states[state], b_tmp_stdout, sudoable)
b_stdout += b_output
b_tmp_stdout = b_unprocessed
if b_tmp_stderr:
b_output, b_unprocessed = self._examine_output('stderr', states[state], b_tmp_stderr, sudoable)
b_stderr += b_output
b_tmp_stderr = b_unprocessed
else:
b_stdout += b_tmp_stdout
b_stderr += b_tmp_stderr
b_tmp_stdout = b_tmp_stderr = b''
# If we see a privilege escalation prompt, we send the password.
# (If we're expecting a prompt but the escalation succeeds, we
# didn't need the password and can carry on regardless.)
if states[state] == 'awaiting_prompt':
if self._flags['become_prompt']:
display.debug('Sending become_pass in response to prompt')
stdin.write(to_bytes(self._play_context.become_pass) + b'\n')
self._flags['become_prompt'] = False
state += 1
elif self._flags['become_success']:
state += 1
# We've requested escalation (with or without a password), now we
# wait for an error message or a successful escalation.
if states[state] == 'awaiting_escalation':
if self._flags['become_success']:
display.debug('Escalation succeeded')
self._flags['become_success'] = False
state += 1
elif self._flags['become_error']:
display.debug('Escalation failed')
self._terminate_process(p)
self._flags['become_error'] = False
raise AnsibleError('Incorrect %s password' % self._play_context.become_method)
elif self._flags['become_nopasswd_error']:
display.debug('Escalation requires password')
self._terminate_process(p)
self._flags['become_nopasswd_error'] = False
raise AnsibleError('Missing %s password' % self._play_context.become_method)
elif self._flags['become_prompt']:
# This shouldn't happen, because we should see the "Sorry,
# try again" message first.
display.debug('Escalation prompt repeated')
self._terminate_process(p)
self._flags['become_prompt'] = False
raise AnsibleError('Incorrect %s password' % self._play_context.become_method)
# Once we're sure that the privilege escalation prompt, if any, has
# been dealt with, we can send any initial data and start waiting
# for output.
if states[state] == 'ready_to_send':
if in_data:
self._send_initial_data(stdin, in_data)
state += 1
# Now we're awaiting_exit: has the child process exited? If it has,
# and we've read all available output from it, we're done.
if p.poll() is not None:
if not rpipes or not rfd:
break break
# We should not see further writes to the stdout/stderr file
# descriptors after the process has closed, set the select
# timeout to gather any last writes we may have missed.
timeout = 0
continue
# If the process has not yet exited, but we've already read EOF from # Otherwise there may still be outstanding data to read.
# its stdout and stderr (and thus removed both from rpipes), we can finally:
# just wait for it to exit. selector.close()
# close stdin after process is terminated and stdout/stderr are read
elif not rpipes: # completely (see also issue #848)
p.wait() stdin.close()
break
# Otherwise there may still be outstanding data to read.
# close stdin after process is terminated and stdout/stderr are read
# completely (see also issue #848)
stdin.close()
if C.HOST_KEY_CHECKING: if C.HOST_KEY_CHECKING:
if cmd[0] == b"sshpass" and p.returncode == 6: if cmd[0] == b"sshpass" and p.returncode == 6:

View file

@ -18,8 +18,8 @@ setup(name='ansible',
author_email='info@ansible.com', author_email='info@ansible.com',
url='http://ansible.com/', url='http://ansible.com/',
license='GPLv3', license='GPLv3',
# Ansible will also make use of a system copy of python-six if installed but use a # Ansible will also make use of a system copy of python-six and
# Bundled copy if it's not. # python-selectors2 if installed but use a Bundled copy if it's not.
install_requires=['paramiko', 'jinja2', "PyYAML", 'setuptools', 'pycrypto >= 2.6'], install_requires=['paramiko', 'jinja2', "PyYAML", 'setuptools', 'pycrypto >= 2.6'],
package_dir={ '': 'lib' }, package_dir={ '': 'lib' },
packages=find_packages('lib'), packages=find_packages('lib'),

View file

@ -7,6 +7,7 @@ metaclass2=$(find ./lib/ansible -path ./lib/ansible/modules -prune \
-o -path ./lib/ansible/modules/__init__.py \ -o -path ./lib/ansible/modules/__init__.py \
-o -path ./lib/ansible/module_utils -prune \ -o -path ./lib/ansible/module_utils -prune \
-o -path ./lib/ansible/compat/six/_six.py -prune \ -o -path ./lib/ansible/compat/six/_six.py -prune \
-o -path ./lib/ansible/compat/selectors/_selectors2.py -prune \
-o -path ./lib/ansible/utils/module_docs_fragments -prune \ -o -path ./lib/ansible/utils/module_docs_fragments -prune \
-o -name '*.py' -exec grep -HL '__metaclass__ = type' '{}' '+') -o -name '*.py' -exec grep -HL '__metaclass__ = type' '{}' '+')
@ -14,6 +15,7 @@ future2=$(find ./lib/ansible -path ./lib/ansible/modules -prune \
-o -path ./lib/ansible/modules/__init__.py \ -o -path ./lib/ansible/modules/__init__.py \
-o -path ./lib/ansible/module_utils -prune \ -o -path ./lib/ansible/module_utils -prune \
-o -path ./lib/ansible/compat/six/_six.py -prune \ -o -path ./lib/ansible/compat/six/_six.py -prune \
-o -path ./lib/ansible/compat/selectors/_selectors2.py -prune \
-o -path ./lib/ansible/utils/module_docs_fragments -prune \ -o -path ./lib/ansible/utils/module_docs_fragments -prune \
-o -name '*.py' -exec grep -HL 'from __future__ import (absolute_import, division, print_function)' '{}' '+') -o -name '*.py' -exec grep -HL 'from __future__ import (absolute_import, division, print_function)' '{}' '+')

View file

@ -23,10 +23,13 @@ __metaclass__ = type
from io import StringIO from io import StringIO
import pytest
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock from ansible.compat.tests.mock import patch, MagicMock
from ansible import constants as C from ansible import constants as C
from ansible.compat.selectors import SelectorKey, EVENT_READ
from ansible.compat.six.moves import shlex_quote from ansible.compat.six.moves import shlex_quote
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
from ansible.playbook.play_context import PlayContext from ansible.playbook.play_context import PlayContext
@ -83,82 +86,6 @@ class TestConnectionBaseClass(unittest.TestCase):
res, stdout, stderr = conn._exec_command('ssh') res, stdout, stderr = conn._exec_command('ssh')
res, stdout, stderr = conn._exec_command('ssh', 'this is some data') res, stdout, stderr = conn._exec_command('ssh', 'this is some data')
@patch('select.select')
@patch('fcntl.fcntl')
@patch('os.write')
@patch('os.close')
@patch('pty.openpty')
@patch('subprocess.Popen')
def test_plugins_connection_ssh__run(self, mock_Popen, mock_openpty, mock_osclose, mock_oswrite, mock_fcntl, mock_select):
pc = PlayContext()
new_stdin = StringIO()
conn = ssh.Connection(pc, new_stdin)
conn._send_initial_data = MagicMock()
conn._examine_output = MagicMock()
conn._terminate_process = MagicMock()
conn.sshpass_pipe = [MagicMock(), MagicMock()]
mock_popen_res = MagicMock()
mock_popen_res.poll = MagicMock()
mock_popen_res.wait = MagicMock()
mock_popen_res.stdin = MagicMock()
mock_popen_res.stdin.fileno.return_value = 1000
mock_popen_res.stdout = MagicMock()
mock_popen_res.stdout.fileno.return_value = 1001
mock_popen_res.stderr = MagicMock()
mock_popen_res.stderr.fileno.return_value = 1002
mock_popen_res.return_code = 0
mock_Popen.return_value = mock_popen_res
def _mock_select(rlist, wlist, elist, timeout=None):
rvals = []
if mock_popen_res.stdin in rlist:
rvals.append(mock_popen_res.stdin)
if mock_popen_res.stderr in rlist:
rvals.append(mock_popen_res.stderr)
return (rvals, [], [])
mock_select.side_effect = _mock_select
mock_popen_res.stdout.read.side_effect = [b"some data", b""]
mock_popen_res.stderr.read.side_effect = [b""]
conn._run("ssh", "this is input data")
# test with a password set to trigger the sshpass write
pc.password = '12345'
mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
mock_popen_res.stderr.read.side_effect = [b""]
conn._run(["ssh", "is", "a", "cmd"], "this is more data")
# test with password prompting enabled
pc.password = None
pc.prompt = True
mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
mock_popen_res.stderr.read.side_effect = [b""]
conn._run("ssh", "this is input data")
# test with some become settings
pc.prompt = False
pc.become = True
pc.success_key = 'BECOME-SUCCESS-abcdefg'
mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
mock_popen_res.stderr.read.side_effect = [b""]
conn._run("ssh", "this is input data")
# simulate no data input
mock_openpty.return_value = (98, 99)
mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
mock_popen_res.stderr.read.side_effect = [b""]
conn._run("ssh", "")
# simulate no data input but Popen using new pty's fails
mock_Popen.return_value = None
mock_Popen.side_effect = [OSError(), mock_popen_res]
mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
mock_popen_res.stderr.read.side_effect = [b""]
conn._run("ssh", "")
def test_plugins_connection_ssh__examine_output(self): def test_plugins_connection_ssh__examine_output(self):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -341,7 +268,6 @@ class TestConnectionBaseClass(unittest.TestCase):
conn.put_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩') conn.put_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩')
conn._run.assert_called_with('some command to run', expected_in_data, checkrc=False) conn._run.assert_called_with('some command to run', expected_in_data, checkrc=False)
# test that a non-zero rc raises an error # test that a non-zero rc raises an error
conn._run.return_value = (1, 'stdout', 'some errors') conn._run.return_value = (1, 'stdout', 'some errors')
self.assertRaises(AnsibleError, conn.put_file, '/path/to/bad/file', '/remote/path/to/file') self.assertRaises(AnsibleError, conn.put_file, '/path/to/bad/file', '/remote/path/to/file')
@ -398,3 +324,215 @@ class TestConnectionBaseClass(unittest.TestCase):
# test that a non-zero rc raises an error # test that a non-zero rc raises an error
conn._run.return_value = (1, 'stdout', 'some errors') conn._run.return_value = (1, 'stdout', 'some errors')
self.assertRaises(AnsibleError, conn.fetch_file, '/path/to/bad/file', '/remote/path/to/file') self.assertRaises(AnsibleError, conn.fetch_file, '/path/to/bad/file', '/remote/path/to/file')
class MockSelector(object):
def __init__(self):
self.files_watched = 0
self.register = MagicMock(side_effect=self._register)
self.unregister = MagicMock(side_effect=self._unregister)
self.close = MagicMock()
self.get_map = MagicMock(side_effect=self._get_map)
self.select = MagicMock()
def _register(self, *args, **kwargs):
self.files_watched += 1
def _unregister(self, *args, **kwargs):
self.files_watched -= 1
def _get_map(self, *args, **kwargs):
return self.files_watched
@pytest.fixture
def mock_run_env(request, mocker):
pc = PlayContext()
new_stdin = StringIO()
conn = ssh.Connection(pc, new_stdin)
conn._send_initial_data = MagicMock()
conn._examine_output = MagicMock()
conn._terminate_process = MagicMock()
conn.sshpass_pipe = [MagicMock(), MagicMock()]
request.cls.pc = pc
request.cls.conn = conn
mock_popen_res = MagicMock()
mock_popen_res.poll = MagicMock()
mock_popen_res.wait = MagicMock()
mock_popen_res.stdin = MagicMock()
mock_popen_res.stdin.fileno.return_value = 1000
mock_popen_res.stdout = MagicMock()
mock_popen_res.stdout.fileno.return_value = 1001
mock_popen_res.stderr = MagicMock()
mock_popen_res.stderr.fileno.return_value = 1002
mock_popen_res.returncode = 0
request.cls.mock_popen_res = mock_popen_res
mock_popen = mocker.patch('subprocess.Popen', return_value=mock_popen_res)
request.cls.mock_popen = mock_popen
request.cls.mock_selector = MockSelector()
mocker.patch('ansible.compat.selectors.DefaultSelector', lambda: request.cls.mock_selector)
request.cls.mock_openpty = mocker.patch('pty.openpty')
mocker.patch('fcntl.fcntl')
mocker.patch('os.write')
mocker.patch('os.close')
@pytest.mark.usefixtures('mock_run_env')
class TestSSHConnectionRun(object):
# FIXME:
# These tests are little more than a smoketest. Need to enhance them
# a bit to check that they're calling the relevant functions and making
# complete coverage of the code paths
def test_no_escalation(self):
self.mock_popen_res.stdout.read.side_effect = [b"my_stdout\n", b"second_line"]
self.mock_popen_res.stderr.read.side_effect = [b"my_stderr"]
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[]]
self.mock_selector.get_map.side_effect = lambda: True
return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data")
assert return_code == 0
assert b_stdout == b'my_stdout\nsecond_line'
assert b_stderr == b'my_stderr'
assert self.mock_selector.register.called is True
assert self.mock_selector.register.call_count == 2
assert self.conn._send_initial_data.called is True
assert self.conn._send_initial_data.call_count == 1
assert self.conn._send_initial_data.call_args[0][1] == 'this is input data'
def test_with_password(self):
# test with a password set to trigger the sshpass write
self.pc.password = '12345'
self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
self.mock_popen_res.stderr.read.side_effect = [b""]
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[]]
self.mock_selector.get_map.side_effect = lambda: True
return_code, b_stdout, b_stderr = self.conn._run(["ssh", "is", "a", "cmd"], "this is more data")
assert return_code == 0
assert b_stdout == b'some data'
assert b_stderr == b''
assert self.mock_selector.register.called is True
assert self.mock_selector.register.call_count == 2
assert self.conn._send_initial_data.called is True
assert self.conn._send_initial_data.call_count == 1
assert self.conn._send_initial_data.call_args[0][1] == 'this is more data'
def _password_with_prompt_examine_output(self, sourice, state, b_chunk, sudoable):
if state == 'awaiting_prompt':
self.conn._flags['become_prompt'] = True
elif state == 'awaiting_escalation':
self.conn._flags['become_success'] = True
return (b'', b'')
def test_pasword_with_prompt(self):
# test with password prompting enabled
self.pc.password = None
self.pc.prompt = b'Password:'
self.conn._examine_output.side_effect = self._password_with_prompt_examine_output
self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"Success", b""]
self.mock_popen_res.stderr.read.side_effect = [b""]
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ),
(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[]]
self.mock_selector.get_map.side_effect = lambda: True
return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data")
assert return_code == 0
assert b_stdout == b''
assert b_stderr == b''
assert self.mock_selector.register.called is True
assert self.mock_selector.register.call_count == 2
assert self.conn._send_initial_data.called is True
assert self.conn._send_initial_data.call_count == 1
assert self.conn._send_initial_data.call_args[0][1] == 'this is input data'
def test_pasword_with_become(self):
# test with some become settings
self.pc.prompt = b'Password:'
self.pc.become = True
self.pc.success_key = 'BECOME-SUCCESS-abcdefg'
self.conn._examine_output.side_effect = self._password_with_prompt_examine_output
self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"BECOME-SUCCESS-abcdefg", b"abc"]
self.mock_popen_res.stderr.read.side_effect = [b"123"]
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[]]
self.mock_selector.get_map.side_effect = lambda: True
return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data")
assert return_code == 0
assert b_stdout == b'abc'
assert b_stderr == b'123'
assert self.mock_selector.register.called is True
assert self.mock_selector.register.call_count == 2
assert self.conn._send_initial_data.called is True
assert self.conn._send_initial_data.call_count == 1
assert self.conn._send_initial_data.call_args[0][1] == 'this is input data'
def test_pasword_without_data(self):
# simulate no data input
self.mock_openpty.return_value = (98, 99)
self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
self.mock_popen_res.stderr.read.side_effect = [b""]
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[]]
self.mock_selector.get_map.side_effect = lambda: True
return_code, b_stdout, b_stderr = self.conn._run("ssh", "")
assert return_code == 0
assert b_stdout == b'some data'
assert b_stderr == b''
assert self.mock_selector.register.called is True
assert self.mock_selector.register.call_count == 2
assert self.conn._send_initial_data.called is False
def test_pasword_without_data(self):
# simulate no data input but Popen using new pty's fails
self.mock_popen.return_value = None
self.mock_popen.side_effect = [OSError(), self.mock_popen_res]
# simulate no data input
self.mock_openpty.return_value = (98, 99)
self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
self.mock_popen_res.stderr.read.side_effect = [b""]
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[]]
self.mock_selector.get_map.side_effect = lambda: True
return_code, b_stdout, b_stderr = self.conn._run("ssh", "")
assert return_code == 0
assert b_stdout == b'some data'
assert b_stderr == b''
assert self.mock_selector.register.called is True
assert self.mock_selector.register.call_count == 2
assert self.conn._send_initial_data.called is False