#!/usr/bin/env python
# Copyright: (c) 2017, Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)

__metaclass__ = type
__requires__ = ['ansible']

try:
    import pkg_resources
except Exception:
    pass

import fcntl
import os
import shlex
import signal
import socket
import sys
import time
import traceback
import errno
import json

from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils.connection import Connection, ConnectionError, send_data, recv_data
from ansible.module_utils.service import fork_process
from ansible.playbook.play_context import PlayContext
from ansible.plugins.loader import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.utils.display import Display
from ansible.utils.jsonrpc import JsonRpcServer


class ConnectionProcess(object):
    '''
    The connection process wraps around a Connection object that manages
    the connection to a remote device that persists over the playbook
    '''
    def __init__(self, fd, play_context, socket_path, original_path, ansible_playbook_pid=None):
        self.play_context = play_context
        self.socket_path = socket_path
        self.original_path = original_path

        self.fd = fd
        self.exception = None

        self.srv = JsonRpcServer()
        self.sock = None

        self.connection = None
        self._ansible_playbook_pid = ansible_playbook_pid

    def start(self):
        try:
            messages = list()
            result = {}

            messages.append('control socket path is %s' % self.socket_path)

            # If this is a relative path (~ gets expanded later) then plug the
            # key's path on to the directory we originally came from, so we can
            # find it now that our cwd is /
            if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/':
                self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)
            self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null',
                                                    ansible_playbook_pid=self._ansible_playbook_pid)
            self.connection.set_options()
            self.connection._connect()
            self.connection._socket_path = self.socket_path
            self.srv.register(self.connection)
            messages.append('connection to remote device started successfully')

            self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            self.sock.bind(self.socket_path)
            self.sock.listen(1)
            messages.append('local domain socket listeners started successfully')
        except Exception as exc:
            result['error'] = to_text(exc)
            result['exception'] = traceback.format_exc()
        finally:
            result['messages'] = messages
            self.fd.write(json.dumps(result))
            self.fd.close()

    def run(self):
        try:
            while self.connection.connected:
                signal.signal(signal.SIGALRM, self.connect_timeout)
                signal.signal(signal.SIGTERM, self.handler)
                signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)

                self.exception = None
                (s, addr) = self.sock.accept()
                signal.alarm(0)

                signal.signal(signal.SIGALRM, self.command_timeout)
                while True:
                    data = recv_data(s)
                    if not data:
                        break

                    signal.alarm(self.connection._play_context.timeout)
                    resp = self.srv.handle_request(data)
                    signal.alarm(0)

                    send_data(s, to_bytes(resp))

                s.close()

        except Exception as e:
            # socket.accept() will raise EINTR if the socket.close() is called
            if hasattr(e, 'errno'):
                if e.errno != errno.EINTR:
                    self.exception = traceback.format_exc()
            else:
                self.exception = traceback.format_exc()

        finally:
            # when done, close the connection properly and cleanup
            # the socket file so it can be recreated
            self.shutdown()

    def connect_timeout(self, signum, frame):
        display.display('persistent connection idle timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True)
        self.shutdown()

    def command_timeout(self, signum, frame):
        display.display('command timeout triggered, timeout value is %s secs' % self.play_context.timeout, log_only=True)
        self.shutdown()

    def handler(self, signum, frame):
        display.display('signal handler called with signal %s' % signum, log_only=True)
        self.shutdown()

    def shutdown(self):
        """ Shuts down the local domain socket
        """
        if os.path.exists(self.socket_path):
            try:
                if self.sock:
                    self.sock.close()
                if self.connection:
                    self.connection.close()
            except:
                pass
            finally:
                if os.path.exists(self.socket_path):
                    os.remove(self.socket_path)
                    setattr(self.connection, '_socket_path', None)
                    setattr(self.connection, '_connected', False)
        display.display('shutdown complete', log_only=True)

    def do_EXEC(self, data):
        cmd = data.split(b'EXEC: ')[1]
        return self.connection.exec_command(cmd)

    def do_PUT(self, data):
        (op, src, dst) = shlex.split(to_native(data))
        return self.connection.fetch_file(src, dst)

    def do_FETCH(self, data):
        (op, src, dst) = shlex.split(to_native(data))
        return self.connection.put_file(src, dst)

    def do_CONTEXT(self, data):
        pc_data = data.split(b'CONTEXT: ', 1)[1]

        if PY3:
            pc_data = cPickle.loads(pc_data, encoding='bytes')
        else:
            pc_data = cPickle.loads(pc_data)

        pc = PlayContext()
        pc.deserialize(pc_data)

        try:
            self.connection.update_play_context(pc)
        except AttributeError:
            pass

        return (0, 'ok', '')

    def do_RUN(self, data):
        timeout = self.play_context.timeout
        while bool(timeout):
            if os.path.exists(self.socket_path):
                break
            time.sleep(1)
            timeout -= 1
        socket_bytes = to_bytes(self.socket_path, errors='surrogate_or_strict')
        return 0, b'\n#SOCKET_PATH#: %s\n' % socket_bytes, ''


def communicate(sock, data):
    send_data(sock, data)
    rc = int(recv_data(sock), 10)
    stdout = recv_data(sock)
    stderr = recv_data(sock)
    return (rc, stdout, stderr)


def main():
    """ Called to initiate the connect to the remote device
    """
    rc = 0
    result = {}
    messages = list()
    socket_path = None

    # Need stdin as a byte stream
    if PY3:
        stdin = sys.stdin.buffer
    else:
        stdin = sys.stdin

    try:
        # read the play context data via stdin, which means depickling it
        cur_line = stdin.readline()
        init_data = b''

        while cur_line.strip() != b'#END_INIT#':
            if cur_line == b'':
                raise Exception("EOF found before init data was complete")
            init_data += cur_line
            cur_line = stdin.readline()

        if PY3:
            pc_data = cPickle.loads(init_data, encoding='bytes')
        else:
            pc_data = cPickle.loads(init_data)

        play_context = PlayContext()
        play_context.deserialize(pc_data)

    except Exception as e:
        rc = 1
        result.update({
            'error': to_text(e),
            'exception': traceback.format_exc()
        })

    if rc == 0:
        ssh = connection_loader.get('ssh', class_only=True)
        ansible_playbook_pid = sys.argv[1]
        cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection, ansible_playbook_pid)

        # create the persistent connection dir if need be and create the paths
        # which we will be using later
        tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
        makedirs_safe(tmp_path)

        lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
        socket_path = unfrackpath(cp % dict(directory=tmp_path))

        # if the socket file doesn't exist, spin up the daemon process
        lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600)
        fcntl.lockf(lock_fd, fcntl.LOCK_EX)

        if not os.path.exists(socket_path):
            messages.append('local domain socket does not exist, starting it')
            original_path = os.getcwd()
            r, w = os.pipe()
            pid = fork_process()

            if pid == 0:
                try:
                    os.close(r)
                    wfd = os.fdopen(w, 'w')
                    process = ConnectionProcess(wfd, play_context, socket_path, original_path, ansible_playbook_pid)
                    process.start()
                except Exception:
                    messages.append(traceback.format_exc())
                    rc = 1

                fcntl.lockf(lock_fd, fcntl.LOCK_UN)
                os.close(lock_fd)

                if rc == 0:
                    process.run()

                sys.exit(rc)

            else:
                os.close(w)
                rfd = os.fdopen(r, 'r')
                data = json.loads(rfd.read())
                messages.extend(data.pop('messages'))
                result.update(data)

        else:
            messages.append('found existing local domain socket, using it!')
            conn = Connection(socket_path)
            pc_data = to_text(init_data)
            try:
                messages.extend(conn.update_play_context(pc_data))
            except Exception as exc:
                # Only network_cli has update_play context, so missing this is
                # not fatal e.g. netconf
                if isinstance(exc, ConnectionError) and getattr(exc, 'code', None) == -32601:
                    pass
                else:
                    result.update({
                        'error': to_text(exc),
                        'exception': traceback.format_exc()
                    })

    result.update({
        'messages': messages,
        'socket_path': socket_path
    })

    if 'exception' in result:
        rc = 1
        sys.stderr.write(json.dumps(result))
    else:
        rc = 0
        sys.stdout.write(json.dumps(result))

    sys.exit(rc)

if __name__ == '__main__':
    display = Display()
    main()