mirror of
				https://github.com/ansible-collections/community.general.git
				synced 2024-09-14 20:13:21 +02:00 
			
		
		
		
	
		
			
				
	
	
		
			457 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			457 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/python
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # (c) 2013, James Cammarata <jcammarata@ansibleworks.com>
 | |
| #
 | |
| # This file is part of Ansible
 | |
| #
 | |
| # Ansible is free software: you can redistribute it and/or modify
 | |
| # it under the terms of the GNU General Public License as published by
 | |
| # the Free Software Foundation, either version 3 of the License, or
 | |
| # (at your option) any later version.
 | |
| #
 | |
| # Ansible is distributed in the hope that it will be useful,
 | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | |
| # GNU General Public License for more details.
 | |
| #
 | |
| # You should have received a copy of the GNU General Public License
 | |
| # along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
 | |
| 
 | |
| DOCUMENTATION = '''
 | |
| ---
 | |
| module: accelerate
 | |
| short_description: Enable accelerated mode on remote node
 | |
| description:
 | |
|      - This modules launches an ephemeral I(accelerate) daemon on the remote node which
 | |
|        Ansible can use to communicate with nodes at high speed.
 | |
|      - The daemon listens on a configurable port for a configurable amount of time.
 | |
|      - Fireball mode is AES encrypted
 | |
| version_added: "1.3"
 | |
| options:
 | |
|   port:
 | |
|     description:
 | |
|       - TCP port for the socket connection
 | |
|     required: false
 | |
|     default: 5099
 | |
|     aliases: []
 | |
|   timeout:
 | |
|     description:
 | |
|       - The number of seconds the socket will wait for data. If none is received when the timeout value is reached, the connection will be closed.
 | |
|     required: false
 | |
|     default: 300
 | |
|     aliases: []
 | |
|   minutes:
 | |
|     description:
 | |
|       - The I(accelerate) listener daemon is started on nodes and will stay around for
 | |
|         this number of minutes before turning itself off.
 | |
|     required: false
 | |
|     default: 30
 | |
|   ipv6:
 | |
|     description:
 | |
|       - The listener daemon on the remote host will bind to the ipv6 localhost socket
 | |
|         if this parameter is set to true.
 | |
|     required: false
 | |
|     default: false
 | |
| notes:
 | |
|     - See the advanced playbooks chapter for more about using accelerated mode.
 | |
| requirements: [ "python-keyczar" ]
 | |
| author: James Cammarata
 | |
| '''
 | |
| 
 | |
| EXAMPLES = '''
 | |
| # To use accelerate mode, simply add "accelerate: true" to your play. The initial
 | |
| # key exchange and starting up of the daemon will occur over SSH, but all commands and
 | |
| # subsequent actions will be conducted over the raw socket connection using AES encryption
 | |
| 
 | |
| - hosts: devservers
 | |
|   accelerate: true
 | |
|   tasks:
 | |
|       - command: /usr/bin/anything
 | |
| '''
 | |
| 
 | |
| import base64
 | |
| import getpass
 | |
| import os
 | |
| import os.path
 | |
| import shutil
 | |
| import signal
 | |
| import socket
 | |
| import struct
 | |
| import sys
 | |
| import syslog
 | |
| import tempfile
 | |
| import time
 | |
| import traceback
 | |
| 
 | |
| import SocketServer
 | |
| 
 | |
| from datetime import datetime
 | |
| from threading import Thread
 | |
| 
 | |
| syslog.openlog('ansible-%s' % os.path.basename(__file__))
 | |
| PIDFILE = os.path.expanduser("~/.accelerate.pid")
 | |
| 
 | |
| # the chunk size to read and send, assuming mtu 1500 and 
 | |
| # leaving room for base64 (+33%) encoding and header (100 bytes)
 | |
| # 4 * (975/3) + 100 = 1400 
 | |
| # which leaves room for the TCP/IP header
 | |
| CHUNK_SIZE=10240
 | |
| 
 | |
| # FIXME: this all should be moved to module_common, as it's 
 | |
