#!/usr/bin/python
# -*- coding: utf-8 -*-

# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

import shutil
import stat
import grp
import pwd
try:
    import selinux
    HAVE_SELINUX=True
except ImportError:
    HAVE_SELINUX=False

def add_path_info(kwargs):
    path = kwargs['path']
    if os.path.exists(path):
        (user, group) = user_and_group(path)
        kwargs['owner']  = user
        kwargs['group'] = group
        st = os.stat(path)
        kwargs['mode']  = oct(stat.S_IMODE(st[stat.ST_MODE]))
        # secontext not yet supported
        if os.path.islink(path):
            kwargs['state'] = 'link'
        elif os.path.isfile(path):
            kwargs['state'] = 'file'
        else:
            kwargs['state'] = 'directory'
        if HAVE_SELINUX and selinux_enabled():
            kwargs['secontext'] = ':'.join(selinux_context(path))
    else:
        kwargs['state'] = 'absent'
    return kwargs

def module_exit_json(**kwargs):
    add_path_info(kwargs)
    module.exit_json(**kwargs)

def module_fail_json(**kwargs):
    add_path_info(kwargs)
    module.fail_json(**kwargs)

# Detect whether using selinux that is MLS-aware.
# While this means you can set the level/range with
# selinux.lsetfilecon(), it may or may not mean that you
# will get the selevel as part of the context returned
# by selinux.lgetfilecon().
def selinux_mls_enabled():
    if not HAVE_SELINUX:
        return False
    if selinux.is_selinux_mls_enabled() == 1:
        return True
    else:
        return False

def selinux_enabled():
    if not HAVE_SELINUX:
        return False
    if selinux.is_selinux_enabled() == 1:
        return True
    else:
        return False

# Determine whether we need a placeholder for selevel/mls
def selinux_initial_context():
    context = [None, None, None]
    if selinux_mls_enabled():
        context.append(None)
    return context

# If selinux fails to find a default, return an array of None
def selinux_default_context(path, mode=0):
    context = selinux_initial_context()
    if not HAVE_SELINUX or not selinux_enabled():
        return context
    try:
        ret = selinux.matchpathcon(path, mode)
    except OSError:
        return context
    if ret[0] == -1:
        return context
    context = ret[1].split(':')
    return context

def selinux_context(path):
    context = selinux_initial_context()
    if not HAVE_SELINUX or not selinux_enabled():
        return context
    try:
        ret = selinux.lgetfilecon(path)
    except:
        module_fail_json(path=path, msg='failed to retrieve selinux context')
    if ret[0] == -1:
        return context
    context = ret[1].split(':')
    return context

# ===========================================
# support functions

def user_and_group(filename):
    st = os.stat(filename)
    uid = st.st_uid
    gid = st.st_gid
    try:
        user = pwd.getpwuid(uid)[0]
    except KeyError:
        user = str(uid)
    try:
        group = grp.getgrgid(gid)[0]
    except KeyError:
        group = str(gid)
    return (user, group)

def set_context_if_different(path, context, changed):
    if not HAVE_SELINUX or not selinux_enabled():
        return changed
    cur_context = selinux_context(path)
    new_context = list(cur_context)
    # Iterate over the current context instead of the
    # argument context, which may have selevel.
    for i in range(len(cur_context)):
        if context[i] is not None and context[i] != cur_context[i]:
            new_context[i] = context[i]
    if cur_context != new_context:
        try:
            rc = selinux.lsetfilecon(path, ':'.join(new_context))
        except OSError:
            module_fail_json(path=path, msg='invalid selinux context')
        if rc != 0:
            module_fail_json(path=path, msg='set selinux context failed')
        changed = True
    return changed

def set_owner_if_different(path, owner, changed):
    if owner is None:
        return changed
    user, group = user_and_group(path)
    if owner != user:
        try:
            uid = pwd.getpwnam(owner).pw_uid
        except KeyError:
            module_fail_json(path=path, msg='chown failed: failed to look up user %s' % owner)
        try:
            os.chown(path, uid, -1)
        except OSError:
            module_fail_json(path=path, msg='chown failed')
        return True

    return changed

def set_group_if_different(path, group, changed):
    if group is None:
        return changed
    old_user, old_group = user_and_group(path)
    if old_group != group:
        try:
            gid = grp.getgrnam(group).gr_gid
        except KeyError:
            module_fail_json(path=path, msg='chgrp failed: failed to look up group %s' % group)
        try:
            os.chown(path, -1, gid)
        except OSError:
            module_fail_json(path=path, msg='chgrp failed')
        return True
    return changed

