#!/usr/bin/env python # (c) 2017, Ansible, Inc. <support@ansible.com> # # This file is part of Ansible # # Ansible is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Ansible is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Ansible. If not, see <http://www.gnu.org/licenses/>. ######################################################## from __future__ import (absolute_import, division, print_function) __metaclass__ = type __requires__ = ['ansible'] try: import pkg_resources except Exception: pass import fcntl import os import shlex import signal import socket import sys import time import traceback import datetime import errno from ansible import constants as C from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils.six import PY3 from ansible.module_utils.six.moves import cPickle from ansible.module_utils.connection import send_data, recv_data from ansible.playbook.play_context import PlayContext from ansible.plugins import connection_loader from ansible.utils.path import unfrackpath, makedirs_safe from ansible.errors import AnsibleConnectionFailure from ansible.utils.display import Display def do_fork(): ''' Does the required double fork for a daemon process. Based on http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/ ''' try: pid = os.fork() if pid > 0: return pid #os.chdir("/") os.setsid() os.umask(0) try: pid = os.fork() if pid > 0: sys.exit(0) if C.DEFAULT_LOG_PATH != '': out_file = open(C.DEFAULT_LOG_PATH, 'ab+') err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0) else: out_file = open('/dev/null', 'ab+') err_file = open('/dev/null', 'ab+', 0) os.dup2(out_file.fileno(), sys.stdout.fileno()) os.dup2(err_file.fileno(), sys.stderr.fileno()) os.close(sys.stdin.fileno()) return pid except OSError as e: sys.exit(1) except OSError as e: sys.exit(1) class Server(): def __init__(self, socket_path, play_context): self.socket_path = socket_path self.play_context = play_context display.display( 'creating new control socket for host %s:%s as user %s' % (play_context.remote_addr, play_context.port, play_context.remote_user), log_only=True ) display.display('control socket path is %s' % socket_path, log_only=True) display.display('current working directory is %s' % os.getcwd(), log_only=True) self._start_time = datetime.datetime.now() display.display("using connection plugin %s" % self.play_context.connection, log_only=True) self.connection = connection_loader.get(play_context.connection, play_context, sys.stdin) self.connection._connect() if not self.connection.connected: raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr) connection_time = datetime.datetime.now() - self._start_time display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True) self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.socket.bind(self.socket_path) self.socket.listen(1) display.display('local socket is set to listening', log_only=True) def run(self): try: while True: signal.signal(signal.SIGALRM, self.connect_timeout) signal.signal(signal.SIGTERM, self.handler) signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT) (s, addr) = self.socket.accept() display.display('incoming request accepted on persistent socket', log_only=True) signal.alarm(0) while True: data = recv_data(s) if not data: break signal.signal(signal.SIGALRM, self.command_timeout) signal.alarm(self.play_context.timeout) op = data.split(':')[0] display.display('socket operation is %s' % op, log_only=True) method = getattr(self, 'do_%s' % op, None) rc = 255 stdout = stderr = '' if not method: stderr = 'Invalid action specified' else: rc, stdout, stderr = method(data) signal.alarm(0) display.display('socket operation completed with rc %s' % rc, log_only=True) send_data(s, to_bytes(rc)) send_data(s, to_bytes(stdout)) send_data(s, to_bytes(stderr)) s.close() except Exception as e: # socket.accept() will raise EINTR if the socket.close() is called if e.errno != errno.EINTR: display.display(traceback.format_exc(), log_only=True) finally: # when done, close the connection properly and cleanup # the socket file so it can be recreated self.shutdown() end_time = datetime.datetime.now() delta = end_time - self._start_time display.display('shutdown local socket, connection was active for %s secs' % delta, log_only=True) def connect_timeout(self, signum, frame): display.display('connect timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True) self.shutdown() def command_timeout(self, signum, frame): display.display('commnad timeout triggered, timeout value is %s secs' % self.play_context.timeout, log_only=True) self.shutdown() def handler(self, signum, frame): display.display('signal handler called with signal %s' % signum, log_only=True) self.shutdown() def shutdown(self): display.display('shutdown persistent connection requested', log_only=True) if not os.path.exists(self.socket_path): display.display('persistent connection is not active', log_only=True) return try: if self.socket: display.display('closing local listener', log_only=True) self.socket.close() if self.connection: display.display('closing the connection', log_only=True) self.close() except: pass finally: if os.path.exists(self.socket_path): display.display('removing the local control socket', log_only=True) os.remove(self.socket_path) display.display('shutdown complete', log_only=True) def do_EXEC(self, data): cmd = data.split(b'EXEC: ')[1] return self.connection.exec_command(cmd) def do_PUT(self, data): (op, src, dst) = shlex.split(to_native(data)) return self.connection.fetch_file(src, dst) def do_FETCH(self, data): (op, src, dst) = shlex.split(to_native(data)) return self.connection.put_file(src, dst) def do_CONTEXT(self, data): pc_data = data.split(b'CONTEXT: ', 1)[1] if PY3: pc_data = cPickle.loads(pc_data, encoding='bytes') else: pc_data = cPickle.loads(pc_data) pc = PlayContext() pc.deserialize(pc_data) try: self.connection.update_play_context(pc) except AttributeError: pass return (0, 'ok', '') def do_RUN(self, data): timeout = self.play_context.timeout while bool(timeout): if os.path.exists(self.socket_path): break time.sleep(1) timeout -= 1 return (0, self.socket_path, '') def communicate(sock, data): send_data(sock, data) rc = int(recv_data(sock), 10) stdout = recv_data(sock) stderr = recv_data(sock) return (rc, stdout, stderr) def main(): # Need stdin as a byte stream if PY3: stdin = sys.stdin.buffer else: stdin = sys.stdin try: # read the play context data via stdin, which means depickling it # FIXME: as noted above, we will probably need to deserialize the # connection loader here as well at some point, otherwise this # won't find role- or playbook-based connection plugins cur_line = stdin.readline() init_data = b'' while cur_line.strip() != b'#END_INIT#': if cur_line == b'': raise Exception("EOF found before init data was complete") init_data += cur_line cur_line = stdin.readline() if PY3: pc_data = cPickle.loads(init_data, encoding='bytes') else: pc_data = cPickle.loads(init_data) pc = PlayContext() pc.deserialize(pc_data) except Exception as e: # FIXME: better error message/handling/logging sys.stderr.write(traceback.format_exc()) sys.exit("FAIL: %s" % e) ssh = connection_loader.get('ssh', class_only=True) cp = ssh._create_control_path(pc.remote_addr, pc.connection, pc.remote_user) # create the persistent connection dir if need be and create the paths # which we will be using later tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) makedirs_safe(tmp_path) lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path) socket_path = unfrackpath(cp % dict(directory=tmp_path)) # if the socket file doesn't exist, spin up the daemon process lock_fd = os.open(lock_path, os.O_RDWR|os.O_CREAT, 0o600) fcntl.lockf(lock_fd, fcntl.LOCK_EX) if not os.path.exists(socket_path): pid = do_fork() if pid == 0: rc = 0 try: server = Server(socket_path, pc) except AnsibleConnectionFailure as exc: display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True) display.display(str(exc), log_only=True) rc = 1 except Exception as exc: display.display('failed to create control socket for host %s' % pc.remote_addr, log_only=True) display.display(traceback.format_exc(), log_only=True) rc = 1 fcntl.lockf(lock_fd, fcntl.LOCK_UN) os.close(lock_fd) if rc == 0: server.run() sys.exit(rc) else: display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True) fcntl.lockf(lock_fd, fcntl.LOCK_UN) os.close(lock_fd) timeout = pc.timeout while bool(timeout): if os.path.exists(socket_path): display.vvvv('connected to local socket in %s' % (pc.timeout - timeout), pc.remote_addr) break time.sleep(1) timeout -= 1 else: raise AnsibleConnectionFailure('timeout waiting for local socket', pc.remote_addr) # now connect to the daemon process # FIXME: if the socket file existed but the daemonized process was killed, # the connection will timeout here. Need to make this more resilient. while True: data = stdin.readline() if data == b'': break if data.strip() == b'': continue sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) attempts = C.PERSISTENT_CONNECT_RETRIES while bool(attempts): try: sock.connect(socket_path) break except socket.error: time.sleep(C.PERSISTENT_CONNECT_INTERVAL) attempts -= 1 else: display.display('number of connection attempts exceeded, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True) display.display('persistent_connect_interval=%s, persistent_connect_retries=%s' % (C.PERSISTENT_CONNECT_INTERVAL, C.PERSISTENT_CONNECT_RETRIES), pc.remote_addr, pc.remote_user, log_only=True) sys.stderr.write('failed to connect to control socket') sys.exit(255) # send the play_context back into the connection so the connection # can handle any privilege escalation activities pc_data = b'CONTEXT: %s' % init_data communicate(sock, pc_data) rc, stdout, stderr = communicate(sock, data.strip()) sys.stdout.write(to_native(stdout)) sys.stderr.write(to_native(stderr)) sock.close() break sys.exit(rc) if __name__ == '__main__': display = Display() main()