| #        pretty much a copy from the callbacks/util code
 | |
| DEBUG_LEVEL=0
 | |
| def log(msg, cap=0):
 | |
|     global DEBUG_LEVEL
 | |
|     if cap >= DEBUG_LEVEL:
 | |
|         syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg)
 | |
| 
 | |
| def vv(msg):
 | |
|     log(msg, cap=2)
 | |
| 
 | |
| def vvv(msg):
 | |
|     log(msg, cap=3)
 | |
| 
 | |
| def vvvv(msg):
 | |
|     log(msg, cap=4)
 | |
| 
 | |
| if os.path.exists(PIDFILE):
 | |
|     try:
 | |
|         data = int(open(PIDFILE).read())
 | |
|         try:
 | |
|             os.kill(data, signal.SIGKILL)
 | |
|         except OSError:
 | |
|             pass
 | |
|     except ValueError:
 | |
|         pass
 | |
|     os.unlink(PIDFILE)
 | |
| 
 | |
| HAS_KEYCZAR = False
 | |
| try:
 | |
|     from keyczar.keys import AesKey
 | |
|     HAS_KEYCZAR = True
 | |
| except ImportError:
 | |
|     pass
 | |
| 
 | |
| # NOTE: this shares a fair amount of code in common with async_wrapper, if async_wrapper were a new module we could move
 | |
| # this into utils.module_common and probably should anyway
 | |
| 
 | |
| def daemonize_self(module, password, port, minutes):
 | |
|     # daemonizing code: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012
 | |
|     try:
 | |
|         pid = os.fork()
 | |
|         if pid > 0:
 | |
|             vvv("exiting pid %s" % pid)
 | |
|             # exit first parent
 | |
|             module.exit_json(msg="daemonized accelerate on port %s for %s minutes with pid %s" % (port, minutes, str(pid)))
 | |
|     except OSError, e:
 | |
|         log("fork #1 failed: %d (%s)" % (e.errno, e.strerror))
 | |
|         sys.exit(1)
 | |
| 
 | |
|     # decouple from parent environment
 | |
|     os.chdir("/")
 | |
|     os.setsid()
 | |
|     os.umask(022)
 | |
| 
 | |
|     # do second fork
 | |
|     try:
 | |
|         pid = os.fork()
 | |
|         if pid > 0:
 | |
|             log("daemon pid %s, writing %s" % (pid, PIDFILE))
 | |
|             pid_file = open(PIDFILE, "w")
 | |
|             pid_file.write("%s" % pid)
 | |
|             pid_file.close()
 | |
|             vvv("pidfile written")
 | |
|             sys.exit(0)
 | |
|     except OSError, e:
 | |
|         log("fork #2 failed: %d (%s)" % (e.errno, e.strerror))
 | |
|         sys.exit(1)
 | |
| 
 | |
|     dev_null = file('/dev/null','rw')
 | |
|     os.dup2(dev_null.fileno(), sys.stdin.fileno())
 | |
|     os.dup2(dev_null.fileno(), sys.stdout.fileno())
 | |
|     os.dup2(dev_null.fileno(), sys.stderr.fileno())
 | |
|     log("daemonizing successful")
 | |
| 
 | |
| class ThreadWithReturnValue(Thread):
 | |
| 
 | |
|     def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None):
 | |
|         Thread.__init__(self, group, target, name, args, kwargs, Verbose)
 | |
|         self._return = None
 | |
| 
 | |
|     def run(self):
 | |
|         if self._Thread__target is not None:
 | |
|             self._return = self._Thread__target(*self._Thread__args,
 | |
|                                                 **self._Thread__kwargs)
 | |
| 
 | |
|     def join(self,timeout=None):
 | |
|         Thread.join(self, timeout=timeout)
 | |
|         return self._return
 | |
| 
 | |
| class ThreadedTCPServer(SocketServer.ThreadingTCPServer):
 | |
|     def __init__(self, server_address, RequestHandlerClass, module, password, timeout):
 | |
