#!/usr/bin/env python

# (c) 2016 Matt Clay <matt@mystile.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function

import datetime
import json
import os
import subprocess
import sys
import traceback
import uuid

from argparse import ArgumentParser
from textwrap import dedent
from time import sleep


def main():
    parser = ArgumentParser(description='Manage remote instances for testing.')

    parser.add_argument('-v',
                        '--verbose',
                        dest='verbose',
                        action='store_true',
                        help='write verbose output to stderr')

    parser.add_argument('--endpoint',
                        dest='endpoint',
                        default='https://14blg63h2i.execute-api.us-east-1.amazonaws.com',
                        help='api endpoint')

    parser.add_argument('--stage',
                        dest='stage',
                        default='prod',
                        choices=['dev', 'prod'],
                        help='api stage (default: prod)')

    subparsers = parser.add_subparsers()

    # sub parser: start

    start_parser = subparsers.add_parser('start', help='start instance')
    start_parser.set_defaults(func=start_instance)

    start_subparsers = start_parser.add_subparsers(dest='start')

    start_parser.add_argument('platform', help='platform (ex: windows)')
    start_parser.add_argument('version', help='version (ex: 2012-R2_RTM)')

    start_parser.add_argument('--id',
                              dest='instance_id',
                              default=uuid.uuid4(),
                              help='instance id to create')

    start_parser.add_argument('--public-key',
                              dest='ssh_key',
                              default=None,
                              help='path to ssh public key for authentication')

    start_parser.add_argument('--query',
                              dest='query',
                              action='store_true',
                              default=False,
                              help='query only, do not start instance')

    shippable = start_subparsers.add_parser('shippable', help='start instance for shippable testing')

    shippable.add_argument('--run-id',
                           dest='run',
                           default=os.environ.get('SHIPPABLE_BUILD_ID'),
                           help='shippable run id')

    shippable.add_argument('--job-number',
                           dest='job',
                           type=int,
                           default=os.environ.get('SHIPPABLE_JOB_NUMBER'),
                           help='shippable job number')

    remote_key = get_remote_key()

    remote = start_subparsers.add_parser('remote', help='start instance for remote testing')

    remote.add_argument('--key',
                        dest='key',
                        default=remote_key,
                        required=remote_key is None,
                        help='remote key')

    remote.add_argument('--nonce',
                        dest='nonce',
                        default=None,
                        help='optional nonce')

    # sub parser: get

    get_parser = subparsers.add_parser('get', help='get instance')
    get_parser.set_defaults(func=get_instance)

    get_parser.add_argument('instance_id', help='id of instance previously created')

    get_parser.add_argument('--template',
                            dest='template',
                            help='inventory template')

    get_parser.add_argument('--tries',
                            dest='tries',
                            default=60,
                            type=int,
                            help='number of tries waiting for instance (default: 60)')

    get_parser.add_argument('--sleep',
                            dest='sleep',
                            default=10,
                            type=int,
                            help='sleep seconds between tries (default: 10)')

    # sub parser: stop

    stop_parser = subparsers.add_parser('stop', help='stop instance')
    stop_parser.set_defaults(func=stop_instance)

    stop_parser.add_argument('instance_id', help='id of instance previously created')

    # parse

    args = parser.parse_args()
    args.func(args)


def get_remote_key():
    path = os.path.join(os.environ['HOME'], '.ansible-core-ci.key')

    try:
        with open(path, 'r') as f:
            return f.read().strip()
    except IOError:
        return None


def get_instance_uri(args):
    return '%s/%s/jobs/%s' % (args.endpoint, args.stage, args.instance_id)


def start_instance(args):
    if args.ssh_key is None:
        public_key = None
    else:
        with open(args.ssh_key, 'r') as f:
            public_key = f.read()

    data = dict(
        config=dict(
            platform=args.platform,
            version=args.version,
            public_key=public_key,
            query=args.query,
        )
    )

    if args.start == 'shippable':
        auth = dict(
            shippable=dict(
                run_id=args.run,
                job_number=args.job,
            ),
        )
    elif args.start == 'remote':
        auth = dict(
            remote=dict(
                key=args.key,
                nonce=args.nonce,
            ),
        )
    else:
        raise Exception('auth required')

    data.update(dict(auth=auth))

    if args.verbose:
        print_stderr('starting instance: %s/%s (%s)' % (args.platform, args.version, args.instance_id))

    headers = {
        'Content-Type': 'application/json',
    }

    uri = get_instance_uri(args)
    response = requests.put(uri, data=json.dumps(data), headers=headers)

    if response.status_code != 200:
        raise Exception(create_http_error(response))

    if args.verbose:
        print_stderr('instance started: %s' % args.instance_id)

    print(args.instance_id)

    if args.query:
        print_stderr(json.dumps(response.json(), indent=4, sort_keys=True))


