1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Enable paramiko to ask whether to add keys to known hosts.

This commit is contained in:
Michael DeHaan 2013-07-04 14:05:41 -04:00
parent cff4ab511c
commit ffadbc520a
2 changed files with 64 additions and 14 deletions

View file

@ -53,7 +53,7 @@ multiprocessing_runner = None
################################################ ################################################
def _executor_hook(job_queue, result_queue): def _executor_hook(job_queue, result_queue, new_stdin):
# attempt workaround of https://github.com/newsapps/beeswithmachineguns/issues/17 # attempt workaround of https://github.com/newsapps/beeswithmachineguns/issues/17
# this function also not present in CentOS 6 # this function also not present in CentOS 6
@ -64,7 +64,7 @@ def _executor_hook(job_queue, result_queue):
while not job_queue.empty(): while not job_queue.empty():
try: try:
host = job_queue.get(block=False) host = job_queue.get(block=False)
return_data = multiprocessing_runner._executor(host) return_data = multiprocessing_runner._executor(host, new_stdin)
result_queue.put(return_data) result_queue.put(return_data)
if 'LEGACY_TEMPLATE_WARNING' in return_data.flags: if 'LEGACY_TEMPLATE_WARNING' in return_data.flags:
@ -75,6 +75,8 @@ def _executor_hook(job_queue, result_queue):
pass pass
except: except:
traceback.print_exc() traceback.print_exc()
if new_stdin:
new_stdin.close()
class HostVars(dict): class HostVars(dict):
''' A special view of setup_cache that adds values from the inventory when needed. ''' ''' A special view of setup_cache that adds values from the inventory when needed. '''
@ -133,6 +135,9 @@ class Runner(object):
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False
): ):
# used to lock multiprocess inputs that wish to share stdin
self.lockfile = tempfile.NamedTemporaryFile()
if not complex_args: if not complex_args:
complex_args = {} complex_args = {}
@ -331,7 +336,7 @@ class Runner(object):
# ***************************************************** # *****************************************************
def _executor(self, host): def _executor(self, host, new_stdin):
''' handler for multiprocessing library ''' ''' handler for multiprocessing library '''
def get_flags(): def get_flags():
@ -344,7 +349,9 @@ class Runner(object):
return flags return flags
try: try:
exec_rc = self._executor_internal(host) self._new_stdin = new_stdin
exec_rc = self._executor_internal(host, new_stdin)
if type(exec_rc) != ReturnData: if type(exec_rc) != ReturnData:
raise Exception("unexpected return type: %s" % type(exec_rc)) raise Exception("unexpected return type: %s" % type(exec_rc))
exec_rc.flags = get_flags() exec_rc.flags = get_flags()
@ -363,7 +370,7 @@ class Runner(object):
# ***************************************************** # *****************************************************
def _executor_internal(self, host): def _executor_internal(self, host, new_stdin):
''' executes any module one or more times ''' ''' executes any module one or more times '''
host_variables = self.inventory.get_variables(host) host_variables = self.inventory.get_variables(host)
@ -774,8 +781,9 @@ class Runner(object):
workers = [] workers = []
for i in range(self.forks): for i in range(self.forks):
new_stdin = os.fdopen(os.dup(sys.stdin.fileno()))
prc = multiprocessing.Process(target=_executor_hook, prc = multiprocessing.Process(target=_executor_hook,
args=(job_queue, result_queue)) args=(job_queue, result_queue, new_stdin))
prc.start() prc.start()
workers.append(prc) workers.append(prc)
@ -847,7 +855,7 @@ class Runner(object):
# We aren't iterating over all the hosts in this # We aren't iterating over all the hosts in this
# group. So, just pick the first host in our group to # group. So, just pick the first host in our group to
# construct the conn object with. # construct the conn object with.
result_data = self._executor(hosts[0]).result result_data = self._executor(hosts[0], None).result
# Create a ResultData item for each host in this group # Create a ResultData item for each host in this group
# using the returned result. If we didn't do this we would # using the returned result. If we didn't do this we would
# get false reports of dark hosts. # get false reports of dark hosts.
@ -865,7 +873,7 @@ class Runner(object):
raise errors.AnsibleError("interrupted") raise errors.AnsibleError("interrupted")
raise raise
else: else:
results = [ self._executor(h) for h in hosts ] results = [ self._executor(h, None) for h in hosts ]
return self._partition_results(results) return self._partition_results(results)
# ***************************************************** # *****************************************************