|         self.module = module
 | |
|         self.key = AesKey.Read(password)
 | |
|         self.allow_reuse_address = True
 | |
|         self.timeout = timeout
 | |
|         SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass)
 | |
| 
 | |
| class ThreadedTCPV6Server(SocketServer.ThreadingTCPServer):
 | |
|     def __init__(self, server_address, RequestHandlerClass, module, password, timeout):
 | |
|         self.module = module
 | |
|         self.address_family = socket.AF_INET6
 | |
|         self.key = AesKey.Read(password)
 | |
|         self.allow_reuse_address = True
 | |
|         self.timeout = timeout
 | |
|         SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass)
 | |
| 
 | |
| class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
 | |
|     def send_data(self, data):
 | |
|         packed_len = struct.pack('!Q', len(data))
 | |
|         return self.request.sendall(packed_len + data)
 | |
| 
 | |
|     def recv_data(self):
 | |
|         header_len = 8 # size of a packed unsigned long long
 | |
|         data = ""
 | |
|         vvvv("in recv_data(), waiting for the header")
 | |
|         while len(data) < header_len:
 | |
|             d = self.request.recv(header_len - len(data))
 | |
|             if not d:
 | |
|                 vvv("received nothing, bailing out")
 | |
|                 return None
 | |
|             data += d
 | |
|         vvvv("in recv_data(), got the header, unpacking")
 | |
|         data_len = struct.unpack('!Q',data[:header_len])[0]
 | |
|         data = data[header_len:]
 | |
|         vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
 | |
|         while len(data) < data_len:
 | |
|             d = self.request.recv(data_len - len(data))
 | |
|             if not d:
 | |
|                 vvv("received nothing, bailing out")
 | |
|                 return None
 | |
|             data += d
 | |
|             vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
 | |
|         vvvv("received all of the data, returning")
 | |
|         return data
 | |
| 
 | |
|     def handle(self):
 | |
|         try:
 | |
|             while True:
 | |
|                 vvvv("waiting for data")
 | |
|                 data = self.recv_data()
 | |
|                 if not data:
 | |
|                     vvvv("received nothing back from recv_data(), breaking out")
 | |
|                     break
 | |
|                 try:
 | |
|                     vvvv("got data, decrypting")
 | |
|                     data = self.server.key.Decrypt(data)
 | |
|                     vvvv("decryption done")
 | |
|                 except:
 | |
|                     vv("bad decrypt, skipping...")
 | |
|                     data2 = json.dumps(dict(rc=1))
 | |
|                     data2 = self.server.key.Encrypt(data2)
 | |
|                     send_data(client, data2)
 | |
|                     return
 | |
| 
 | |
|                 vvvv("loading json from the data")
 | |
|                 data = json.loads(data)
 | |
| 
 | |
|                 mode = data['mode']
 | |
|                 response = {}
 | |
|                 last_pong = datetime.now()
 | |
|                 if mode == 'command':
 | |
|                     vvvv("received a command request, running it")
 | |
|                     twrv = ThreadWithReturnValue(target=self.command, args=(data,))
 | |
|                     twrv.start()
 | |
|                     response = None
 | |
|                     while twrv.is_alive():
 | |
|                         if (datetime.now() - last_pong).seconds >= 15:
 | |
|                             last_pong = datetime.now()
 | |
|                             vvvv("command still running, sending keepalive packet")
 | |
|                             data2 = json.dumps(dict(pong=True))
 | |
|                             data2 = self.server.key.Encrypt(data2)
 | |
|                             self.send_data(data2)
 | |
|                         time.sleep(0.1)
 | |
|                     response = twrv._return
 | |
|                     vvvv("thread is done, response from join was %s" % response)
 | |
|                 elif mode == 'put':
 | |
|                     vvvv("received a put request, putting it")
 | |
|                     response = self.put(data)
 | |
|                 elif mode == 'fetch':
 | |
|                     vvvv("received a fetch request, getting it")
 | |
|                     response = self.fetch(data)
 | |
