From 9574f8947114d7ce20284ff2ceba539c3b7b765c Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Thu, 30 Jan 2014 14:39:52 -0600 Subject: [PATCH] Detect remote_user change in accelerate daemon and allow a restart Fixes #5812 --- .../runner/connection_plugins/accelerate.py | 48 +++++++++++++++++ library/utilities/accelerate | 53 ++++++++++++++++--- 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/lib/ansible/runner/connection_plugins/accelerate.py b/lib/ansible/runner/connection_plugins/accelerate.py index 17cc99c1d7..60c1319262 100644 --- a/lib/ansible/runner/connection_plugins/accelerate.py +++ b/lib/ansible/runner/connection_plugins/accelerate.py @@ -101,6 +101,7 @@ class Connection(object): try: if not self.is_connected: + wrong_user = False tries = 3 self.conn = socket.socket() self.conn.settimeout(constants.ACCELERATE_CONNECT_TIMEOUT) @@ -108,6 +109,12 @@ class Connection(object): while tries > 0: try: self.conn.connect((self.host,self.accport)) + if not self.validate_user(): + # the accelerated daemon was started with a + # different remote_user. The above command + # should have caused the accelerate daemon to + # shutdown, so we'll reconnect. + wrong_user = True break except: vvvv("failed, retrying...") @@ -116,6 +123,9 @@ class Connection(object): if tries == 0: vvv("Could not connect via the accelerated connection, exceeded # of tries") raise errors.AnsibleError("Failed to connect") + elif wrong_user: + vvv("Restarting daemon with a different remote_user") + raise errors.AnsibleError("Wrong user") self.conn.settimeout(constants.ACCELERATE_TIMEOUT) except: if allow_ssh: @@ -159,6 +169,44 @@ class Connection(object): except socket.timeout: raise errors.AnsibleError("timed out while waiting to receive data") + def validate_user(self): + ''' + Checks the remote uid of the accelerated daemon vs. the + one specified for this play and will cause the accel + daemon to exit if they don't match + ''' + + data = dict( + mode='validate_user', + username=self.user, + ) + data = utils.jsonify(data) + data = utils.encrypt(self.key, data) + if self.send_data(data): + raise errors.AnsibleError("Failed to send command to %s" % self.host) + + while True: + # we loop here while waiting for the response, because a + # long running command may cause us to receive keepalive packets + # ({"pong":"true"}) rather than the response we want. + response = self.recv_data() + if not response: + raise errors.AnsibleError("Failed to get a response from %s" % self.host) + response = utils.decrypt(self.key, response) + response = utils.parse_json(response) + if "pong" in response: + # it's a keepalive, go back to waiting + vvvv("%s: received a keepalive packet" % self.host) + continue + else: + vvvv("%s: received the response" % self.host) + break + + if response.get('failed'): + raise errors.AnsibleError("Error while validating user: %s" % response.get("msg")) + else: + return response.get('rc') == 0 + def exec_command(self, cmd, tmp_path, sudo_user=None, sudoable=False, executable='/bin/sh', in_data=None, su=None, su_user=None): ''' run a command on the remote host ''' diff --git a/library/utilities/accelerate b/library/utilities/accelerate index c4069efac9..a6e84e3237 100644 --- a/library/utilities/accelerate +++ b/library/utilities/accelerate @@ -75,6 +75,7 @@ import getpass import json import os import os.path +import pwd import signal import socket import struct @@ -280,6 +281,9 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): elif mode == 'fetch': vvvv("received a fetch request, getting it") response = self.fetch(data) + elif mode == 'validate_user': + vvvv("received a request to validate the user id") + response = self.validate_user(data) vvvv("response result is %s" % str(response)) data2 = json.dumps(response) @@ -287,6 +291,10 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): vvvv("sending the response back to the controller") self.send_data(data2) vvvv("done sending the response") + + if mode == 'validate_user' and response.get('rc') == 1: + vvvv("detected a uid mismatch, shutting down") + self.server.shutdown() except: tb = traceback.format_exc() log("encountered an unhandled exception in the handle() function") @@ -295,6 +303,27 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): data2 = self.server.key.Encrypt(data2) self.send_data(data2) + def validate_user(self, data): + if 'username' not in data: + return dict(failed=True, msg='No username specified') + + vvvv("validating we're running as %s" % data['username']) + + # get the current uid + c_uid = os.getuid() + try: + # the target uid + t_uid = pwd.getpwnam(data['username']).pw_uid + except: + vvvv("could not find user %s" % data['username']) + return dict(failed=True, msg='could not find user %s' % data['username']) + + # and return rc=0 for success, rc=1 for failure + if c_uid == t_uid: + return dict(rc=0) + else: + return dict(rc=1) + def command(self, data): if 'cmd' not in data: return dict(failed=True, msg='internal error: cmd is required') @@ -409,14 +438,26 @@ def daemonize(module, password, port, timeout, minutes, ipv6): signal.signal(signal.SIGALRM, catcher) signal.setitimer(signal.ITIMER_REAL, 60 * minutes) - if ipv6: - server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout) - else: - server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout) - server.allow_reuse_address = True + tries = 5 + while tries > 0: + try: + if ipv6: + server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout) + else: + server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout) + server.allow_reuse_address = True + break + except: + vv("Failed to create the TCP server (tries left = %d)" % tries) + tries -= 1 + time.sleep(0.2) + if tries == 0: + vv("Maximum number of attempts to create the TCP server reached, bailing out") + raise Exception("max # of attempts to serve reached") + vv("serving!") - server.serve_forever(poll_interval=1.0) + server.serve_forever(poll_interval=0.1) except Exception, e: tb = traceback.format_exc() log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb))