#!/usr/bin/env python

# (c) 2017, Ansible, Inc. <support@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

########################################################
from __future__ import (absolute_import, division, print_function)

__metaclass__ = type
__requires__ = ['ansible']

try:
    import pkg_resources
except Exception:
    pass

import fcntl
import os
import shlex
import signal
import socket
import sys
import time
import traceback
import datetime
import errno

from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils.connection import send_data, recv_data
from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.errors import AnsibleConnectionFailure
from ansible.utils.display import Display


def do_fork():
    '''
    Does the required double fork for a daemon process. Based on
    http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
    '''
    try:
        pid = os.fork()
        if pid > 0:
            return pid

        #os.chdir("/")
        os.setsid()
        os.umask(0)

        try:
            pid = os.fork()
            if pid > 0:
                sys.exit(0)

            if C.DEFAULT_LOG_PATH != '':
                out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
                err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
            else:
                out_file = open('/dev/null', 'ab+')
                err_file = open('/dev/null', 'ab+', 0)

            os.dup2(out_file.fileno(), sys.stdout.fileno())
            os.dup2(err_file.fileno(), sys.stderr.fileno())
            os.close(sys.stdin.fileno())

            return pid
        except OSError as e:
            sys.exit(1)
    except OSError as e:
        sys.exit(1)


class Server():

    def __init__(self, socket_path, play_context):
        self.socket_path = socket_path
        self.play_context = play_context

        display.display(
            'creating new control socket for host %s:%s as user %s' %
            (play_context.remote_addr, play_context.port, play_context.remote_user),
            log_only=True
        )

        display.display('control socket path is %s' % socket_path, log_only=True)
        display.display('current working directory is %s' % os.getcwd(), log_only=True)

        self._start_time = datetime.datetime.now()

        display.display("using connection plugin %s" % self.play_context.connection, log_only=True)

        self.connection = connection_loader.get(play_context.connection, play_context, sys.stdin)
        self.connection._connect()

        if not self.connection.connected:
            raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr)

        connection_time = datetime.datetime.now() - self._start_time
        display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True)

        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.socket.bind(self.socket_path)
        self.socket.listen(1)
        display.display('local socket is set to listening', log_only=True)

    def run(self):
        try:
            while True:
                signal.signal(signal.SIGALRM, self.connect_timeout)
                signal.signal(signal.SIGTERM, self.handler)
                signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)

                (s, addr) = self.socket.accept()
                display.display('incoming request accepted on persistent socket', log_only=True)
                signal.alarm(0)

                while True:
                    data = recv_data(s)
                    if not data:
                        break

                    signal.signal(signal.SIGALRM, self.command_timeout)
                    signal.alarm(self.play_context.timeout)

                    op = data.split(':')[0]
                    display.display('socket operation is %s' % op, log_only=True)

                    method = getattr(self, 'do_%s' % op, None)

                    rc = 255
                    stdout = stderr = ''

                    if not method:
                        stderr = 'Invalid action specified'
                    else:
                        rc, stdout, stderr = method(data)

                    signal.alarm(0)

                    display.display('socket operation completed with rc %s' % rc, log_only=True)

                    send_data(s, to_bytes(rc))
                    send_data(s, to_bytes(stdout))
                    send_data(s, to_bytes(stderr))

                s.close()

        except Exception as e:
            # socket.accept() will raise EINTR if the socket.close() is called
            if e.errno != errno.EINTR:
                display.display(traceback.format_exc(), log_only=True)

        finally:
            # when done, close the connection properly and cleanup
            # the socket file so it can be recreated
            self.shutdown()
            end_time = datetime.datetime.now()
            delta = end_time - self._start_time
            display.display('shutdown local socket, connection was active for %s secs' % delta, log_only=True)

    def connect_timeout(self, signum, frame):
        display.display('connect timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True)
        self.shutdown()

    def command_timeout(self, signum, frame):
        display.display('commnad timeout triggered, timeout value is %s secs' % self.play_context.timeout, log_only=True)
        self.shutdown()

    def handler(self, signum, frame):
        display.display('signal handler called with signal %s' % signum, log_only=True)
        self.shutdown()

    def shutdown(self):
        display.display('shutdown persistent connection requested', log_only=True)

        if not os.path.exists(self.socket_path):
            display.display('persistent connection is not active', log_only=True)
            return

        try:
            if self.socket:
                display.display('closing local listener', log_only=True)
                self.socket.close()
            if self.connection:
                display.display('closing the connection', log_only=True)
                self.close()
        except:
            pass
        finally:
            if os.path.exists(self.socket_path):
                display.display('removing the local control socket', log_only=True)
                os.remove(self.socket_path)

        display.display('shutdown complete', log_only=True)

    def do_EXEC(self, data):
        cmd = data.split(b'EXEC: ')[1]
        return self.connection.exec_command(cmd)

    def do_PUT(self, data):
        (op, src, dst) = shlex.split(to_native(data))
        return self.connection.fetch_file(src, dst)

    def do_FETCH(self, data):
        (op, src, dst) = shlex.split(to_native(data))
        return self.connection.put_file(src, dst)

    def do_CONTEXT(self, data):
        pc_data = data.split(b'CONTEXT: ', 1)[1]

        if PY3:
            pc_data = cPickle.loads(pc_data, encoding='bytes')
        else:
            pc_data = cPickle.loads(pc_data)

        pc = PlayContext()
        pc.deserialize(pc_data)

        try:
            self.connection.update_play_context(pc)
        except AttributeError:
            pass

        return (0, 'ok', '')

    def do_RUN(self, data):
        timeout = self.play_context.timeout
        while bool(timeout):
            if os.path.exists(self.socket_path):
                break
            time.sleep(1)
            timeout -= 1
        return (0, self.socket_path, '')


def communicate(sock, data):
    send_data(sock, data)
    rc = int(recv_data(sock), 10)
    stdout = recv_data(sock)
    stderr = recv_data(sock)
    return (rc, stdout, stderr)

def main():
    # Need stdin as a byte stream
    if PY3:
        stdin = sys.stdin.buffer
    else:
        stdin = sys.stdin

    try:
        # read the play context data via stdin, which means depickling it
        # FIXME: as noted above, we will probably need to deserialize the
        #        connection loader here as well at some point, otherwise this
        #        won't find role- or playbook-based connection plugins
        cur_line = stdin.readline()
        init_data = b''
        while cur_line.strip() != b'#END_INIT#':
            if cur_line == b'':
                raise Exception("EOF found before init data was complete")
            init_data += cur_line
            cur_line = stdin.readline()
        if PY3:
            pc_data = cPickle.loads(init_data, encoding='bytes')
        else:
            pc_data = cPickle.loads(init_data)

        pc = PlayContext()
        pc.deserialize(pc_data)

    except Exception as e:
        # FIXME: better error message/handling/logging
        sys.stderr.write(traceback.format_exc())
        sys.exit("FAIL: %s" % e)

    ssh = connection_loader.get('ssh', class_only=True)
    cp = ssh._create_control_path(pc.remote_addr, pc.connection, pc.remote_user)

    # create the persistent connection dir if need be and create the paths
    # which we will be using later
    tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
    makedirs_safe(tmp_path)
    lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
    socket_path = unfrackpath(cp % dict(directory=tmp_path))

    # if the socket file doesn't exist, spin up the daemon process
    lock_fd = os.open(lock_path, os.O_RDWR|os.O_CREAT, 0o600)
    fcntl.lockf(lock_fd, fcntl.LOCK_EX)

    if not os.path.exists(socket_path):
        pid = do_fork()
        if pid == 0:
            rc = 0
            try:
                server = Server(socket_path, pc)
            except AnsibleConnectionFailure as exc:
                display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True)
                display.display(str(exc), log_only=True)
                rc = 1
            except Exception as exc:
                display.display('failed to create control socket for host %s' % pc.remote_addr, log_only=True)
                display.display(traceback.format_exc(), log_only=True)
                rc = 1
            fcntl.lockf(lock_fd, fcntl.LOCK_UN)
            os.close(lock_fd)
            if rc == 0:
                server.run()
            sys.exit(rc)
    else:
        display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True)

    fcntl.lockf(lock_fd, fcntl.LOCK_UN)
    os.close(lock_fd)

    timeout = pc.timeout
    while bool(timeout):
        if os.path.exists(socket_path):
            display.vvvv('connected to local socket in %s' % (pc.timeout - timeout), pc.remote_addr)
            break
        time.sleep(1)
        timeout -= 1
    else:
        raise AnsibleConnectionFailure('timeout waiting for local socket', pc.remote_addr)

    # now connect to the daemon process
    # FIXME: if the socket file existed but the daemonized process was killed,
    #        the connection will timeout here. Need to make this more resilient.
    while True:
        data = stdin.readline()
        if data == b'':
            break
        if data.strip() == b'':
            continue

        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

        attempts = C.PERSISTENT_CONNECT_RETRIES
        while bool(attempts):
            try:
                sock.connect(socket_path)
                break
            except socket.error:
                time.sleep(C.PERSISTENT_CONNECT_INTERVAL)
                attempts -= 1
        else:
            display.display('number of connection attempts exceeded, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True)
            display.display('persistent_connect_interval=%s, persistent_connect_retries=%s' % (C.PERSISTENT_CONNECT_INTERVAL, C.PERSISTENT_CONNECT_RETRIES), pc.remote_addr, pc.remote_user, log_only=True)
            sys.stderr.write('failed to connect to control socket')
            sys.exit(255)

        # send the play_context back into the connection so the connection
        # can handle any privilege escalation activities
        pc_data = b'CONTEXT: %s' % init_data
        communicate(sock, pc_data)

        rc, stdout, stderr = communicate(sock, data.strip())

        sys.stdout.write(to_native(stdout))
        sys.stderr.write(to_native(stderr))

        sock.close()
        break

    sys.exit(rc)

if __name__ == '__main__':
    display = Display()
    main()