diff --git a/v2/ansible/parsing/__init__.py b/v2/ansible/parsing/__init__.py index 4641623c03..229be2622f 100644 --- a/v2/ansible/parsing/__init__.py +++ b/v2/ansible/parsing/__init__.py @@ -19,19 +19,205 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import json + +from yaml import YAMLError + from ansible.errors import AnsibleError, AnsibleInternalError +from ansible.parsing.vault import VaultLib +from ansible.parsing.yaml import safe_load -def load(self, data): - if instanceof(data, file): +def process_common_errors(msg, probline, column): + replaced = probline.replace(" ","") + + if ":{{" in replaced and "}}" in replaced: + msg = msg + """ +This one looks easy to fix. YAML thought it was looking for the start of a +hash/dictionary and was confused to see a second "{". Most likely this was +meant to be an ansible template evaluation instead, so we have to give the +parser a small hint that we wanted a string instead. The solution here is to +just quote the entire value. + +For instance, if the original line was: + + app_path: {{ base_path }}/foo + +It should be written as: + + app_path: "{{ base_path }}/foo" +""" + return msg + + elif len(probline) and len(probline) > 1 and len(probline) > column and probline[column] == ":" and probline.count(':') > 1: + msg = msg + """ +This one looks easy to fix. There seems to be an extra unquoted colon in the line +and this is confusing the parser. It was only expecting to find one free +colon. The solution is just add some quotes around the colon, or quote the +entire line after the first colon. + +For instance, if the original line was: + + copy: src=file.txt dest=/path/filename:with_colon.txt + +It can be written as: + + copy: src=file.txt dest='/path/filename:with_colon.txt' + +Or: + + copy: 'src=file.txt dest=/path/filename:with_colon.txt' + + +""" + return msg + else: + parts = probline.split(":") + if len(parts) > 1: + middle = parts[1].strip() + match = False + unbalanced = False + if middle.startswith("'") and not middle.endswith("'"): + match = True + elif middle.startswith('"') and not middle.endswith('"'): + match = True + if len(middle) > 0 and middle[0] in [ '"', "'" ] and middle[-1] in [ '"', "'" ] and probline.count("'") > 2 or probline.count('"') > 2: + unbalanced = True + if match: + msg = msg + """ +This one looks easy to fix. It seems that there is a value started +with a quote, and the YAML parser is expecting to see the line ended +with the same kind of quote. For instance: + + when: "ok" in result.stdout + +Could be written as: + + when: '"ok" in result.stdout' + +or equivalently: + + when: "'ok' in result.stdout" + +""" + return msg + + if unbalanced: + msg = msg + """ +We could be wrong, but this one looks like it might be an issue with +unbalanced quotes. If starting a value with a quote, make sure the +line ends with the same set of quotes. For instance this arbitrary +example: + + foo: "bad" "wolf" + +Could be written as: + + foo: '"bad" "wolf"' + +""" + return msg + + return msg + +def process_yaml_error(exc, data, path=None, show_content=True): + if hasattr(exc, 'problem_mark'): + mark = exc.problem_mark + if show_content: + if mark.line -1 >= 0: + before_probline = data.split("\n")[mark.line-1] + else: + before_probline = '' + probline = data.split("\n")[mark.line] + arrow = " " * mark.column + "^" + msg = """Syntax Error while loading YAML script, %s +Note: The error may actually appear before this position: line %s, column %s + +%s +%s +%s""" % (path, mark.line + 1, mark.column + 1, before_probline, probline, arrow) + + unquoted_var = None + if '{{' in probline and '}}' in probline: + if '"{{' not in probline or "'{{" not in probline: + unquoted_var = True + + if not unquoted_var: + msg = process_common_errors(msg, probline, mark.column) + else: + msg = msg + """ +We could be wrong, but this one looks like it might be an issue with +missing quotes. Always quote template expression brackets when they +start a value. For instance: + + with_items: + - {{ foo }} + +Should be written as: + + with_items: + - "{{ foo }}" + +""" + else: + # most likely displaying a file with sensitive content, + # so don't show any of the actual lines of yaml just the + # line number itself + msg = """Syntax error while loading YAML script, %s +The error appears to have been on line %s, column %s, but may actually +be before there depending on the exact syntax problem. +""" % (path, mark.line + 1, mark.column + 1) + + else: + # No problem markers means we have to throw a generic + # "stuff messed up" type message. Sry bud. + if path: + msg = "Could not parse YAML. Check over %s again." % path + else: + msg = "Could not parse YAML." + raise errors.AnsibleYAMLValidationFailed(msg) + + +def load_data(data): + + if isinstance(data, file): fd = open(f) data = fd.read() fd.close() - if instanceof(data, basestring): + if isinstance(data, basestring): try: return json.loads(data) except: return safe_load(data) raise AnsibleInternalError("expected file or string, got %s" % type(data)) + +def load_data_from_file(path, vault_password=None): + ''' + Convert a yaml file to a data structure. + Was previously 'parse_yaml_from_file()'. + ''' + + data = None + show_content = True + + try: + data = open(path).read() + except IOError: + raise errors.AnsibleError("file could not read: %s" % path) + + vault = VaultLib(password=vault_password) + if vault.is_encrypted(data): + # if the file is encrypted and no password was specified, + # the decrypt call would throw an error, but we check first + # since the decrypt function doesn't know the file name + if vault_password is None: + raise errors.AnsibleError("A vault password must be specified to decrypt %s" % path) + data = vault.decrypt(data) + show_content = False + + try: + return load_data(data) + except YAMLError, exc: + process_yaml_error(exc, data, path, show_content) diff --git a/v2/ansible/parsing/vault/__init__.py b/v2/ansible/parsing/vault/__init__.py new file mode 100644 index 0000000000..3b83d2989e --- /dev/null +++ b/v2/ansible/parsing/vault/__init__.py @@ -0,0 +1,563 @@ +# (c) 2014, James Tanner +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +# ansible-pull is a script that runs ansible in local mode +# after checking out a playbooks directory from source repo. There is an +# example playbook to bootstrap this script in the examples/ dir which +# installs ansible and sets it up to run on cron. + +import os +import shlex +import shutil +import tempfile +from io import BytesIO +from subprocess import call +from ansible import errors +from hashlib import sha256 +from hashlib import md5 +from binascii import hexlify +from binascii import unhexlify +from ansible import constants as C + +try: + from Crypto.Hash import SHA256, HMAC + HAS_HASH = True +except ImportError: + HAS_HASH = False + +# Counter import fails for 2.0.1, requires >= 2.6.1 from pip +try: + from Crypto.Util import Counter + HAS_COUNTER = True +except ImportError: + HAS_COUNTER = False + +# KDF import fails for 2.0.1, requires >= 2.6.1 from pip +try: + from Crypto.Protocol.KDF import PBKDF2 + HAS_PBKDF2 = True +except ImportError: + HAS_PBKDF2 = False + +# AES IMPORTS +try: + from Crypto.Cipher import AES as AES + HAS_AES = True +except ImportError: + HAS_AES = False + +CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto" + +HEADER='$ANSIBLE_VAULT' +CIPHER_WHITELIST=['AES', 'AES256'] + +class VaultLib(object): + + def __init__(self, password): + self.password = password + self.cipher_name = None + self.version = '1.1' + + def is_encrypted(self, data): + if data.startswith(HEADER): + return True + else: + return False + + def encrypt(self, data): + + if self.is_encrypted(data): + raise errors.AnsibleError("data is already encrypted") + + if not self.cipher_name: + self.cipher_name = "AES256" + #raise errors.AnsibleError("the cipher must be set before encrypting data") + + if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: + cipher = globals()['Vault' + self.cipher_name] + this_cipher = cipher() + else: + raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + + """ + # combine sha + data + this_sha = sha256(data).hexdigest() + tmp_data = this_sha + "\n" + data + """ + + # encrypt sha + data + enc_data = this_cipher.encrypt(data, self.password) + + # add header + tmp_data = self._add_header(enc_data) + return tmp_data + + def decrypt(self, data): + if self.password is None: + raise errors.AnsibleError("A vault password must be specified to decrypt data") + + if not self.is_encrypted(data): + raise errors.AnsibleError("data is not encrypted") + + # clean out header + data = self._split_header(data) + + # create the cipher object + if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: + cipher = globals()['Vault' + self.cipher_name] + this_cipher = cipher() + else: + raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + + # try to unencrypt data + data = this_cipher.decrypt(data, self.password) + if data is None: + raise errors.AnsibleError("Decryption failed") + + return data + + def _add_header(self, data): + # combine header and encrypted data in 80 char columns + + #tmpdata = hexlify(data) + tmpdata = [data[i:i+80] for i in range(0, len(data), 80)] + + if not self.cipher_name: + raise errors.AnsibleError("the cipher must be set before adding a header") + + dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher_name + "\n" + + for l in tmpdata: + dirty_data += l + '\n' + + return dirty_data + + + def _split_header(self, data): + # used by decrypt + + tmpdata = data.split('\n') + tmpheader = tmpdata[0].strip().split(';') + + self.version = str(tmpheader[1].strip()) + self.cipher_name = str(tmpheader[2].strip()) + clean_data = '\n'.join(tmpdata[1:]) + + """ + # strip out newline, join, unhex + clean_data = [ x.strip() for x in clean_data ] + clean_data = unhexlify(''.join(clean_data)) + """ + + return clean_data + + def __enter__(self): + return self + + def __exit__(self, *err): + pass + +class VaultEditor(object): + # uses helper methods for write_file(self, filename, data) + # to write a file so that code isn't duplicated for simple + # file I/O, ditto read_file(self, filename) and launch_editor(self, filename) + # ... "Don't Repeat Yourself", etc. + + def __init__(self, cipher_name, password, filename): + # instantiates a member variable for VaultLib + self.cipher_name = cipher_name + self.password = password + self.filename = filename + + def create_file(self): + """ create a new encrypted file """ + + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + if os.path.isfile(self.filename): + raise errors.AnsibleError("%s exists, please use 'edit' instead" % self.filename) + + # drop the user into vim on file + old_umask = os.umask(0077) + call(self._editor_shell_command(self.filename)) + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + this_vault.cipher_name = self.cipher_name + enc_data = this_vault.encrypt(tmpdata) + self.write_data(enc_data, self.filename) + os.umask(old_umask) + + def decrypt_file(self): + + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + if not os.path.isfile(self.filename): + raise errors.AnsibleError("%s does not exist" % self.filename) + + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + if this_vault.is_encrypted(tmpdata): + dec_data = this_vault.decrypt(tmpdata) + if dec_data is None: + raise errors.AnsibleError("Decryption failed") + else: + self.write_data(dec_data, self.filename) + else: + raise errors.AnsibleError("%s is not encrypted" % self.filename) + + def edit_file(self): + + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + # make sure the umask is set to a sane value + old_mask = os.umask(0077) + + # decrypt to tmpfile + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + dec_data = this_vault.decrypt(tmpdata) + _, tmp_path = tempfile.mkstemp() + self.write_data(dec_data, tmp_path) + + # drop the user into vim on the tmp file + call(self._editor_shell_command(tmp_path)) + new_data = self.read_data(tmp_path) + + # create new vault + new_vault = VaultLib(self.password) + + # we want the cipher to default to AES256 + #new_vault.cipher_name = this_vault.cipher_name + + # encrypt new data a write out to tmp + enc_data = new_vault.encrypt(new_data) + self.write_data(enc_data, tmp_path) + + # shuffle tmp file into place + self.shuffle_files(tmp_path, self.filename) + + # and restore the old umask + os.umask(old_mask) + + def view_file(self): + + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + # decrypt to tmpfile + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + dec_data = this_vault.decrypt(tmpdata) + _, tmp_path = tempfile.mkstemp() + self.write_data(dec_data, tmp_path) + + # drop the user into pager on the tmp file + call(self._pager_shell_command(tmp_path)) + os.remove(tmp_path) + + def encrypt_file(self): + + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + if not os.path.isfile(self.filename): + raise errors.AnsibleError("%s does not exist" % self.filename) + + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + this_vault.cipher_name = self.cipher_name + if not this_vault.is_encrypted(tmpdata): + enc_data = this_vault.encrypt(tmpdata) + self.write_data(enc_data, self.filename) + else: + raise errors.AnsibleError("%s is already encrypted" % self.filename) + + def rekey_file(self, new_password): + + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + # decrypt + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + dec_data = this_vault.decrypt(tmpdata) + + # create new vault + new_vault = VaultLib(new_password) + + # we want to force cipher to the default + #new_vault.cipher_name = this_vault.cipher_name + + # re-encrypt data and re-write file + enc_data = new_vault.encrypt(dec_data) + self.write_data(enc_data, self.filename) + + def read_data(self, filename): + f = open(filename, "rb") + tmpdata = f.read() + f.close() + return tmpdata + + def write_data(self, data, filename): + if os.path.isfile(filename): + os.remove(filename) + f = open(filename, "wb") + f.write(data) + f.close() + + def shuffle_files(self, src, dest): + # overwrite dest with src + if os.path.isfile(dest): + os.remove(dest) + shutil.move(src, dest) + + def _editor_shell_command(self, filename): + EDITOR = os.environ.get('EDITOR','vim') + editor = shlex.split(EDITOR) + editor.append(filename) + + return editor + + def _pager_shell_command(self, filename): + PAGER = os.environ.get('PAGER','less') + pager = shlex.split(PAGER) + pager.append(filename) + + return pager + +######################################## +# CIPHERS # +######################################## + +class VaultAES(object): + + # this version has been obsoleted by the VaultAES256 class + # which uses encrypt-then-mac (fixing order) and also improving the KDF used + # code remains for upgrade purposes only + # http://stackoverflow.com/a/16761459 + + def __init__(self): + if not HAS_AES: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + def aes_derive_key_and_iv(self, password, salt, key_length, iv_length): + + """ Create a key and an initialization vector """ + + d = d_i = '' + while len(d) < key_length + iv_length: + d_i = md5(d_i + password + salt).digest() + d += d_i + + key = d[:key_length] + iv = d[key_length:key_length+iv_length] + + return key, iv + + def encrypt(self, data, password, key_length=32): + + """ Read plaintext data from in_file and write encrypted to out_file """ + + + # combine sha + data + this_sha = sha256(data).hexdigest() + tmp_data = this_sha + "\n" + data + + in_file = BytesIO(tmp_data) + in_file.seek(0) + out_file = BytesIO() + + bs = AES.block_size + + # Get a block of random data. EL does not have Crypto.Random.new() + # so os.urandom is used for cross platform purposes + salt = os.urandom(bs - len('Salted__')) + + key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) + cipher = AES.new(key, AES.MODE_CBC, iv) + out_file.write('Salted__' + salt) + finished = False + while not finished: + chunk = in_file.read(1024 * bs) + if len(chunk) == 0 or len(chunk) % bs != 0: + padding_length = (bs - len(chunk) % bs) or bs + chunk += padding_length * chr(padding_length) + finished = True + out_file.write(cipher.encrypt(chunk)) + + out_file.seek(0) + enc_data = out_file.read() + tmp_data = hexlify(enc_data) + + return tmp_data + + + def decrypt(self, data, password, key_length=32): + + """ Read encrypted data from in_file and write decrypted to out_file """ + + # http://stackoverflow.com/a/14989032 + + data = ''.join(data.split('\n')) + data = unhexlify(data) + + in_file = BytesIO(data) + in_file.seek(0) + out_file = BytesIO() + + bs = AES.block_size + salt = in_file.read(bs)[len('Salted__'):] + key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) + cipher = AES.new(key, AES.MODE_CBC, iv) + next_chunk = '' + finished = False + + while not finished: + chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs)) + if len(next_chunk) == 0: + padding_length = ord(chunk[-1]) + chunk = chunk[:-padding_length] + finished = True + out_file.write(chunk) + + # reset the stream pointer to the beginning + out_file.seek(0) + new_data = out_file.read() + + # split out sha and verify decryption + split_data = new_data.split("\n") + this_sha = split_data[0] + this_data = '\n'.join(split_data[1:]) + test_sha = sha256(this_data).hexdigest() + + if this_sha != test_sha: + raise errors.AnsibleError("Decryption failed") + + #return out_file.read() + return this_data + + +class VaultAES256(object): + + """ + Vault implementation using AES-CTR with an HMAC-SHA256 authentication code. + Keys are derived using PBKDF2 + """ + + # http://www.daemonology.net/blog/2009-06-11-cryptographic-right-answers.html + + def __init__(self): + + if not HAS_PBKDF2 or not HAS_COUNTER or not HAS_HASH: + raise errors.AnsibleError(CRYPTO_UPGRADE) + + def gen_key_initctr(self, password, salt): + # 16 for AES 128, 32 for AES256 + keylength = 32 + + # match the size used for counter.new to avoid extra work + ivlength = 16 + + hash_function = SHA256 + + # make two keys and one iv + pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest() + + + derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, + count=10000, prf=pbkdf2_prf) + + key1 = derivedkey[:keylength] + key2 = derivedkey[keylength:(keylength * 2)] + iv = derivedkey[(keylength * 2):(keylength * 2) + ivlength] + + return key1, key2, hexlify(iv) + + + def encrypt(self, data, password): + + salt = os.urandom(32) + key1, key2, iv = self.gen_key_initctr(password, salt) + + # PKCS#7 PAD DATA http://tools.ietf.org/html/rfc5652#section-6.3 + bs = AES.block_size + padding_length = (bs - len(data) % bs) or bs + data += padding_length * chr(padding_length) + + # COUNTER.new PARAMETERS + # 1) nbits (integer) - Length of the counter, in bits. + # 2) initial_value (integer) - initial value of the counter. "iv" from gen_key_initctr + + ctr = Counter.new(128, initial_value=long(iv, 16)) + + # AES.new PARAMETERS + # 1) AES key, must be either 16, 24, or 32 bytes long -- "key" from gen_key_initctr + # 2) MODE_CTR, is the recommended mode + # 3) counter= + + cipher = AES.new(key1, AES.MODE_CTR, counter=ctr) + + # ENCRYPT PADDED DATA + cryptedData = cipher.encrypt(data) + + # COMBINE SALT, DIGEST AND DATA + hmac = HMAC.new(key2, cryptedData, SHA256) + message = "%s\n%s\n%s" % ( hexlify(salt), hmac.hexdigest(), hexlify(cryptedData) ) + message = hexlify(message) + return message + + def decrypt(self, data, password): + + # SPLIT SALT, DIGEST, AND DATA + data = ''.join(data.split("\n")) + data = unhexlify(data) + salt, cryptedHmac, cryptedData = data.split("\n", 2) + salt = unhexlify(salt) + cryptedData = unhexlify(cryptedData) + + key1, key2, iv = self.gen_key_initctr(password, salt) + + # EXIT EARLY IF DIGEST DOESN'T MATCH + hmacDecrypt = HMAC.new(key2, cryptedData, SHA256) + if not self.is_equal(cryptedHmac, hmacDecrypt.hexdigest()): + return None + + # SET THE COUNTER AND THE CIPHER + ctr = Counter.new(128, initial_value=long(iv, 16)) + cipher = AES.new(key1, AES.MODE_CTR, counter=ctr) + + # DECRYPT PADDED DATA + decryptedData = cipher.decrypt(cryptedData) + + # UNPAD DATA + padding_length = ord(decryptedData[-1]) + decryptedData = decryptedData[:-padding_length] + + return decryptedData + + def is_equal(self, a, b): + # http://codahale.com/a-lesson-in-timing-attacks/ + if len(a) != len(b): + return False + + result = 0 + for x, y in zip(a, b): + result |= ord(x) ^ ord(y) + return result == 0 + + diff --git a/v2/ansible/playbook/attribute.py b/v2/ansible/playbook/attribute.py index ecafe653f0..1e7e404181 100644 --- a/v2/ansible/playbook/attribute.py +++ b/v2/ansible/playbook/attribute.py @@ -25,11 +25,7 @@ class Attribute: self.isa = isa self.private = private - self.value = None self.default = default - def __call__(self): - return self.value - class FieldAttribute(Attribute): pass diff --git a/v2/ansible/playbook/base.py b/v2/ansible/playbook/base.py index 08f4d519e6..a992e19a5d 100644 --- a/v2/ansible/playbook/base.py +++ b/v2/ansible/playbook/base.py @@ -19,24 +19,39 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from inspect import getmembers from io import FileIO from six import iteritems, string_types from ansible.playbook.attribute import Attribute, FieldAttribute -from ansible.parsing import load as ds_load +from ansible.parsing import load_data class Base: + _tags = FieldAttribute(isa='list') + _when = FieldAttribute(isa='list') + def __init__(self): # each class knows attributes set upon it, see Task.py for example self._attributes = dict() - for (name, value) in iteritems(self.__class__.__dict__): - aname = name[1:] + for (name, value) in self._get_base_attributes().iteritems(): + self._attributes[name] = value.default + + def _get_base_attributes(self): + ''' + Returns the list of attributes for this class (or any subclass thereof). + If the attribute name starts with an underscore, it is removed + ''' + base_attributes = dict() + for (name, value) in getmembers(self.__class__): if isinstance(value, Attribute): - self._attributes[aname] = value.default + if name.startswith('_'): + name = name[1:] + base_attributes[name] = value + return base_attributes def munge(self, ds): ''' infrequently used method to do some pre-processing of legacy terms ''' @@ -49,7 +64,7 @@ class Base: assert ds is not None if isinstance(ds, string_types) or isinstance(ds, FileIO): - ds = ds_load(ds) + ds = load_data(ds) # we currently don't do anything with private attributes but may # later decide to filter them out of 'ds' here. @@ -57,20 +72,15 @@ class Base: ds = self.munge(ds) # walk all attributes in the class - for (name, attribute) in iteritems(self.__class__.__dict__): - aname = name[1:] + for (name, attribute) in self._get_base_attributes().iteritems(): - # process Field attributes which get loaded from the YAML - - if isinstance(attribute, FieldAttribute): - - # copy the value over unless a _load_field method is defined - if aname in ds: - method = getattr(self, '_load_%s' % aname, None) - if method: - self._attributes[aname] = method(aname, ds[aname]) - else: - self._attributes[aname] = ds[aname] + # copy the value over unless a _load_field method is defined + if name in ds: + method = getattr(self, '_load_%s' % name, None) + if method: + self._attributes[name] = method(name, ds[name]) + else: + self._attributes[name] = ds[name] # return the constructed object self.validate() @@ -81,20 +91,12 @@ class Base: ''' validation that is done at parse time, not load time ''' # walk all fields in the object - for (name, attribute) in self.__dict__.iteritems(): + for (name, attribute) in self._get_base_attributes().iteritems(): - # find any field attributes - if isinstance(attribute, FieldAttribute): - - if not name.startswith("_"): - raise AnsibleError("FieldAttribute %s must start with _" % name) - - aname = name[1:] - - # run validator only if present - method = getattr(self, '_validate_%s' % (prefix, aname), None) - if method: - method(self, attribute) + # run validator only if present + method = getattr(self, '_validate_%s' % name, None) + if method: + method(self, attribute) def post_validate(self, runner_context): ''' diff --git a/v2/ansible/playbook/block.py b/v2/ansible/playbook/block.py index 45bab87bf6..5e4826d119 100644 --- a/v2/ansible/playbook/block.py +++ b/v2/ansible/playbook/block.py @@ -25,11 +25,13 @@ from ansible.playbook.attribute import Attribute, FieldAttribute class Block(Base): - # TODO: FIXME: block/rescue/always should be enough - _begin = FieldAttribute(isa='list') + _block = FieldAttribute(isa='list') _rescue = FieldAttribute(isa='list') - _end = FieldAttribute(isa='list') - _otherwise = FieldAttribute(isa='list') + _always = FieldAttribute(isa='list') + + # for future consideration? this would be functionally + # similar to the 'else' clause for exceptions + #_otherwise = FieldAttribute(isa='list') def __init__(self, role=None): self.role = role @@ -45,6 +47,20 @@ class Block(Base): b = Block(role=role) return b.load_data(data) + def munge(self, ds): + ''' + If a simple task is given, an implicit block for that single task + is created, which goes in the main portion of the block + ''' + is_block = False + for attr in ('block', 'rescue', 'always'): + if attr in ds: + is_block = True + break + if not is_block: + return dict(block=ds) + return ds + def _load_list_of_tasks(self, ds): assert type(ds) == list task_list = [] @@ -53,15 +69,16 @@ class Block(Base): task_list.append(t) return task_list - def _load_begin(self, attr, ds): + def _load_block(self, attr, ds): return self._load_list_of_tasks(ds) def _load_rescue(self, attr, ds): return self._load_list_of_tasks(ds) - def _load_end(self, attr, ds): + def _load_always(self, attr, ds): return self._load_list_of_tasks(ds) - def _load_otherwise(self, attr, ds): - return self._load_list_of_tasks(ds) + # not currently used + #def _load_otherwise(self, attr, ds): + # return self._load_list_of_tasks(ds) diff --git a/v2/ansible/playbook/role.py b/v2/ansible/playbook/role.py index f36207874d..af3855f3d9 100644 --- a/v2/ansible/playbook/role.py +++ b/v2/ansible/playbook/role.py @@ -19,31 +19,251 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -from v2.playbook.base import PlaybookBase -from v2.utils import list_union +from six import iteritems, string_types -class Role(PlaybookBase): +import os - # TODO: this will be overhauled to match Task.py at some point +from ansible.playbook.attribute import FieldAttribute +from ansible.playbook.base import Base +from ansible.playbook.block import Block +from ansible.parsing import load_data_from_file - def __init__(self): - pass +#from ansible.utils import list_union, unfrackpath + +class Role(Base): + + _role = FieldAttribute(isa='string') + _src = FieldAttribute(isa='string') + _scm = FieldAttribute(isa='string') + _version = FieldAttribute(isa='string') + _params = FieldAttribute(isa='dict') + _metadata = FieldAttribute(isa='dict') + _task_blocks = FieldAttribute(isa='list') + _handler_blocks = FieldAttribute(isa='list') + _default_vars = FieldAttribute(isa='dict') + _role_vars = FieldAttribute(isa='dict') + + def __init__(self, vault_password=None): + self._role_path = None + self._vault_password = vault_password + super(Role, self).__init__() + + def __repr__(self): + return self.get_name() def get_name(self): - return "TEMPORARY" + return self._attributes['role'] - def load(self, ds): - self._ds = ds - self._tasks = [] - self._handlers = [] - self._blocks = [] - self._dependencies = [] - self._metadata = dict() - self._defaults = dict() - self._vars = dict() - self._params = dict() + @staticmethod + def load(data, vault_password=None): + assert isinstance(data, string_types) or isinstance(data, dict) + r = Role(vault_password=vault_password) + r.load_data(data) + return r - def get_vars(self): + #------------------------------------------------------------------------------ + # munge, and other functions used for loading the ds + + def munge(self, ds): + # Role definitions can be strings or dicts, so we fix + # things up here. Anything that is not a role name, tag, + # or conditional will also be added to the params sub- + # dictionary for loading later + if isinstance(ds, string_types): + new_ds = dict(role=ds) + else: + ds = self._munge_role(ds) + + params = dict() + new_ds = dict() + + for (key, value) in iteritems(ds): + if key not in [name for (name, value) in self._get_base_attributes().iteritems()]: + # this key does not match a field attribute, + # so it must be a role param + params[key] = value + else: + # this is a field attribute, so copy it over directly + new_ds[key] = value + + # finally, assign the params to a new entry in the revised ds + new_ds['params'] = params + + # set the role path, based on the role definition + self._role_path = self._get_role_path(new_ds.get('role')) + + # load the role's files, if they exist + new_ds['metadata'] = self._load_role_yaml('meta') + new_ds['task_blocks'] = self._load_role_yaml('tasks') + new_ds['handler_blocks'] = self._load_role_yaml('handlers') + new_ds['default_vars'] = self._load_role_yaml('defaults') + new_ds['role_vars'] = self._load_role_yaml('vars') + + return new_ds + + def _load_role_yaml(self, subdir): + file_path = os.path.join(self._role_path, subdir) + if os.path.exists(file_path) and os.path.isdir(file_path): + main_file = self._resolve_main(file_path) + if os.path.exists(main_file): + return load_data_from_file(main_file, self._vault_password) + return None + + def _resolve_main(self, basepath): + ''' flexibly handle variations in main filenames ''' + possible_mains = ( + os.path.join(basepath, 'main'), + os.path.join(basepath, 'main.yml'), + os.path.join(basepath, 'main.yaml'), + os.path.join(basepath, 'main.json'), + ) + + if sum([os.path.isfile(x) for x in possible_mains]) > 1: + raise errors.AnsibleError("found multiple main files at %s, only one allowed" % (basepath)) + else: + for m in possible_mains: + if os.path.isfile(m): + return m # exactly one main file + return possible_mains[0] # zero mains (we still need to return something) + + def _get_role_path(self, role): + ''' + the 'role', as specified in the ds (or as a bare string), can either + be a simple name or a full path. If it is a full path, we use the + basename as the role name, otherwise we take the name as-given and + append it to the default role path + ''' + + # FIXME: this should use unfrackpath once the utils code has been sorted out + role_path = os.path.normpath(role) + if os.path.exists(role_path): + return role_path + else: + for path in ('./roles', '/etc/ansible/roles'): + role_path = os.path.join(path, role) + if os.path.exists(role_path): + return role_path + # FIXME: raise something here + raise + + def _repo_url_to_role_name(self, repo_url): + # gets the role name out of a repo like + # http://git.example.com/repos/repo.git" => "repo" + + if '://' not in repo_url and '@' not in repo_url: + return repo_url + trailing_path = repo_url.split('/')[-1] + if trailing_path.endswith('.git'): + trailing_path = trailing_path[:-4] + if trailing_path.endswith('.tar.gz'): + trailing_path = trailing_path[:-7] + if ',' in trailing_path: + trailing_path = trailing_path.split(',')[0] + return trailing_path + + def _role_spec_parse(self, role_spec): + # takes a repo and a version like + # git+http://git.example.com/repos/repo.git,v1.0 + # and returns a list of properties such as: + # { + # 'scm': 'git', + # 'src': 'http://git.example.com/repos/repo.git', + # 'version': 'v1.0', + # 'name': 'repo' + # } + + default_role_versions = dict(git='master', hg='tip') + + role_spec = role_spec.strip() + role_version = '' + if role_spec == "" or role_spec.startswith("#"): + return (None, None, None, None) + + tokens = [s.strip() for s in role_spec.split(',')] + + # assume https://github.com URLs are git+https:// URLs and not + # tarballs unless they end in '.zip' + if 'github.com/' in tokens[0] and not tokens[0].startswith("git+") and not tokens[0].endswith('.tar.gz'): + tokens[0] = 'git+' + tokens[0] + + if '+' in tokens[0]: + (scm, role_url) = tokens[0].split('+') + else: + scm = None + role_url = tokens[0] + + if len(tokens) >= 2: + role_version = tokens[1] + + if len(tokens) == 3: + role_name = tokens[2] + else: + role_name = repo_url_to_role_name(tokens[0]) + + if scm and not role_version: + role_version = default_role_versions.get(scm, '') + + return dict(scm=scm, src=role_url, version=role_version, name=role_name) + + def _munge_role(self, ds): + if 'role' in ds: + # Old style: {role: "galaxy.role,version,name", other_vars: "here" } + role_info = self._role_spec_parse(ds['role']) + if isinstance(role_info, dict): + # Warning: Slight change in behaviour here. name may be being + # overloaded. Previously, name was only a parameter to the role. + # Now it is both a parameter to the role and the name that + # ansible-galaxy will install under on the local system. + if 'name' in ds and 'name' in role_info: + del role_info['name'] + ds.update(role_info) + else: + # New style: { src: 'galaxy.role,version,name', other_vars: "here" } + if 'github.com' in ds["src"] and 'http' in ds["src"] and '+' not in ds["src"] and not ds["src"].endswith('.tar.gz'): + ds["src"] = "git+" + ds["src"] + + if '+' in ds["src"]: + (scm, src) = ds["src"].split('+') + ds["scm"] = scm + ds["src"] = src + + if 'name' in role: + ds["role"] = ds["name"] + del ds["name"] + else: + ds["role"] = self._repo_url_to_role_name(ds["src"]) + + # set some values to a default value, if none were specified + ds.setdefault('version', '') + ds.setdefault('scm', None) + + return ds + + #------------------------------------------------------------------------------ + # attribute loading defs + + def _load_list_of_blocks(self, ds): + assert type(ds) == list + block_list = [] + for block in ds: + b = Block(block) + block_list.append(b) + return block_list + + def _load_task_blocks(self, attr, ds): + if ds is None: + return [] + return self._load_list_of_blocks(ds) + + def _load_handler_blocks(self, attr, ds): + if ds is None: + return [] + return self._load_list_of_blocks(ds) + + #------------------------------------------------------------------------------ + # other functions + + def get_variables(self): # returns the merged variables for this role, including # recursively merging those of all child roles return dict() @@ -60,8 +280,3 @@ class Role(PlaybookBase): all_deps = list_union(all_deps, self.dependencies) return all_deps - def get_blocks(self): - # should return - return self.blocks - - diff --git a/v2/test/parsing/test_general.py b/v2/test/parsing/test_general.py index 0a150e1a23..b86fcba289 100644 --- a/v2/test/parsing/test_general.py +++ b/v2/test/parsing/test_general.py @@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type from .. compat import unittest -from ansible.parsing import load +from ansible.parsing import load_data from ansible.errors import AnsibleParserError import json @@ -68,12 +68,12 @@ class TestGeneralParsing(unittest.TestCase): "jkl" : 5678 } """ - output = load(input) + output = load_data(input) self.assertEqual(output['asdf'], '1234') self.assertEqual(output['jkl'], 5678) def parse_json_from_file(self): - output = load(MockFile(dict(a=1,b=2,c=3)),'json') + output = load_data(MockFile(dict(a=1,b=2,c=3)),'json') self.assertEqual(ouput, dict(a=1,b=2,c=3)) def parse_yaml_from_dict(self): @@ -81,12 +81,12 @@ class TestGeneralParsing(unittest.TestCase): asdf: '1234' jkl: 5678 """ - output = load(input) + output = load_data(input) self.assertEqual(output['asdf'], '1234') self.assertEqual(output['jkl'], 5678) def parse_yaml_from_file(self): - output = load(MockFile(dict(a=1,b=2,c=3),'yaml')) + output = load_data(MockFile(dict(a=1,b=2,c=3),'yaml')) self.assertEqual(output, dict(a=1,b=2,c=3)) def parse_fail(self): @@ -95,10 +95,10 @@ class TestGeneralParsing(unittest.TestCase): *** NOT VALID """ - self.assertRaises(load(input), AnsibleParserError) + self.assertRaises(load_data(input), AnsibleParserError) def parse_fail_from_file(self): - self.assertRaises(load(MockFile(None,'fail')), AnsibleParserError) + self.assertRaises(load_data(MockFile(None,'fail')), AnsibleParserError) def parse_fail_invalid_type(self): self.assertRaises(3000, AnsibleParsingError) diff --git a/v2/test/playbook/test_block.py b/v2/test/playbook/test_block.py index a0da7f0e6a..46921ae6d2 100644 --- a/v2/test/playbook/test_block.py +++ b/v2/test/playbook/test_block.py @@ -49,31 +49,39 @@ class TestBlock(unittest.TestCase): def test_load_block_simple(self): ds = dict( - begin = [], + block = [], rescue = [], - end = [], - otherwise = [], + always = [], + #otherwise = [], ) b = Block.load(ds) - self.assertEqual(b.begin, []) + self.assertEqual(b.block, []) self.assertEqual(b.rescue, []) - self.assertEqual(b.end, []) - self.assertEqual(b.otherwise, []) + self.assertEqual(b.always, []) + # not currently used + #self.assertEqual(b.otherwise, []) def test_load_block_with_tasks(self): ds = dict( - begin = [dict(action='begin')], + block = [dict(action='block')], rescue = [dict(action='rescue')], - end = [dict(action='end')], - otherwise = [dict(action='otherwise')], + always = [dict(action='always')], + #otherwise = [dict(action='otherwise')], ) b = Block.load(ds) - self.assertEqual(len(b.begin), 1) - assert isinstance(b.begin[0], Task) + self.assertEqual(len(b.block), 1) + assert isinstance(b.block[0], Task) self.assertEqual(len(b.rescue), 1) assert isinstance(b.rescue[0], Task) - self.assertEqual(len(b.end), 1) - assert isinstance(b.end[0], Task) - self.assertEqual(len(b.otherwise), 1) - assert isinstance(b.otherwise[0], Task) + self.assertEqual(len(b.always), 1) + assert isinstance(b.always[0], Task) + # not currently used + #self.assertEqual(len(b.otherwise), 1) + #assert isinstance(b.otherwise[0], Task) + + def test_load_implicit_block(self): + ds = [dict(action='foo')] + b = Block.load(ds) + self.assertEqual(len(b.block), 1) + assert isinstance(b.block[0], Task) diff --git a/v2/test/playbook/test_role.py b/v2/test/playbook/test_role.py new file mode 100644 index 0000000000..55e170de24 --- /dev/null +++ b/v2/test/playbook/test_role.py @@ -0,0 +1,52 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.playbook.block import Block +from ansible.playbook.role import Role +from ansible.playbook.task import Task +from .. compat import unittest + +class TestRole(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_construct_empty_block(self): + r = Role() + + def test_role__load_list_of_blocks(self): + task = dict(action='test') + r = Role() + self.assertEqual(r._load_list_of_blocks([]), []) + res = r._load_list_of_blocks([task]) + self.assertEqual(len(res), 1) + assert isinstance(res[0], Block) + res = r._load_list_of_blocks([task,task,task]) + self.assertEqual(len(res), 3) + + def test_load_role_simple(self): + pass + + def test_load_role_complex(self): + pass