From b45342923c3633bd08701dad37c5cfd5bd4b75e8 Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Tue, 27 Aug 2013 13:12:35 -0500 Subject: [PATCH] Initial support for sudoable commands over fireball2 Caveats: * requiretty must be disabled in the sudoers config * asking for a password doesn't work yet, so any sudoers users must be configured with NOPASSWD * if not starting the daemon as root, the user running the daemon must have sudoers entries to allow them to run the command as the target sudo_user --- .../runner/connection_plugins/fireball2.py | 77 +++++++++++++------ library/utilities/fireball2 | 36 +++++++-- 2 files changed, 83 insertions(+), 30 deletions(-) diff --git a/lib/ansible/runner/connection_plugins/fireball2.py b/lib/ansible/runner/connection_plugins/fireball2.py index 30c56442a2..91628d1d70 100644 --- a/lib/ansible/runner/connection_plugins/fireball2.py +++ b/lib/ansible/runner/connection_plugins/fireball2.py @@ -61,15 +61,25 @@ class Connection(object): def connect(self, allow_ssh=True): ''' activates the connection object ''' - if self.is_connected: - return self - try: - self.conn = socket.socket() - self.conn.connect((self.host,self.fbport)) + if not self.is_connected: + # TODO: make the timeout and retries configurable? + tries = 10 + self.conn = socket.socket() + self.conn.settimeout(30.0) + while tries > 0: + try: + self.conn.connect((self.host,self.fbport)) + break + except: + time.sleep(0.1) + tries -= 1 + if tries == 0: + vvv("Could not connect via the fireball2 connection, exceeded # of tries") + raise errors.AnsibleError("Failed to connect") except: if allow_ssh: - print "Falling back to ssh to startup accelerated mode" + vvv("Falling back to ssh to startup accelerated mode") res = self._execute_fb_module() return self.connect(allow_ssh=False) else: @@ -84,23 +94,29 @@ class Connection(object): def recv_data(self): header_len = 8 # size of a packed unsigned long long data = b"" - while len(data) < header_len: - d = self.conn.recv(1024) - if not d: - return None - data += d - data_len = struct.unpack('Q',data[:header_len])[0] - data = data[header_len:] - while len(data) < data_len: - d = self.conn.recv(1024) - if not d: - return None - data += d - return data + try: + while len(data) < header_len: + d = self.conn.recv(1024) + if not d: + return None + data += d + data_len = struct.unpack('Q',data[:header_len])[0] + data = data[header_len:] + while len(data) < data_len: + d = self.conn.recv(1024) + if not d: + return None + data += d + return data + except socket.timeout: + raise errors.AnsibleError("timed out while waiting to receive data") def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False, executable='/bin/sh'): ''' run a command on the remote host ''' + if self.runner.sudo or sudoable and sudo_user: + cmd, prompt = utils.make_sudo_cmd(sudo_user, executable, cmd) + vvv("EXEC COMMAND %s" % cmd) data = dict( @@ -112,12 +128,15 @@ class Connection(object): data = utils.jsonify(data) data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnisbleError("Failed to send command to %s:%s" % (self.host,self.port)) + raise errors.AnisbleError("Failed to send command to %s" % self.host) response = self.recv_data() + if not response: + raise errors.AnsibleError("Failed to get a response from %s" % self.host) response = utils.decrypt(self.key, response) response = utils.parse_json(response) + vvv("COMMAND DONE: rc=%s" % str(response.get('rc',""))) return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr','')) def put_file(self, in_path, out_path): @@ -132,17 +151,23 @@ class Connection(object): data = base64.b64encode(data) data = dict(mode='put', data=data, out_path=out_path) + if self.runner.sudo: + data['user'] = self.runner.sudo_user + # TODO: support chunked file transfer data = utils.jsonify(data) data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnsibleError("failed to send the file to %s:%s" % (self.host,self.port)) + raise errors.AnsibleError("failed to send the file to %s" % self.host) response = self.recv_data() - response = utils.decrypt(self.key, data) + if not response: + raise errors.AnsibleError("Failed to get a response from %s" % self.host) + response = utils.decrypt(self.key, response) response = utils.parse_json(response) - # no meaningful response needed for this + if response.get('failed',False): + raise errors.AnsibleError("failed to put the file in the requested location") def fetch_file(self, in_path, out_path): ''' save a remote file to the specified path ''' @@ -152,10 +177,12 @@ class Connection(object): data = utils.jsonify(data) data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnsibleError("failed to initiate the file fetch with %s:%s" % (self.host,self.port)) + raise errors.AnsibleError("failed to initiate the file fetch with %s" % self.host) response = self.recv_data() - response = utils.decrypt(self.key, data) + if not response: + raise errors.AnsibleError("Failed to get a response from %s" % self.host) + response = utils.decrypt(self.key, response) response = utils.parse_json(response) response = response['data'] response = base64.b64decode(response) diff --git a/library/utilities/fireball2 b/library/utilities/fireball2 index 9d32e1ddd2..3df2ebb5cd 100644 --- a/library/utilities/fireball2 +++ b/library/utilities/fireball2 @@ -60,12 +60,15 @@ EXAMPLES = ''' ''' import os +import os.path +import tempfile import sys import shutil import socket import struct import time import base64 +import getpass import syslog import signal import time @@ -138,7 +141,6 @@ def daemonize_self(module, password, port, minutes): os.dup2(dev_null.fileno(), sys.stderr.fileno()) log("daemonizing successful") -#class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): class ThreadedTCPServer(SocketServer.ThreadingTCPServer): def __init__(self, server_address, RequestHandlerClass, module, password): self.module = module @@ -171,11 +173,14 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): def handle(self): while True: + #log("waiting for data") data = self.recv_data() if not data: break try: + #log("got data, decrypting") data = self.server.key.Decrypt(data) + #log("decryption done") except: log("bad decrypt, skipping...") data2 = json.dumps(dict(rc=1)) @@ -183,6 +188,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): send_data(client, data2) return + #log("loading json from the data") data = json.loads(data) mode = data['mode'] @@ -212,7 +218,8 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): stdout = '' if stderr is None: stderr = '' - log("got stdout: %s" % stdout) + #log("got stdout: %s" % stdout) + #log("got stderr: %s" % stderr) return dict(rc=rc, stdout=stdout, stderr=stderr) @@ -234,14 +241,32 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): if 'out_path' not in data: return dict(failed=True, msg='internal error: out_path is required') + final_path = None + if 'user' in data and data.get('user') != getpass.getuser(): + log("the target user doesn't match this user, we'll move the file into place via sudo") + (fd,out_path) = tempfile.mkstemp(prefix='ansible.', dir=os.path.expanduser('~/.ansible/tmp/')) + out_fd = os.fdopen(fd, 'w', 0) + final_path = data['out_path'] + else: + out_path = data['out_path'] + out_fd = open(out_path, 'w') + # FIXME: should probably support chunked file transfer for binary files # at some point. For now, just base64 encodes the file # so don't use it to move ISOs, use rsync. - fh = open(data['out_path'], 'w') - fh.write(base64.b64decode(data['data'])) - fh.close() + try: + out_fd.write(base64.b64decode(data['data'])) + out_fd.close() + except: + return dict(failed=True, stdout="Could not write the file") + if final_path: + log("moving %s to %s" % (out_path, final_path)) + args = ['sudo','mv',out_path,final_path] + rc, stdout, stderr = self.server.module.run_command(args, close_fds=True) + if rc != 0: + return dict(failed=True, stdout="failed to copy the file into position with sudo") return dict() def daemonize(module, password, port, minutes): @@ -257,6 +282,7 @@ def daemonize(module, password, port, minutes): server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password) server.allow_reuse_address = True + log("serving!") server.serve_forever(poll_interval=1.0) except Exception, e: tb = traceback.format_exc()