def set_mode_if_different(path, mode, changed):
    if mode is None:
        return changed
    try:
        # FIXME: support English modes
        mode = int(mode, 8)
    except Exception, e:
        module_fail_json(path=path, msg='mode needs to be something octalish', details=str(e))

    st = os.stat(path)
    prev_mode = stat.S_IMODE(st[stat.ST_MODE])

    if prev_mode != mode:
        # FIXME: comparison against string above will cause this to be executed
        # every time
        try:
            os.chmod(path, mode)
        except Exception, e:
            module_fail_json(path=path, msg='chmod failed', details=str(e))
 
        st = os.stat(path)
        new_mode = stat.S_IMODE(st[stat.ST_MODE])

        if new_mode != prev_mode:
            return True
    return changed


def rmtree_error(func, path, exc_info):
    module_fail_json(path=path, msg='failed to remove directory')

def main():

    # FIXME: pass this around, should not use global
    global module

    module = AnsibleModule(
        check_invalid_arguments = False,
        argument_spec = dict(
            state = dict(choices=['file','directory','link','absent'], default='file'),
            path  = dict(aliases=['dest', 'name'], required=True),
            src   = dict(),
            mode  = dict(),
            owner = dict(),
            group = dict(),
            seuser = dict(),
            serole = dict(),
            selevel = dict(),
            setype = dict(),
        )
    )

    params = module.params
    state  = params['state']
    path   = os.path.expanduser(params['path'])
    src    = params.get('src', None)
    if src:
        src = os.path.expanduser(src)

    mode   = params.get('mode', None)
    owner  = params.get('owner', None)
    group  = params.get('group', None)

    # selinux related options
    seuser    = params.get('seuser', None)
    serole    = params.get('serole', None)
    setype    = params.get('setype', None)
    selevel   = params.get('serange', 's0')
    secontext = [seuser, serole, setype]
    if selinux_mls_enabled():
        secontext.append(selevel)

    default_secontext = selinux_default_context(path)
    for i in range(len(default_secontext)):
        if i is not None and secontext[i] == '_default':
            secontext[i] = default_secontext[i]

    if state == 'link' and (src is None or path is None):
        module_fail_json(msg='src and dest are required for "link" state')
    elif path is None:
        module_fail_json(msg='path is required')

    changed = False

    prev_state = 'absent'
    if os.path.lexists(path):
        if os.path.islink(path):
            prev_state = 'link'
        elif os.path.isfile(path):
            prev_state = 'file'
        else:
            prev_state = 'directory'

    if prev_state != 'absent' and state == 'absent':
        try:
            if prev_state == 'directory':
                if os.path.islink(path):
                    os.unlink(path)
                else:
                    shutil.rmtree(path, ignore_errors=False, onerror=rmtree_error)
            else:
                os.unlink(path)
        except Exception, e:
            module_fail_json(path=path, msg=str(e))
        module_exit_json(path=path, changed=True)

    if prev_state != 'absent' and prev_state != state:
        module_fail_json(path=path, msg='refusing to convert between %s and %s' % (prev_state, state))

    if prev_state == 'absent' and state == 'absent':
        module_exit_json(path=path, changed=False)

    if state == 'file':

        if prev_state == 'absent':
            module_fail_json(path=path, msg='file does not exist, use copy or template module to create')

        # set modes owners and context as needed
        changed = set_context_if_different(path, secontext, changed)
        changed = set_owner_if_different(path, owner, changed)
        changed = set_group_if_different(path, group, changed)
        changed = set_mode_if_different(path, mode, changed)

        module_exit_json(path=path, changed=changed)

    elif state == 'directory':

        if prev_state == 'absent':
            os.makedirs(path)
            changed = True

        # set modes owners and context as needed
        changed = set_context_if_different(path, secontext, changed)
        changed = set_owner_if_different(path, owner, changed)
        changed = set_group_if_different(path, group, changed)
        changed = set_mode_if_different(path, mode, changed)

        module_exit_json(path=path, changed=changed)

    elif state == 'link':

        if os.path.isabs(src):
            abs_src = src
        else:
            module.fail_json(msg="absolute paths are required")
        if not os.path.exists(abs_src):
            module_fail_json(path=path, src=src, msg='src file does not exist')

        if prev_state == 'absent':
            os.symlink(src, path)
            changed = True
        elif prev_state == 'link':
            old_src = os.readlink(path)
            if not os.path.isabs(old_src):
                old_src = os.path.join(os.path.dirname(path), old_src)
            if old_src != src:
                os.unlink(path)
                os.symlink(src, path)
        else:
            module_fail_json(dest=path, src=src, msg='unexpected position reached')

        # set modes owners and context as needed
        changed = set_context_if_different(path, secontext, changed)
        changed = set_owner_if_different(path, owner, changed)
        changed = set_group_if_different(path, group, changed)
        changed = set_mode_if_different(path, mode, changed)

        module.exit_json(dest=path, src=src, changed=changed)

    module_fail_json(path=path, msg='unexpected position reached')

# this is magic, see lib/ansible/module_common.py
#<<INCLUDE_ANSIBLE_MODULE_COMMON>>
main()