#!/usr/bin/python

# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

# I wanted to keep this simple at first, so for now this checks out
# from the given branch of a repo at a particular SHA or
# tag.  Latest is not supported, you should not be doing
# that. Contribs welcome! -- MPD

try:
    import json
except ImportError:
    import simplejson as json
import os
import re
import sys
import shlex
import subprocess
import syslog

# ===========================================
# Basic support methods

def exit_json(rc=0, **kwargs):
   print json.dumps(kwargs)
   sys.exit(rc)

def fail_json(**kwargs):
   kwargs['failed'] = True
   exit_json(**kwargs)

# ===========================================
# convert arguments of form a=b c=d
# to a dictionary
# FIXME: make more idiomatic

if len(sys.argv) == 1:
   fail_json(msg="the command module requires arguments (-a)")

argfile = sys.argv[1]
if not os.path.exists(argfile):
   fail_json(msg="Argument file not found")

args = open(argfile, 'r').read()
items = shlex.split(args)
syslog.openlog('ansible-%s' % os.path.basename(__file__))
syslog.syslog(syslog.LOG_NOTICE, 'Invoked with %s' % args)

if not len(items):
   fail_json(msg="the command module requires arguments (-a)")

params = {}
for x in items:
    (k, v) = x.split("=")
    params[k] = v

dest = params['dest']
repo = params['repo']
version = params.get('version', 'HEAD')
remote = params.get('remote', 'origin')

# ===========================================

def get_version(dest):
   ''' samples the version of the git repo '''
   os.chdir(dest)
   cmd = "git show --abbrev-commit"
   sha = os.popen(cmd).read().split("\n")
   sha = sha[0].split()[1]
   return sha

def clone(repo, dest):
   ''' makes a new git repo if it does not already exist '''
   try:
       os.makedirs(os.path.dirname(dest))
   except:
       pass
   cmd = "git clone %s %s" % (repo, dest)
   cmd = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
   (out, err) = cmd.communicate()
   rc = cmd.returncode
   return (rc, out, err)

def reset(dest):
   '''
   Resets the index and working tree to HEAD.
   Discards any changes to tracked files in working
   tree since that commit.
   '''
   os.chdir(dest)
   cmd = "git reset --hard HEAD"
   cmd = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
   (out, err) = cmd.communicate()
   rc = cmd.returncode
   return (rc, out, err)

def switchLocalBranch( branch ):
   cmd = "git checkout %s" % branch
   cmd = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
   return cmd.communicate()

def get_branches(dest):
    os.chdir(dest)
    branches = []
    cmd = "git branch -a"
    cmd = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = cmd.communicate()
    if cmd.returncode != 0:
        fail_json(msg="Could not determine branch data - received %s" % out)
    for line in out.split('\n'):
        branches.append(line.strip())
    return branches

def is_remote_branch(dest, remote, branch):
    branches = get_branches(dest)
    rbranch = 'remotes/%s/%s' % (remote, branch)
    if rbranch in branches:
        return True
    else:
        return False

def is_local_branch(dest, branch):
    branches = get_branches(dest)
    lbranch = '%s' % branch
    if lbranch in branches:
        return True
    else:
        return False

def pull(repo, dest, version):
    ''' updates repo from remote sources '''
    os.chdir(dest)
    branches = get_branches(dest)
    cur_branch = ''
    for b in branches:
        if b.startswith('* '):
            cur_branch = b
    if is_local_branch(dest, version) and version != cur_branch:
        (out, err) = switchLocalBranch(version)

    cmd = "git pull -u origin"
    cmd = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = cmd.communicate()
    rc = cmd.returncode
    return (rc, out, err)

def switch_version(dest, remote, version):
    ''' once pulled, switch to a particular SHA or tag '''
    os.chdir(dest)
    cmd = ''
    if version != 'HEAD':
        if not is_local_branch(dest, version) and is_remote_branch(dest, remote, version):
            cmd = "git checkout --track -b %s %s/%s" % (version, remote, version)
        else:
            cmd = "git checkout --force %s" % version
    else:
        # is there a better way to do this?
        cmd = "git rebase origin"
    cmd = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    (out, err) = cmd.communicate()
    rc = cmd.returncode
    return (rc, out, err)
 

gitconfig = os.path.join(dest, '.git', 'config')

out, err, status = (None, None, None)

# if there is no git configuration, do a clone operation
# else pull and switch the version

before = None
if not os.path.exists(gitconfig):
   (rc, out, err) = clone(repo, dest)
   if rc != 0:
      fail_json(out=out, err=err, rc=rc)
else:
   # else do a pull   
   before = get_version(dest)
   (rc, out, err) = reset(dest)
   if rc != 0:
      fail_json(out=out, err=err, rc=rc)
   (rc, out, err) = pull(repo, dest, version)

# handle errors from clone or pull

if out.find('error') != -1 or err.find('ERROR') != -1:
   fail_json(out=out, err=err)

# switch to version specified regardless of whether
# we cloned or pulled

(rc, out, err) = switch_version(dest, remote, version)
if err.find('error') != -1:
   fail_json(out=out, err=err)

# determine if we changed anything

after = get_version(dest)
changed = False

if before != after:
   changed = True

exit_json(changed=changed, before=before, after=after)