mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Enable paramiko to ask whether to add keys to known hosts.
This commit is contained in:
parent
cff4ab511c
commit
ffadbc520a
2 changed files with 64 additions and 14 deletions
|
@ -53,7 +53,7 @@ multiprocessing_runner = None
|
|||
|
||||
################################################
|
||||
|
||||
def _executor_hook(job_queue, result_queue):
|
||||
def _executor_hook(job_queue, result_queue, new_stdin):
|
||||
|
||||
# attempt workaround of https://github.com/newsapps/beeswithmachineguns/issues/17
|
||||
# this function also not present in CentOS 6
|
||||
|
@ -64,7 +64,7 @@ def _executor_hook(job_queue, result_queue):
|
|||
while not job_queue.empty():
|
||||
try:
|
||||
host = job_queue.get(block=False)
|
||||
return_data = multiprocessing_runner._executor(host)
|
||||
return_data = multiprocessing_runner._executor(host, new_stdin)
|
||||
result_queue.put(return_data)
|
||||
|
||||
if 'LEGACY_TEMPLATE_WARNING' in return_data.flags:
|
||||
|
@ -75,6 +75,8 @@ def _executor_hook(job_queue, result_queue):
|
|||
pass
|
||||
except:
|
||||
traceback.print_exc()
|
||||
if new_stdin:
|
||||
new_stdin.close()
|
||||
|
||||
class HostVars(dict):
|
||||
''' A special view of setup_cache that adds values from the inventory when needed. '''
|
||||
|
@ -133,6 +135,9 @@ class Runner(object):
|
|||
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False
|
||||
):
|
||||
|
||||
# used to lock multiprocess inputs that wish to share stdin
|
||||
self.lockfile = tempfile.NamedTemporaryFile()
|
||||
|
||||
if not complex_args:
|
||||
complex_args = {}
|
||||
|
||||
|
@ -331,7 +336,7 @@ class Runner(object):
|
|||
|
||||
# *****************************************************
|
||||
|
||||
def _executor(self, host):
|
||||
def _executor(self, host, new_stdin):
|
||||
''' handler for multiprocessing library '''
|
||||
|
||||
def get_flags():
|
||||
|
@ -344,7 +349,9 @@ class Runner(object):
|
|||
return flags
|
||||
|
||||
try:
|
||||
exec_rc = self._executor_internal(host)
|
||||
self._new_stdin = new_stdin
|
||||
|
||||
exec_rc = self._executor_internal(host, new_stdin)
|
||||
if type(exec_rc) != ReturnData:
|
||||
raise Exception("unexpected return type: %s" % type(exec_rc))
|
||||
exec_rc.flags = get_flags()
|
||||
|
@ -363,7 +370,7 @@ class Runner(object):
|
|||
|
||||
# *****************************************************
|
||||
|
||||
def _executor_internal(self, host):
|
||||
def _executor_internal(self, host, new_stdin):
|
||||
''' executes any module one or more times '''
|
||||
|
||||
host_variables = self.inventory.get_variables(host)
|
||||
|
@ -774,8 +781,9 @@ class Runner(object):
|
|||
|
||||
workers = []
|
||||
for i in range(self.forks):
|
||||
new_stdin = os.fdopen(os.dup(sys.stdin.fileno()))
|
||||
prc = multiprocessing.Process(target=_executor_hook,
|
||||
args=(job_queue, result_queue))
|
||||
args=(job_queue, result_queue, new_stdin))
|
||||
prc.start()
|
||||
workers.append(prc)
|
||||
|
||||
|
@ -847,7 +855,7 @@ class Runner(object):
|
|||
# We aren't iterating over all the hosts in this
|
||||
# group. So, just pick the first host in our group to
|
||||
# construct the conn object with.
|
||||
result_data = self._executor(hosts[0]).result
|
||||
result_data = self._executor(hosts[0], None).result
|
||||
# Create a ResultData item for each host in this group
|
||||
# using the returned result. If we didn't do this we would
|
||||
# get false reports of dark hosts.
|
||||
|
@ -865,7 +873,7 @@ class Runner(object):
|
|||
raise errors.AnsibleError("interrupted")
|
||||
raise
|
||||
else:
|
||||
results = [ self._executor(h) for h in hosts ]
|
||||
results = [ self._executor(h, None) for h in hosts ]
|
||||
return self._partition_results(results)
|
||||
|
||||
# *****************************************************
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
# You should have received a copy of the GNU General Public License
|
||||
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
# ---
|
||||
# The paramiko transport is provided because many distributions, in particular EL6 and before
|
||||
# do not support ControlPersist in their SSH implementations. This is needed on the Ansible
|
||||
# control machine to be reasonably efficient with connections. Thus paramiko is faster
|
||||
# for most users on these platforms. Users with ControlPersist capability can consider
|
||||
# using -c ssh or configuring the transport in ansible.cfg.
|
||||
|
||||
import warnings
|
||||
import os
|
||||
import pipes
|
||||
|
@ -24,12 +32,19 @@ import logging
|
|||
import traceback
|
||||
import fcntl
|
||||
import sys
|
||||
from termios import tcflush, TCIFLUSH
|
||||
from binascii import hexlify
|
||||
from ansible.callbacks import vvv
|
||||
from ansible import errors
|
||||
from ansible import utils
|
||||
from ansible import constants as C
|
||||
|
||||
AUTHENTICITY_MSG="""
|
||||
paramiko: The authenticity of host '%s' can't be established.
|
||||
The %s key fingerprint is %s.
|
||||
Are you sure you want to continue connecting (yes/no)?
|
||||
"""
|
||||
|
||||
# prevent paramiko warning noise -- see http://stackoverflow.com/questions/3920502/
|
||||
HAVE_PARAMIKO=False
|
||||
with warnings.catch_warnings():
|
||||
|
@ -41,22 +56,49 @@ with warnings.catch_warnings():
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
class MyAutoAddPolicy(object):
|
||||
class MyAddPolicy(object):
|
||||
"""
|
||||
Modified version of AutoAddPolicy in paramiko so we can determine when keys are added.
|
||||
Based on AutoAddPolicy in paramiko so we can determine when keys are added
|
||||
and also prompt for input.
|
||||
|
||||
Policy for automatically adding the hostname and new host key to the
|
||||
local L{HostKeys} object, and saving it. This is used by L{SSHClient}.
|
||||
"""
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
|
||||
def missing_host_key(self, client, hostname, key):
|
||||
|
||||
if C.HOST_KEY_CHECKING:
|
||||
|
||||
KEY_LOCK = self.runner.lockfile
|
||||
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
|
||||
|
||||
old_stdin = sys.stdin
|
||||
sys.stdin = self.runner._new_stdin
|
||||
fingerprint = hexlify(key.get_fingerprint())
|
||||
ktype = key.get_name()
|
||||
|
||||
# clear out any premature input on sys.stdin
|
||||
tcflush(sys.stdin, TCIFLUSH)
|
||||
|
||||
inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
|
||||
sys.stdin = old_stdin
|
||||
if inp not in ['yes','y','']:
|
||||
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
|
||||
raise errors.AnsibleError("host connection rejected by user")
|
||||
|
||||
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
key._added_by_ansible_this_time = True
|
||||
|
||||
# existing implementation below:
|
||||
client._host_keys.add(hostname, key.get_name(), key)
|
||||
if client._host_keys_filename is not None:
|
||||
client.save_host_keys(client._host_keys_filename)
|
||||
|
||||
# host keys are actually saved in close() function below
|
||||
# in order to control ordering.
|
||||
|
||||
|
||||
# keep connection objects on a per host basis to avoid repeated attempts to reconnect
|
||||
|
@ -103,7 +145,7 @@ class Connection(object):
|
|||
|
||||
if C.HOST_KEY_CHECKING:
|
||||
ssh.load_system_host_keys()
|
||||
ssh.set_missing_host_key_policy(MyAutoAddPolicy())
|
||||
ssh.set_missing_host_key_policy(MyAddPolicy(self.runner))
|
||||
|
||||
allow_agent = True
|
||||
if self.password is not None:
|
||||
|
|
Loading…
Reference in a new issue