| 
 | |
|                 vvvv("response result is %s" % str(response))
 | |
|                 data2 = json.dumps(response)
 | |
|                 data2 = self.server.key.Encrypt(data2)
 | |
|                 vvvv("sending the response back to the controller")
 | |
|                 self.send_data(data2)
 | |
|                 vvvv("done sending the response")
 | |
|         except:
 | |
|             tb = traceback.format_exc()
 | |
|             log("encountered an unhandled exception in the handle() function")
 | |
|             log("error was:\n%s" % tb)
 | |
|             data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function"))
 | |
|             data2 = self.server.key.Encrypt(data2)
 | |
|             self.send_data(data2)
 | |
| 
 | |
|     def command(self, data):
 | |
|         if 'cmd' not in data:
 | |
|             return dict(failed=True, msg='internal error: cmd is required')
 | |
|         if 'tmp_path' not in data:
 | |
|             return dict(failed=True, msg='internal error: tmp_path is required')
 | |
|         if 'executable' not in data:
 | |
|             return dict(failed=True, msg='internal error: executable is required')
 | |
| 
 | |
|         vvvv("executing: %s" % data['cmd'])
 | |
|         rc, stdout, stderr = self.server.module.run_command(data['cmd'], executable=data['executable'], close_fds=True)
 | |
|         if stdout is None:
 | |
|             stdout = ''
 | |
|         if stderr is None:
 | |
|             stderr = ''
 | |
|         vvvv("got stdout: %s" % stdout)
 | |
|         vvvv("got stderr: %s" % stderr)
 | |
| 
 | |
|         return dict(rc=rc, stdout=stdout, stderr=stderr)
 | |
| 
 | |
|     def fetch(self, data):
 | |
|         if 'in_path' not in data:
 | |
|             return dict(failed=True, msg='internal error: in_path is required')
 | |
| 
 | |
|         try:
 | |
|             fd = file(data['in_path'], 'rb')
 | |
|             fstat = os.stat(data['in_path'])
 | |
|             vvv("FETCH file is %d bytes" % fstat.st_size)
 | |
|             while fd.tell() < fstat.st_size:
 | |
|                 data = fd.read(CHUNK_SIZE)
 | |
|                 last = False
 | |
|                 if fd.tell() >= fstat.st_size:
 | |
|                     last = True
 | |
|                 data = dict(data=base64.b64encode(data), last=last)
 | |
|                 data = json.dumps(data)
 | |
|                 data = self.server.key.Encrypt(data)
 | |
| 
 | |
|                 if self.send_data(data):
 | |
|                     return dict(failed=True, stderr="failed to send data")
 | |
| 
 | |
|                 response = self.recv_data()
 | |
|                 if not response:
 | |
|                     log("failed to get a response, aborting")
 | |
|                     return dict(failed=True, stderr="Failed to get a response from %s" % self.host)
 | |
|                 response = self.server.key.Decrypt(response)
 | |
|                 response = json.loads(response)
 | |
| 
 | |
|                 if response.get('failed',False):
 | |
|                     log("got a failed response from the master")
 | |
|                     return dict(failed=True, stderr="Master reported failure, aborting transfer")
 | |
|         except Exception, e:
 | |
|             fd.close()
 | |
|             tb = traceback.format_exc()
 | |
|             log("failed to fetch the file: %s" % tb)
 | |
|             return dict(failed=True, stderr="Could not fetch the file: %s" % str(e))
 | |
| 
 | |
|         fd.close()
 | |
|         return dict()
 | |
| 
 | |
|     def put(self, data):
 | |
|         if 'data' not in data:
 | |
|             return dict(failed=True, msg='internal error: data is required')
 | |
|         if 'out_path' not in data:
 | |
|             return dict(failed=True, msg='internal error: out_path is required')
 | |
| 
 | |
|         final_path = None
 | |
|         final_user = None
 | |
|         if 'user' in data and data.get('user') != getpass.getuser():
 | |
