#!/usr/bin/env python # (c) 2016 Matt Clay # # This file is part of Ansible # # Ansible is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Ansible is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Ansible. If not, see . from __future__ import print_function import datetime import json import os import subprocess import sys import traceback import uuid from argparse import ArgumentParser from textwrap import dedent from time import sleep def main(): parser = ArgumentParser(description='Manage remote instances for testing.') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='write verbose output to stderr') parser.add_argument('--endpoint', dest='endpoint', default='https://14blg63h2i.execute-api.us-east-1.amazonaws.com', help='api endpoint') parser.add_argument('--stage', dest='stage', default='prod', choices=['dev', 'prod'], help='api stage (default: prod)') subparsers = parser.add_subparsers() # sub parser: start start_parser = subparsers.add_parser('start', help='start instance') start_parser.set_defaults(func=start_instance) start_subparsers = start_parser.add_subparsers(dest='start') start_parser.add_argument('platform', help='platform (ex: windows)') start_parser.add_argument('version', help='version (ex: 2012-R2_RTM)') start_parser.add_argument('--id', dest='instance_id', default=uuid.uuid4(), help='instance id to create') start_parser.add_argument('--public-key', dest='ssh_key', default=None, help='path to ssh public key for authentication') start_parser.add_argument('--query', dest='query', action='store_true', default=False, help='query only, do not start instance') shippable = start_subparsers.add_parser('shippable', help='start instance for shippable testing') shippable.add_argument('--run-id', dest='run', default=os.environ.get('SHIPPABLE_BUILD_ID'), help='shippable run id') shippable.add_argument('--job-number', dest='job', type=int, default=os.environ.get('SHIPPABLE_JOB_NUMBER'), help='shippable job number') remote_key = get_remote_key() remote = start_subparsers.add_parser('remote', help='start instance for remote testing') remote.add_argument('--key', dest='key', default=remote_key, required=remote_key is None, help='remote key') remote.add_argument('--nonce', dest='nonce', default=None, help='optional nonce') # sub parser: get get_parser = subparsers.add_parser('get', help='get instance') get_parser.set_defaults(func=get_instance) get_parser.add_argument('instance_id', help='id of instance previously created') get_parser.add_argument('--template', dest='template', help='inventory template') get_parser.add_argument('--tries', dest='tries', default=60, type=int, help='number of tries waiting for instance (default: 60)') get_parser.add_argument('--sleep', dest='sleep', default=10, type=int, help='sleep seconds between tries (default: 10)') # sub parser: stop stop_parser = subparsers.add_parser('stop', help='stop instance') stop_parser.set_defaults(func=stop_instance) stop_parser.add_argument('instance_id', help='id of instance previously created') # parse args = parser.parse_args() args.func(args) def get_remote_key(): path = os.path.join(os.environ['HOME'], '.ansible-core-ci.key') try: with open(path, 'r') as f: return f.read().strip() except IOError: return None def get_instance_uri(args): return '%s/%s/jobs/%s' % (args.endpoint, args.stage, args.instance_id) def start_instance(args): if args.ssh_key is None: public_key = None else: with open(args.ssh_key, 'r') as f: public_key = f.read() data = dict( config=dict( platform=args.platform, version=args.version, public_key=public_key, query=args.query, ) ) if args.start == 'shippable': auth = dict( shippable=dict( run_id=args.run, job_number=args.job, ), ) elif args.start == 'remote': auth = dict( remote=dict( key=args.key, nonce=args.nonce, ), ) else: raise Exception('auth required') data.update(dict(auth=auth)) if args.verbose: print_stderr('starting instance: %s/%s (%s)' % (args.platform, args.version, args.instance_id)) headers = { 'Content-Type': 'application/json', } uri = get_instance_uri(args) response = requests.put(uri, data=json.dumps(data), headers=headers) if response.status_code != 200: raise Exception(create_http_error(response)) if args.verbose: print_stderr('instance started: %s' % args.instance_id) print(args.instance_id) if args.query: print_stderr(json.dumps(response.json(), indent=4, sort_keys=True)) def stop_instance(args): if args.verbose: print_stderr('stopping instance: %s' % args.instance_id) uri = get_instance_uri(args) response = requests.delete(uri) if response.status_code != 200: raise Exception(create_http_error(response)) if args.verbose: print_stderr('instance stopped: %s' % args.instance_id) def get_instance(args): if args.verbose: print_stderr('waiting for instance: %s' % args.instance_id) uri = get_instance_uri(args) start_time = datetime.datetime.utcnow() for i in range(args.tries): response = requests.get(uri) if response.status_code != 200: raise Exception(create_http_error(response)) response_json = response.json() status = response_json['status'] if status == 'running': end_time = datetime.datetime.utcnow() duration = end_time - start_time if args.verbose: print_stderr('waited %s for instance availability' % duration) connection = response_json['connection'] inventory = make_inventory(args.template, connection, args.instance_id) print(inventory) return sleep(args.sleep) raise Exception('timeout waiting for instance') def make_inventory(inventory_template, connection, instance_id): if inventory_template is None: template = dedent(''' [winrm] windows # @instance_id [winrm:vars] ansible_connection=winrm ansible_host=@ansible_host ansible_user=@ansible_user ansible_password=@ansible_password ansible_port=5986 ansible_winrm_server_cert_validation=ignore ''').strip() else: with open(inventory_template, 'r') as f: template = f.read() inventory = template\ .replace('@instance_id', instance_id)\ .replace('@ansible_host', connection['hostname'])\ .replace('@ansible_user', connection['username'])\ .replace('@ansible_password', connection.get('password', '')) return inventory def print_stderr(*args, **kwargs): """Print to stderr.""" print(*args, file=sys.stderr, **kwargs) def create_http_error(response): response_json = response.json() stack_trace = '' if 'message' in response_json: message = response_json['message'] elif 'errorMessage' in response_json: message = response_json['errorMessage'].strip() if 'stackTrace' in response_json: trace = '\n'.join([x.rstrip() for x in traceback.format_list(response_json['stackTrace'])]) stack_trace = ('\nTraceback (from remote server):\n%s' % trace) else: message = str(response_json) return '%s: %s%s' % (response.status_code, message, stack_trace) class HttpRequest: def __init__(self): """ primitive replacement for requests to avoid extra dependency avoids use of urllib2 due to lack of SNI support """ def get(self, url): return self.request('GET', url) def delete(self, url): return self.request('DELETE', url) def put(self, url, data=None, headers=None): return self.request('PUT', url, data, headers) def request(self, method, url, data=None, headers=None): args = ['/usr/bin/curl', '-s', '-i', '-X', method] if headers is not None: for header in headers: args += ['-H', '%s: %s' % (header, headers[header])] if data is not None: args += ['-d', data] args += [url] header, body = subprocess.check_output(args).split('\r\n\r\n', 1) response_headers = header.split('\r\n') first_line = response_headers[0] http_response = first_line.split(' ') status_code = int(http_response[1]) return HttpResponse(status_code, body) class HttpResponse: def __init__(self, status_code, response): self.status_code = status_code self.response = response def json(self): return json.loads(self.response) requests = HttpRequest() if __name__ == '__main__': main()