def stop_instance(args):
    if args.verbose:
        print_stderr('stopping instance: %s' % args.instance_id)

    uri = get_instance_uri(args)
    response = requests.delete(uri)

    if response.status_code != 200:
        raise Exception(create_http_error(response))

    if args.verbose:
        print_stderr('instance stopped: %s' % args.instance_id)


def get_instance(args):
    if args.verbose:
        print_stderr('waiting for instance: %s' % args.instance_id)

    uri = get_instance_uri(args)
    start_time = datetime.datetime.utcnow()

    for i in range(args.tries):
        response = requests.get(uri)

        if response.status_code != 200:
            raise Exception(create_http_error(response))

        response_json = response.json()
        status = response_json['status']

        if status == 'running':
            end_time = datetime.datetime.utcnow()
            duration = end_time - start_time

            if args.verbose:
                print_stderr('waited %s for instance availability' % duration)

            connection = response_json['connection']
            inventory = make_inventory(args.template, connection, args.instance_id)

            print(inventory)

            return

        sleep(args.sleep)

    raise Exception('timeout waiting for instance')


def make_inventory(inventory_template, connection, instance_id):
    if inventory_template is None:
        template = dedent('''
            [winrm]
            windows # @instance_id

            [winrm:vars]
            ansible_connection=winrm
            ansible_host=@ansible_host
            ansible_user=@ansible_user
            ansible_password=@ansible_password
            ansible_port=5986
            ansible_winrm_server_cert_validation=ignore
        ''').strip()
    else:
        with open(inventory_template, 'r') as f:
            template = f.read()

    inventory = template\
        .replace('@instance_id', instance_id)\
        .replace('@ansible_host', connection['hostname'])\
        .replace('@ansible_port', str(connection.get('port', 22)))\
        .replace('@ansible_user', connection['username'])\
        .replace('@ansible_password', connection.get('password', ''))

    return inventory


def print_stderr(*args, **kwargs):
    """Print to stderr."""

    print(*args, file=sys.stderr, **kwargs)


def create_http_error(response):
    response_json = response.json()
    stack_trace = ''

    if 'message' in response_json:
        message = response_json['message']
    elif 'errorMessage' in response_json:
        message = response_json['errorMessage'].strip()
        if 'stackTrace' in response_json:
            trace = '\n'.join([x.rstrip() for x in traceback.format_list(response_json['stackTrace'])])
            stack_trace = ('\nTraceback (from remote server):\n%s' % trace)
    else:
        message = str(response_json)

    return '%s: %s%s' % (response.status_code, message, stack_trace)


class HttpRequest:
    def __init__(self):
        """
        primitive replacement for requests to avoid extra dependency
        avoids use of urllib2 due to lack of SNI support
        """

    def get(self, url):
        return self.request('GET', url)

    def delete(self, url):
        return self.request('DELETE', url)

    def put(self, url, data=None, headers=None):
        return self.request('PUT', url, data, headers)

    def request(self, method, url, data=None, headers=None):
        args = ['/usr/bin/curl', '-s', '-S', '-i', '-X', method]

        if headers is None:
            headers = {}

        headers['Expect'] = ''  # don't send expect continue header

        for header in headers:
            args += ['-H', '%s: %s' % (header, headers[header])]

        if data is not None:
            args += ['-d', data]

        args += [url]

        header, body = subprocess.check_output(args).split('\r\n\r\n', 1)

        response_headers = header.split('\r\n')
        first_line = response_headers[0]
        http_response = first_line.split(' ')
        status_code = int(http_response[1])

        return HttpResponse(status_code, body)


class HttpResponse:
    def __init__(self, status_code, response):
        self.status_code = status_code
        self.response = response

    def json(self):
        try:
            return json.loads(self.response)
        except ValueError:
            raise ValueError('Cannot parse response as JSON:\n%s' % self.response)


requests = HttpRequest()

if __name__ == '__main__':
    main()