|             vv("the target user doesn't match this user, we'll move the file into place via sudo")
 | |
|             (fd,out_path) = tempfile.mkstemp(prefix='ansible.', dir=os.path.expanduser('~/.ansible/tmp/'))
 | |
|             out_fd = os.fdopen(fd, 'w', 0)
 | |
|             final_path = data['out_path']
 | |
|             final_user = data['user']
 | |
|         else:
 | |
|             out_path = data['out_path']
 | |
|             out_fd = open(out_path, 'w')
 | |
| 
 | |
|         try:
 | |
|             bytes=0
 | |
|             while True:
 | |
|                 out = base64.b64decode(data['data'])
 | |
|                 bytes += len(out)
 | |
|                 out_fd.write(out)
 | |
|                 response = json.dumps(dict())
 | |
|                 response = self.server.key.Encrypt(response)
 | |
|                 self.send_data(response)
 | |
|                 if data['last']:
 | |
|                     break
 | |
|                 data = self.recv_data()
 | |
|                 if not data:
 | |
|                     raise ""
 | |
|                 data = self.server.key.Decrypt(data)
 | |
|                 data = json.loads(data)
 | |
|         except:
 | |
|             out_fd.close()
 | |
|             tb = traceback.format_exc()
 | |
|             log("failed to put the file: %s" % tb)
 | |
|             return dict(failed=True, stdout="Could not write the file")
 | |
| 
 | |
|         vvvv("wrote %d bytes" % bytes)
 | |
|         out_fd.close()
 | |
| 
 | |
|         if final_path:
 | |
|             vvv("moving %s to %s" % (out_path, final_path))
 | |
|             self.server.module.atomic_move(out_path, final_path)
 | |
|         return dict()
 | |
| 
 | |
| def daemonize(module, password, port, timeout, minutes, ipv6):
 | |
|     try:
 | |
|         daemonize_self(module, password, port, minutes)
 | |
| 
 | |
|         def catcher(signum, _):
 | |
|             module.exit_json(msg='timer expired')
 | |
| 
 | |
|         signal.signal(signal.SIGALRM, catcher)
 | |
|         signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
 | |
| 
 | |
|         if ipv6:
 | |
|             server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout)
 | |
|         else:
 | |
|             server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout)
 | |
|         server.allow_reuse_address = True
 | |
|         
 | |
|         vv("serving!")
 | |
|         server.serve_forever(poll_interval=1.0)
 | |
|     except Exception, e:
 | |
|         tb = traceback.format_exc()
 | |
|         log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb))
 | |
|         sys.exit(0)
 | |
| 
 | |
| def main():
 | |
|     global DEBUG_LEVEL
 | |
|     module = AnsibleModule(
 | |
|         argument_spec = dict(
 | |
|             port=dict(required=False, default=5099),
 | |
|             ipv6=dict(required=False, default=False),
 | |
|             timeout=dict(required=False, default=300),
 | |
|             password=dict(required=True),
 | |
|             minutes=dict(required=False, default=30),
 | |
|             debug=dict(required=False, default=0, type='int')
 | |
|         ),
 | |
|         supports_check_mode=True
 | |
|     )
 | |
| 
 | |
|     password  = base64.b64decode(module.params['password'])
 | |
|     port      = int(module.params['port'])
 | |
|     timeout   = int(module.params['timeout'])
 | |
|     minutes   = int(module.params['minutes'])
 | |
|     debug     = int(module.params['debug'])
 | |
|     ipv6      = bool(module.params['ipv6'])
 | |
| 
 | |
|     if not HAS_KEYCZAR:
 | |
|         module.fail_json(msg="keyczar is not installed (on the remote side)")
 | |
| 
 | |
|     DEBUG_LEVEL=debug
 | |
| 
 | |
|     daemonize(module, password, port, timeout, minutes, ipv6)
 | |
| 
 | |
| # this is magic, see lib/ansible/module_common.py
 | |
| #<<INCLUDE_ANSIBLE_MODULE_COMMON>>
 | |
| main()
 |