View file

@ -15,6 +15,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# ---
# The paramiko transport is provided because many distributions, in particular EL6 and before
# do not support ControlPersist in their SSH implementations. This is needed on the Ansible
# control machine to be reasonably efficient with connections. Thus paramiko is faster
# for most users on these platforms. Users with ControlPersist capability can consider
# using -c ssh or configuring the transport in ansible.cfg.
import warnings import warnings
import os import os
import pipes import pipes
@ -24,12 +32,19 @@ import logging
import traceback import traceback
import fcntl import fcntl
import sys import sys
from termios import tcflush, TCIFLUSH
from binascii import hexlify from binascii import hexlify
from ansible.callbacks import vvv from ansible.callbacks import vvv
from ansible import errors from ansible import errors
from ansible import utils from ansible import utils
from ansible import constants as C from ansible import constants as C
AUTHENTICITY_MSG="""
paramiko: The authenticity of host '%s' can't be established.
The %s key fingerprint is %s.
Are you sure you want to continue connecting (yes/no)?
"""
# prevent paramiko warning noise -- see http://stackoverflow.com/questions/3920502/ # prevent paramiko warning noise -- see http://stackoverflow.com/questions/3920502/
HAVE_PARAMIKO=False HAVE_PARAMIKO=False
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -41,22 +56,49 @@ with warnings.catch_warnings():
except ImportError: except ImportError:
pass pass
class MyAutoAddPolicy(object): class MyAddPolicy(object):
""" """
Modified version of AutoAddPolicy in paramiko so we can determine when keys are added. Based on AutoAddPolicy in paramiko so we can determine when keys are added
and also prompt for input.
Policy for automatically adding the hostname and new host key to the Policy for automatically adding the hostname and new host key to the
local L{HostKeys} object, and saving it. This is used by L{SSHClient}. local L{HostKeys} object, and saving it. This is used by L{SSHClient}.
""" """
def __init__(self, runner):
self.runner = runner
def missing_host_key(self, client, hostname, key): def missing_host_key(self, client, hostname, key):
if C.HOST_KEY_CHECKING:
KEY_LOCK = self.runner.lockfile
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
old_stdin = sys.stdin
sys.stdin = self.runner._new_stdin
fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name()
# clear out any premature input on sys.stdin
tcflush(sys.stdin, TCIFLUSH)
inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
sys.stdin = old_stdin
if inp not in ['yes','y','']:
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
raise errors.AnsibleError("host connection rejected by user")
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
key._added_by_ansible_this_time = True key._added_by_ansible_this_time = True
# existing implementation below: # existing implementation below:
client._host_keys.add(hostname, key.get_name(), key) client._host_keys.add(hostname, key.get_name(), key)
if client._host_keys_filename is not None:
client.save_host_keys(client._host_keys_filename) # host keys are actually saved in close() function below
# in order to control ordering.
# keep connection objects on a per host basis to avoid repeated attempts to reconnect # keep connection objects on a per host basis to avoid repeated attempts to reconnect
@ -103,7 +145,7 @@ class Connection(object):
if C.HOST_KEY_CHECKING: if C.HOST_KEY_CHECKING:
ssh.load_system_host_keys() ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(MyAutoAddPolicy()) ssh.set_missing_host_key_policy(MyAddPolicy(self.runner))
allow_agent = True allow_agent = True
if self.password is not None: if self.password is not None: