mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Improve interlaced output prevention when asking for host key approval.
This commit is contained in:
parent
c55adc9ac9
commit
2cb7c30834
4 changed files with 91 additions and 73 deletions
|
@ -83,11 +83,23 @@ def log_lockfile():
|
|||
|
||||
LOG_LOCK = open(log_lockfile(), 'w')
|
||||
|
||||
def log_flock():
|
||||
fcntl.flock(LOG_LOCK, fcntl.LOCK_EX)
|
||||
def log_flock(runner):
|
||||
fcntl.lockf(LOG_LOCK, fcntl.LOCK_EX)
|
||||
if runner is not None:
|
||||
try:
|
||||
fcntl.lockf(runner.output_lockfile, fcntl.LOCK_EX)
|
||||
except OSError, e:
|
||||
# already got closed?
|
||||
pass
|
||||
|
||||
def log_unflock():
|
||||
fcntl.flock(LOG_LOCK, fcntl.LOCK_UN)
|
||||
def log_unflock(runner):
|
||||
fcntl.lockf(LOG_LOCK, fcntl.LOCK_UN)
|
||||
if runner is not None:
|
||||
try:
|
||||
fcntl.lockf(runner.output_lockfile, fcntl.LOCK_UN)
|
||||
except OSError, e:
|
||||
# already got closed?
|
||||
pass
|
||||
|
||||
def set_play(callback, play):
|
||||
''' used to notify callback plugins of context '''
|
||||
|
@ -101,9 +113,9 @@ def set_task(callback, task):
|
|||
for callback_plugin in callback_plugins:
|
||||
callback_plugin.task = task
|
||||
|
||||
def display(msg, color=None, stderr=False, screen_only=False, log_only=False):
|
||||
def display(msg, color=None, stderr=False, screen_only=False, log_only=False, runner=None):
|
||||
# prevent a very rare case of interlaced multiprocess I/O
|
||||
log_flock()
|
||||
log_flock(runner)
|
||||
msg2 = msg
|
||||
if color:
|
||||
msg2 = stringc(msg, color)
|
||||
|
@ -120,7 +132,7 @@ def display(msg, color=None, stderr=False, screen_only=False, log_only=False):
|
|||
logger.error(msg)
|
||||
else:
|
||||
logger.info(msg)
|
||||
log_unflock()
|
||||
log_unflock(runner)
|
||||
|
||||
def call_callback_module(method_name, *args, **kwargs):
|
||||
|
||||
|
@ -346,7 +358,7 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
def on_unreachable(self, host, res):
|
||||
if type(res) == dict:
|
||||
res = res.get('msg','')
|
||||
display("%s | FAILED => %s" % (host, res), stderr=True, color='red')
|
||||
display("%s | FAILED => %s" % (host, res), stderr=True, color='red', runner=self.runner)
|
||||
if self.options.tree:
|
||||
utils.write_tree_file(
|
||||
self.options.tree, host,
|
||||
|
@ -355,15 +367,15 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
super(CliRunnerCallbacks, self).on_unreachable(host, res)
|
||||
|
||||
def on_skipped(self, host, item=None):
|
||||
display("%s | skipped" % (host))
|
||||
display("%s | skipped" % (host), runner=self.runner)
|
||||
super(CliRunnerCallbacks, self).on_skipped(host, item)
|
||||
|
||||
def on_error(self, host, err):
|
||||
display("err: [%s] => %s\n" % (host, err), stderr=True)
|
||||
display("err: [%s] => %s\n" % (host, err), stderr=True, runner=self.runner)
|
||||
super(CliRunnerCallbacks, self).on_error(host, err)
|
||||
|
||||
def on_no_hosts(self):
|
||||
display("no hosts matched\n", stderr=True)
|
||||
display("no hosts matched\n", stderr=True, runner=self.runner)
|
||||
super(CliRunnerCallbacks, self).on_no_hosts()
|
||||
|
||||
def on_async_poll(self, host, res, jid, clock):
|
||||
|
@ -371,27 +383,27 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
self._async_notified[jid] = clock + 1
|
||||
if self._async_notified[jid] > clock:
|
||||
self._async_notified[jid] = clock
|
||||
display("<job %s> polling, %ss remaining" % (jid, clock))
|
||||
display("<job %s> polling, %ss remaining" % (jid, clock), runner=self.runner)
|
||||
super(CliRunnerCallbacks, self).on_async_poll(host, res, jid, clock)
|
||||
|
||||
def on_async_ok(self, host, res, jid):
|
||||
display("<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True)))
|
||||
display("<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True)), runner=self.runner)
|
||||
super(CliRunnerCallbacks, self).on_async_ok(host, res, jid)
|
||||
|
||||
def on_async_failed(self, host, res, jid):
|
||||
display("<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True)), color='red', stderr=True)
|
||||
display("<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True)), color='red', stderr=True, runner=self.runner)
|
||||
super(CliRunnerCallbacks, self).on_async_failed(host,res,jid)
|
||||
|
||||
def _on_any(self, host, result):
|
||||
result2 = result.copy()
|
||||
result2.pop('invocation', None)
|
||||
(msg, color) = host_report_msg(host, self.options.module_name, result2, self.options.one_line)
|
||||
display(msg, color=color)
|
||||
display(msg, color=color, runner=self.runner)
|
||||
if self.options.tree:
|
||||
utils.write_tree_file(self.options.tree, host, utils.jsonify(result2,format=True))
|
||||
|
||||
def on_file_diff(self, host, diff):
|
||||
display(utils.get_diff(diff))
|
||||
display(utils.get_diff(diff), runner=self.runner)
|
||||
super(CliRunnerCallbacks, self).on_file_diff(host, diff)
|
||||
|
||||
########################################################################
|
||||
|
@ -412,11 +424,12 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
msg = "fatal: [%s] => (item=%s) => %s" % (host, item, results)
|
||||
else:
|
||||
msg = "fatal: [%s] => %s" % (host, results)
|
||||
display(msg, color='red')
|
||||
display(msg, color='red', runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_unreachable(host, results)
|
||||
|
||||
def on_failed(self, host, results, ignore_errors=False):
|
||||
|
||||
|
||||
results2 = results.copy()
|
||||
results2.pop('invocation', None)
|
||||
|
||||
|
@ -433,21 +446,22 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
msg = "failed: [%s] => (item=%s) => %s" % (host, item, utils.jsonify(results2))
|
||||
else:
|
||||
msg = "failed: [%s] => %s" % (host, utils.jsonify(results2))
|
||||
display(msg, color='red')
|
||||
display(msg, color='red', runner=self.runner)
|
||||
|
||||
if stderr:
|
||||
display("stderr: %s" % stderr, color='red')
|
||||
display("stderr: %s" % stderr, color='red', runner=self.runner)
|
||||
if stdout:
|
||||
display("stdout: %s" % stdout, color='red')
|
||||
display("stdout: %s" % stdout, color='red', runner=self.runner)
|
||||
if returned_msg:
|
||||
display("msg: %s" % returned_msg, color='red')
|
||||
display("msg: %s" % returned_msg, color='red', runner=self.runner)
|
||||
if not parsed and module_msg:
|
||||
display("invalid output was: %s" % module_msg, color='red')
|
||||
display("invalid output was: %s" % module_msg, color='red', runner=self.runner)
|
||||
if ignore_errors:
|
||||
display("...ignoring", color='cyan')
|
||||
display("...ignoring", color='cyan', runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_failed(host, results, ignore_errors=ignore_errors)
|
||||
|
||||
def on_ok(self, host, host_result):
|
||||
|
||||
item = host_result.get('item', None)
|
||||
|
||||
host_result2 = host_result.copy()
|
||||
|
@ -477,9 +491,9 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
|
||||
if msg != '':
|
||||
if not changed:
|
||||
display(msg, color='green')
|
||||
display(msg, color='green', runner=self.runner)
|
||||
else:
|
||||
display(msg, color='yellow')
|
||||
display(msg, color='yellow', runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_ok(host, host_result)
|
||||
|
||||
def on_error(self, host, err):
|
||||
|
@ -491,7 +505,7 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
else:
|
||||
msg = "err: [%s] => %s" % (host, err)
|
||||
|
||||
display(msg, color='red', stderr=True)
|
||||
display(msg, color='red', stderr=True, runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_error(host, err)
|
||||
|
||||
def on_skipped(self, host, item=None):
|
||||
|
@ -500,11 +514,11 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
msg = "skipping: [%s] => (item=%s)" % (host, item)
|
||||
else:
|
||||
msg = "skipping: [%s]" % host
|
||||
display(msg, color='cyan')
|
||||
display(msg, color='cyan', runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_skipped(host, item)
|
||||
|
||||
def on_no_hosts(self):
|
||||
display("FATAL: no hosts matched or all hosts have already failed -- aborting\n", color='red')
|
||||
display("FATAL: no hosts matched or all hosts have already failed -- aborting\n", color='red', runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_no_hosts()
|
||||
|
||||
def on_async_poll(self, host, res, jid, clock):
|
||||
|
@ -513,21 +527,21 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
|
|||
if self._async_notified[jid] > clock:
|
||||
self._async_notified[jid] = clock
|
||||
msg = "<job %s> polling, %ss remaining"%(jid, clock)
|
||||
display(msg, color='cyan')
|
||||
display(msg, color='cyan', runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_async_poll(host,res,jid,clock)
|
||||
|
||||
def on_async_ok(self, host, res, jid):
|
||||
msg = "<job %s> finished on %s"%(jid, host)
|
||||
display(msg, color='cyan')
|
||||
display(msg, color='cyan', runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_async_ok(host, res, jid)
|
||||
|
||||
def on_async_failed(self, host, res, jid):
|
||||
msg = "<job %s> FAILED on %s" % (jid, host)
|
||||
display(msg, color='red', stderr=True)
|
||||
display(msg, color='red', stderr=True, runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_async_failed(host,res,jid)
|
||||
|
||||
def on_file_diff(self, host, diff):
|
||||
display(utils.get_diff(diff))
|
||||
display(utils.get_diff(diff), runner=self.runner)
|
||||
super(PlaybookRunnerCallbacks, self).on_file_diff(host, diff)
|
||||
|
||||
########################################################################
|
||||
|
|
|
@ -51,6 +51,9 @@ except ImportError:
|
|||
HAS_ATFORK=False
|
||||
|
||||
multiprocessing_runner = None
|
||||
|
||||
OUTPUT_LOCKFILE = tempfile.TemporaryFile()
|
||||
PROCESS_LOCKFILE = tempfile.TemporaryFile()
|
||||
|
||||
################################################
|
||||
|
||||
|
@ -134,8 +137,9 @@ class Runner(object):
|
|||
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False
|
||||
):
|
||||
|
||||
# used to lock multiprocess inputs that wish to share stdin
|
||||
self.lockfile = tempfile.NamedTemporaryFile()
|
||||
# used to lock multiprocess inputs and outputs at various levels
|
||||
self.output_lockfile = OUTPUT_LOCKFILE
|
||||
self.process_lockfile = PROCESS_LOCKFILE
|
||||
|
||||
if not complex_args:
|
||||
complex_args = {}
|
||||
|
@ -884,9 +888,6 @@ class Runner(object):
|
|||
else:
|
||||
results = [ self._executor(h, None) for h in hosts ]
|
||||
|
||||
|
||||
self.lockfile.close()
|
||||
|
||||
return self._partition_results(results)
|
||||
|
||||
# *****************************************************
|
||||
|
|
|
@ -72,9 +72,9 @@ class MyAddPolicy(object):
|
|||
|
||||
if C.HOST_KEY_CHECKING:
|
||||
|
||||
KEY_LOCK = self.runner.lockfile
|
||||
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
|
||||
|
||||
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
|
||||
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
|
||||
|
||||
old_stdin = sys.stdin
|
||||
sys.stdin = self.runner._new_stdin
|
||||
fingerprint = hexlify(key.get_fingerprint())
|
||||
|
@ -86,10 +86,12 @@ class MyAddPolicy(object):
|
|||
inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
|
||||
sys.stdin = old_stdin
|
||||
if inp not in ['yes','y','']:
|
||||
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
|
||||
fcntl.flock(self.runner.output_lockfile, fcntl.LOCK_UN)
|
||||
fcntl.flock(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
raise errors.AnsibleError("host connection rejected by user")
|
||||
|
||||
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
|
||||
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
|
||||
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
key._added_by_ansible_this_time = True
|
||||
|
@ -257,22 +259,23 @@ class Connection(object):
|
|||
except IOError:
|
||||
raise errors.AnsibleError("failed to transfer file from %s" % in_path)
|
||||
|
||||
def _any_keys_added(self):
|
||||
added_any = False
|
||||
for hostname, keys in self.ssh._host_keys.iteritems():
|
||||
for keytype, key in keys.iteritems():
|
||||
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
|
||||
if added_this_time:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _save_ssh_host_keys(self, filename):
|
||||
'''
|
||||
not using the paramiko save_ssh_host_keys function as we want to add new SSH keys at the bottom so folks
|
||||
don't complain about it :)
|
||||
'''
|
||||
|
||||
added_any = False
|
||||
for hostname, keys in self.ssh._host_keys.iteritems():
|
||||
for keytype, key in keys.iteritems():
|
||||
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
|
||||
if added_this_time:
|
||||
added_any = True
|
||||
break
|
||||
|
||||
if not added_any:
|
||||
return
|
||||
if not self._any_keys_added():
|
||||
return False
|
||||
|
||||
path = os.path.expanduser("~/.ssh")
|
||||
if not os.path.exists(path):
|
||||
|
@ -300,23 +303,22 @@ class Connection(object):
|
|||
if self.sftp is not None:
|
||||
self.sftp.close()
|
||||
|
||||
# add any new SSH host keys
|
||||
lockfile = self.keyfile.replace("known_hosts",".known_hosts.lock")
|
||||
KEY_LOCK = open(lockfile, 'w')
|
||||
fcntl.flock(KEY_LOCK, fcntl.LOCK_EX)
|
||||
|
||||
try:
|
||||
# just in case any were added recently
|
||||
self.ssh.load_system_host_keys()
|
||||
self.ssh._host_keys.update(self.ssh._system_host_keys)
|
||||
#self.ssh.save_host_keys(self.keyfile)
|
||||
self._save_ssh_host_keys(self.keyfile)
|
||||
except:
|
||||
# unable to save keys, including scenario when key was invalid
|
||||
# and caught earlier
|
||||
traceback.print_exc()
|
||||
pass
|
||||
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
|
||||
if self._any_keys_added():
|
||||
# add any new SSH host keys -- warning -- this could be slow
|
||||
lockfile = self.keyfile.replace("known_hosts",".known_hosts.lock")
|
||||
KEY_LOCK = open(lockfile, 'w')
|
||||
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
|
||||
try:
|
||||
# just in case any were added recently
|
||||
self.ssh.load_system_host_keys()
|
||||
self.ssh._host_keys.update(self.ssh._system_host_keys)
|
||||
self._save_ssh_host_keys(self.keyfile)
|
||||
except:
|
||||
# unable to save keys, including scenario when key was invalid
|
||||
# and caught earlier
|
||||
traceback.print_exc()
|
||||
pass
|
||||
fcntl.lockf(KEY_LOCK, fcntl.LOCK_UN)
|
||||
|
||||
self.ssh.close()
|
||||
|
||||
|
|
|
@ -131,8 +131,9 @@ class Connection(object):
|
|||
if C.HOST_KEY_CHECKING and not_in_host_file:
|
||||
# lock around the initial SSH connectivity so the user prompt about whether to add
|
||||
# the host to known hosts is not intermingled with multiprocess output.
|
||||
KEY_LOCK = self.runner.lockfile
|
||||
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
|
||||
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
|
||||
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
|
||||
|
||||
|
||||
|
||||
try:
|
||||
|
@ -191,8 +192,8 @@ class Connection(object):
|
|||
if C.HOST_KEY_CHECKING and not_in_host_file:
|
||||
# lock around the initial SSH connectivity so the user prompt about whether to add
|
||||
# the host to known hosts is not intermingled with multiprocess output.
|
||||
KEY_LOCK = self.runner.lockfile
|
||||
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
|
||||
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
|
||||
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
|
||||
if p.returncode != 0 and stderr.find('Bad configuration option: ControlPersist') != -1:
|
||||
raise errors.AnsibleError('using -c ssh on certain older ssh versions may not support ControlPersist, set ANSIBLE_SSH_ARGS="" (or ansible_ssh_args in the config file) before running again')
|
||||
|
|
Loading…
Reference in a new issue