1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Overhauls to v2 code

* using inspect module instead of iteritems(self.__class__.__dict__, due
  to the fact that the later does not include attributes from parent
  classes
* added tags/when attributes to Base() class for use by all subclasses
* removed value/callable code from Attribute, as they are not used
* started moving some limited code from utils to new places in v2 tree
  (vault, yaml-parsing related defs)
* re-added ability of Block.load() to create implicit blocks from tasks
* started overhaul of Role class and role-related code
This commit is contained in:
James Cammarata 2014-10-19 00:14:30 -05:00
parent 28fd4df787
commit b0069a338e
9 changed files with 1130 additions and 91 deletions

View file

@ -19,19 +19,205 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import json
from yaml import YAMLError
from ansible.errors import AnsibleError, AnsibleInternalError from ansible.errors import AnsibleError, AnsibleInternalError
from ansible.parsing.vault import VaultLib
from ansible.parsing.yaml import safe_load
def load(self, data):
if instanceof(data, file): def process_common_errors(msg, probline, column):
replaced = probline.replace(" ","")
if ":{{" in replaced and "}}" in replaced:
msg = msg + """
This one looks easy to fix. YAML thought it was looking for the start of a
hash/dictionary and was confused to see a second "{". Most likely this was
meant to be an ansible template evaluation instead, so we have to give the
parser a small hint that we wanted a string instead. The solution here is to
just quote the entire value.
For instance, if the original line was:
app_path: {{ base_path }}/foo
It should be written as:
app_path: "{{ base_path }}/foo"
"""
return msg
elif len(probline) and len(probline) > 1 and len(probline) > column and probline[column] == ":" and probline.count(':') > 1:
msg = msg + """
This one looks easy to fix. There seems to be an extra unquoted colon in the line
and this is confusing the parser. It was only expecting to find one free
colon. The solution is just add some quotes around the colon, or quote the
entire line after the first colon.
For instance, if the original line was:
copy: src=file.txt dest=/path/filename:with_colon.txt
It can be written as:
copy: src=file.txt dest='/path/filename:with_colon.txt'
Or:
copy: 'src=file.txt dest=/path/filename:with_colon.txt'
"""
return msg
else:
parts = probline.split(":")
if len(parts) > 1:
middle = parts[1].strip()
match = False
unbalanced = False
if middle.startswith("'") and not middle.endswith("'"):
match = True
elif middle.startswith('"') and not middle.endswith('"'):
match = True
if len(middle) > 0 and middle[0] in [ '"', "'" ] and middle[-1] in [ '"', "'" ] and probline.count("'") > 2 or probline.count('"') > 2:
unbalanced = True
if match:
msg = msg + """
This one looks easy to fix. It seems that there is a value started
with a quote, and the YAML parser is expecting to see the line ended
with the same kind of quote. For instance:
when: "ok" in result.stdout
Could be written as:
when: '"ok" in result.stdout'
or equivalently:
when: "'ok' in result.stdout"
"""
return msg
if unbalanced:
msg = msg + """
We could be wrong, but this one looks like it might be an issue with
unbalanced quotes. If starting a value with a quote, make sure the
line ends with the same set of quotes. For instance this arbitrary
example:
foo: "bad" "wolf"
Could be written as:
foo: '"bad" "wolf"'
"""
return msg
return msg
def process_yaml_error(exc, data, path=None, show_content=True):
if hasattr(exc, 'problem_mark'):
mark = exc.problem_mark
if show_content:
if mark.line -1 >= 0:
before_probline = data.split("\n")[mark.line-1]
else:
before_probline = ''
probline = data.split("\n")[mark.line]
arrow = " " * mark.column + "^"
msg = """Syntax Error while loading YAML script, %s
Note: The error may actually appear before this position: line %s, column %s
%s
%s
%s""" % (path, mark.line + 1, mark.column + 1, before_probline, probline, arrow)
unquoted_var = None
if '{{' in probline and '}}' in probline:
if '"{{' not in probline or "'{{" not in probline:
unquoted_var = True
if not unquoted_var:
msg = process_common_errors(msg, probline, mark.column)
else:
msg = msg + """
We could be wrong, but this one looks like it might be an issue with
missing quotes. Always quote template expression brackets when they
start a value. For instance:
with_items:
- {{ foo }}
Should be written as:
with_items:
- "{{ foo }}"
"""
else:
# most likely displaying a file with sensitive content,
# so don't show any of the actual lines of yaml just the
# line number itself
msg = """Syntax error while loading YAML script, %s
The error appears to have been on line %s, column %s, but may actually
be before there depending on the exact syntax problem.
""" % (path, mark.line + 1, mark.column + 1)
else:
# No problem markers means we have to throw a generic
# "stuff messed up" type message. Sry bud.
if path:
msg = "Could not parse YAML. Check over %s again." % path
else:
msg = "Could not parse YAML."
raise errors.AnsibleYAMLValidationFailed(msg)
def load_data(data):
if isinstance(data, file):
fd = open(f) fd = open(f)
data = fd.read() data = fd.read()
fd.close() fd.close()
if instanceof(data, basestring): if isinstance(data, basestring):
try: try:
return json.loads(data) return json.loads(data)
except: except:
return safe_load(data) return safe_load(data)
raise AnsibleInternalError("expected file or string, got %s" % type(data)) raise AnsibleInternalError("expected file or string, got %s" % type(data))
def load_data_from_file(path, vault_password=None):
'''
Convert a yaml file to a data structure.
Was previously 'parse_yaml_from_file()'.
'''
data = None
show_content = True
try:
data = open(path).read()
except IOError:
raise errors.AnsibleError("file could not read: %s" % path)
vault = VaultLib(password=vault_password)
if vault.is_encrypted(data):
# if the file is encrypted and no password was specified,
# the decrypt call would throw an error, but we check first
# since the decrypt function doesn't know the file name
if vault_password is None:
raise errors.AnsibleError("A vault password must be specified to decrypt %s" % path)
data = vault.decrypt(data)
show_content = False
try:
return load_data(data)
except YAMLError, exc:
process_yaml_error(exc, data, path, show_content)

View file

@ -0,0 +1,563 @@
# (c) 2014, James Tanner <tanner.jc@gmail.com>
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
# ansible-pull is a script that runs ansible in local mode
# after checking out a playbooks directory from source repo. There is an
# example playbook to bootstrap this script in the examples/ dir which
# installs ansible and sets it up to run on cron.
import os
import shlex
import shutil
import tempfile
from io import BytesIO
from subprocess import call
from ansible import errors
from hashlib import sha256
from hashlib import md5
from binascii import hexlify
from binascii import unhexlify
from ansible import constants as C
try:
from Crypto.Hash import SHA256, HMAC
HAS_HASH = True
except ImportError:
HAS_HASH = False
# Counter import fails for 2.0.1, requires >= 2.6.1 from pip
try:
from Crypto.Util import Counter
HAS_COUNTER = True
except ImportError:
HAS_COUNTER = False
# KDF import fails for 2.0.1, requires >= 2.6.1 from pip
try:
from Crypto.Protocol.KDF import PBKDF2
HAS_PBKDF2 = True
except ImportError:
HAS_PBKDF2 = False
# AES IMPORTS
try:
from Crypto.Cipher import AES as AES
HAS_AES = True
except ImportError:
HAS_AES = False
CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto"
HEADER='$ANSIBLE_VAULT'
CIPHER_WHITELIST=['AES', 'AES256']
class VaultLib(object):
def __init__(self, password):
self.password = password
self.cipher_name = None
self.version = '1.1'
def is_encrypted(self, data):
if data.startswith(HEADER):
return True
else:
return False
def encrypt(self, data):
if self.is_encrypted(data):
raise errors.AnsibleError("data is already encrypted")
if not self.cipher_name:
self.cipher_name = "AES256"
#raise errors.AnsibleError("the cipher must be set before encrypting data")
if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST:
cipher = globals()['Vault' + self.cipher_name]
this_cipher = cipher()
else:
raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name)
"""
# combine sha + data
this_sha = sha256(data).hexdigest()
tmp_data = this_sha + "\n" + data
"""
# encrypt sha + data
enc_data = this_cipher.encrypt(data, self.password)
# add header
tmp_data = self._add_header(enc_data)
return tmp_data
def decrypt(self, data):
if self.password is None:
raise errors.AnsibleError("A vault password must be specified to decrypt data")
if not self.is_encrypted(data):
raise errors.AnsibleError("data is not encrypted")
# clean out header
data = self._split_header(data)
# create the cipher object
if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST:
cipher = globals()['Vault' + self.cipher_name]
this_cipher = cipher()
else:
raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name)
# try to unencrypt data
data = this_cipher.decrypt(data, self.password)
if data is None:
raise errors.AnsibleError("Decryption failed")
return data
def _add_header(self, data):
# combine header and encrypted data in 80 char columns
#tmpdata = hexlify(data)
tmpdata = [data[i:i+80] for i in range(0, len(data), 80)]
if not self.cipher_name:
raise errors.AnsibleError("the cipher must be set before adding a header")
dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher_name + "\n"
for l in tmpdata:
dirty_data += l + '\n'
return dirty_data
def _split_header(self, data):
# used by decrypt
tmpdata = data.split('\n')
tmpheader = tmpdata[0].strip().split(';')
self.version = str(tmpheader[1].strip())
self.cipher_name = str(tmpheader[2].strip())
clean_data = '\n'.join(tmpdata[1:])
"""
# strip out newline, join, unhex
clean_data = [ x.strip() for x in clean_data ]
clean_data = unhexlify(''.join(clean_data))
"""
return clean_data
def __enter__(self):
return self
def __exit__(self, *err):
pass
class VaultEditor(object):
# uses helper methods for write_file(self, filename, data)
# to write a file so that code isn't duplicated for simple
# file I/O, ditto read_file(self, filename) and launch_editor(self, filename)
# ... "Don't Repeat Yourself", etc.
def __init__(self, cipher_name, password, filename):
# instantiates a member variable for VaultLib
self.cipher_name = cipher_name
self.password = password
self.filename = filename
def create_file(self):
""" create a new encrypted file """
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE)
if os.path.isfile(self.filename):
raise errors.AnsibleError("%s exists, please use 'edit' instead" % self.filename)
# drop the user into vim on file
old_umask = os.umask(0077)
call(self._editor_shell_command(self.filename))
tmpdata = self.read_data(self.filename)
this_vault = VaultLib(self.password)
this_vault.cipher_name = self.cipher_name
enc_data = this_vault.encrypt(tmpdata)
self.write_data(enc_data, self.filename)
os.umask(old_umask)
def decrypt_file(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE)
if not os.path.isfile(self.filename):
raise errors.AnsibleError("%s does not exist" % self.filename)
tmpdata = self.read_data(self.filename)
this_vault = VaultLib(self.password)
if this_vault.is_encrypted(tmpdata):
dec_data = this_vault.decrypt(tmpdata)
if dec_data is None:
raise errors.AnsibleError("Decryption failed")
else:
self.write_data(dec_data, self.filename)
else:
raise errors.AnsibleError("%s is not encrypted" % self.filename)
def edit_file(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE)
# make sure the umask is set to a sane value
old_mask = os.umask(0077)
# decrypt to tmpfile
tmpdata = self.read_data(self.filename)
this_vault = VaultLib(self.password)
dec_data = this_vault.decrypt(tmpdata)
_, tmp_path = tempfile.mkstemp()
self.write_data(dec_data, tmp_path)
# drop the user into vim on the tmp file
call(self._editor_shell_command(tmp_path))
new_data = self.read_data(tmp_path)
# create new vault
new_vault = VaultLib(self.password)
# we want the cipher to default to AES256
#new_vault.cipher_name = this_vault.cipher_name
# encrypt new data a write out to tmp
enc_data = new_vault.encrypt(new_data)
self.write_data(enc_data, tmp_path)
# shuffle tmp file into place
self.shuffle_files(tmp_path, self.filename)
# and restore the old umask
os.umask(old_mask)
def view_file(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE)
# decrypt to tmpfile
tmpdata = self.read_data(self.filename)
this_vault = VaultLib(self.password)
dec_data = this_vault.decrypt(tmpdata)
_, tmp_path = tempfile.mkstemp()
self.write_data(dec_data, tmp_path)
# drop the user into pager on the tmp file
call(self._pager_shell_command(tmp_path))
os.remove(tmp_path)
def encrypt_file(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE)
if not os.path.isfile(self.filename):
raise errors.AnsibleError("%s does not exist" % self.filename)
tmpdata = self.read_data(self.filename)
this_vault = VaultLib(self.password)
this_vault.cipher_name = self.cipher_name
if not this_vault.is_encrypted(tmpdata):
enc_data = this_vault.encrypt(tmpdata)
self.write_data(enc_data, self.filename)
else:
raise errors.AnsibleError("%s is already encrypted" % self.filename)
def rekey_file(self, new_password):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE)
# decrypt
tmpdata = self.read_data(self.filename)
this_vault = VaultLib(self.password)
dec_data = this_vault.decrypt(tmpdata)
# create new vault
new_vault = VaultLib(new_password)
# we want to force cipher to the default
#new_vault.cipher_name = this_vault.cipher_name
# re-encrypt data and re-write file
enc_data = new_vault.encrypt(dec_data)
self.write_data(enc_data, self.filename)
def read_data(self, filename):
f = open(filename, "rb")
tmpdata = f.read()
f.close()
return tmpdata
def write_data(self, data, filename):
if os.path.isfile(filename):
os.remove(filename)
f = open(filename, "wb")
f.write(data)
f.close()
def shuffle_files(self, src, dest):
# overwrite dest with src
if os.path.isfile(dest):
os.remove(dest)
shutil.move(src, dest)
def _editor_shell_command(self, filename):
EDITOR = os.environ.get('EDITOR','vim')
editor = shlex.split(EDITOR)
editor.append(filename)
return editor
def _pager_shell_command(self, filename):
PAGER = os.environ.get('PAGER','less')
pager = shlex.split(PAGER)
pager.append(filename)
return pager
########################################
# CIPHERS #
########################################
class VaultAES(object):
# this version has been obsoleted by the VaultAES256 class
# which uses encrypt-then-mac (fixing order) and also improving the KDF used
# code remains for upgrade purposes only
# http://stackoverflow.com/a/16761459
def __init__(self):
if not HAS_AES:
raise errors.AnsibleError(CRYPTO_UPGRADE)
def aes_derive_key_and_iv(self, password, salt, key_length, iv_length):
""" Create a key and an initialization vector """
d = d_i = ''
while len(d) < key_length + iv_length:
d_i = md5(d_i + password + salt).digest()
d += d_i
key = d[:key_length]
iv = d[key_length:key_length+iv_length]
return key, iv
def encrypt(self, data, password, key_length=32):
""" Read plaintext data from in_file and write encrypted to out_file """
# combine sha + data
this_sha = sha256(data).hexdigest()
tmp_data = this_sha + "\n" + data
in_file = BytesIO(tmp_data)
in_file.seek(0)
out_file = BytesIO()
bs = AES.block_size
# Get a block of random data. EL does not have Crypto.Random.new()
# so os.urandom is used for cross platform purposes
salt = os.urandom(bs - len('Salted__'))
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES.new(key, AES.MODE_CBC, iv)
out_file.write('Salted__' + salt)
finished = False
while not finished:
chunk = in_file.read(1024 * bs)
if len(chunk) == 0 or len(chunk) % bs != 0:
padding_length = (bs - len(chunk) % bs) or bs
chunk += padding_length * chr(padding_length)
finished = True
out_file.write(cipher.encrypt(chunk))
out_file.seek(0)
enc_data = out_file.read()
tmp_data = hexlify(enc_data)
return tmp_data
def decrypt(self, data, password, key_length=32):
""" Read encrypted data from in_file and write decrypted to out_file """
# http://stackoverflow.com/a/14989032
data = ''.join(data.split('\n'))
data = unhexlify(data)
in_file = BytesIO(data)
in_file.seek(0)
out_file = BytesIO()
bs = AES.block_size
salt = in_file.read(bs)[len('Salted__'):]
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES.new(key, AES.MODE_CBC, iv)
next_chunk = ''
finished = False
while not finished:
chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs))
if len(next_chunk) == 0:
padding_length = ord(chunk[-1])
chunk = chunk[:-padding_length]
finished = True
out_file.write(chunk)
# reset the stream pointer to the beginning
out_file.seek(0)
new_data = out_file.read()
# split out sha and verify decryption
split_data = new_data.split("\n")
this_sha = split_data[0]
this_data = '\n'.join(split_data[1:])
test_sha = sha256(this_data).hexdigest()
if this_sha != test_sha:
raise errors.AnsibleError("Decryption failed")
#return out_file.read()
return this_data
class VaultAES256(object):
"""
Vault implementation using AES-CTR with an HMAC-SHA256 authentication code.
Keys are derived using PBKDF2
"""
# http://www.daemonology.net/blog/2009-06-11-cryptographic-right-answers.html
def __init__(self):
if not HAS_PBKDF2 or not HAS_COUNTER or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE)
def gen_key_initctr(self, password, salt):
# 16 for AES 128, 32 for AES256
keylength = 32
# match the size used for counter.new to avoid extra work
ivlength = 16
hash_function = SHA256
# make two keys and one iv
pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest()
derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength,
count=10000, prf=pbkdf2_prf)
key1 = derivedkey[:keylength]
key2 = derivedkey[keylength:(keylength * 2)]
iv = derivedkey[(keylength * 2):(keylength * 2) + ivlength]
return key1, key2, hexlify(iv)
def encrypt(self, data, password):
salt = os.urandom(32)
key1, key2, iv = self.gen_key_initctr(password, salt)
# PKCS#7 PAD DATA http://tools.ietf.org/html/rfc5652#section-6.3
bs = AES.block_size
padding_length = (bs - len(data) % bs) or bs
data += padding_length * chr(padding_length)
# COUNTER.new PARAMETERS
# 1) nbits (integer) - Length of the counter, in bits.
# 2) initial_value (integer) - initial value of the counter. "iv" from gen_key_initctr
ctr = Counter.new(128, initial_value=long(iv, 16))
# AES.new PARAMETERS
# 1) AES key, must be either 16, 24, or 32 bytes long -- "key" from gen_key_initctr
# 2) MODE_CTR, is the recommended mode
# 3) counter=<CounterObject>
cipher = AES.new(key1, AES.MODE_CTR, counter=ctr)
# ENCRYPT PADDED DATA
cryptedData = cipher.encrypt(data)
# COMBINE SALT, DIGEST AND DATA
hmac = HMAC.new(key2, cryptedData, SHA256)
message = "%s\n%s\n%s" % ( hexlify(salt), hmac.hexdigest(), hexlify(cryptedData) )
message = hexlify(message)
return message
def decrypt(self, data, password):
# SPLIT SALT, DIGEST, AND DATA
data = ''.join(data.split("\n"))
data = unhexlify(data)
salt, cryptedHmac, cryptedData = data.split("\n", 2)
salt = unhexlify(salt)
cryptedData = unhexlify(cryptedData)
key1, key2, iv = self.gen_key_initctr(password, salt)
# EXIT EARLY IF DIGEST DOESN'T MATCH
hmacDecrypt = HMAC.new(key2, cryptedData, SHA256)
if not self.is_equal(cryptedHmac, hmacDecrypt.hexdigest()):
return None
# SET THE COUNTER AND THE CIPHER
ctr = Counter.new(128, initial_value=long(iv, 16))
cipher = AES.new(key1, AES.MODE_CTR, counter=ctr)
# DECRYPT PADDED DATA
decryptedData = cipher.decrypt(cryptedData)
# UNPAD DATA
padding_length = ord(decryptedData[-1])
decryptedData = decryptedData[:-padding_length]
return decryptedData
def is_equal(self, a, b):
# http://codahale.com/a-lesson-in-timing-attacks/
if len(a) != len(b):
return False
result = 0
for x, y in zip(a, b):
result |= ord(x) ^ ord(y)
return result == 0

View file

@ -25,11 +25,7 @@ class Attribute:
self.isa = isa self.isa = isa
self.private = private self.private = private
self.value = None
self.default = default self.default = default
def __call__(self):
return self.value
class FieldAttribute(Attribute): class FieldAttribute(Attribute):
pass pass

View file

@ -19,24 +19,39 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from inspect import getmembers
from io import FileIO from io import FileIO
from six import iteritems, string_types from six import iteritems, string_types
from ansible.playbook.attribute import Attribute, FieldAttribute from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.parsing import load as ds_load from ansible.parsing import load_data
class Base: class Base:
_tags = FieldAttribute(isa='list')
_when = FieldAttribute(isa='list')
def __init__(self): def __init__(self):
# each class knows attributes set upon it, see Task.py for example # each class knows attributes set upon it, see Task.py for example
self._attributes = dict() self._attributes = dict()
for (name, value) in iteritems(self.__class__.__dict__): for (name, value) in self._get_base_attributes().iteritems():
aname = name[1:] self._attributes[name] = value.default
def _get_base_attributes(self):
'''
Returns the list of attributes for this class (or any subclass thereof).
If the attribute name starts with an underscore, it is removed
'''
base_attributes = dict()
for (name, value) in getmembers(self.__class__):
if isinstance(value, Attribute): if isinstance(value, Attribute):
self._attributes[aname] = value.default if name.startswith('_'):
name = name[1:]
base_attributes[name] = value
return base_attributes
def munge(self, ds): def munge(self, ds):
''' infrequently used method to do some pre-processing of legacy terms ''' ''' infrequently used method to do some pre-processing of legacy terms '''
@ -49,7 +64,7 @@ class Base:
assert ds is not None assert ds is not None
if isinstance(ds, string_types) or isinstance(ds, FileIO): if isinstance(ds, string_types) or isinstance(ds, FileIO):
ds = ds_load(ds) ds = load_data(ds)
# we currently don't do anything with private attributes but may # we currently don't do anything with private attributes but may
# later decide to filter them out of 'ds' here. # later decide to filter them out of 'ds' here.
@ -57,20 +72,15 @@ class Base:
ds = self.munge(ds) ds = self.munge(ds)
# walk all attributes in the class # walk all attributes in the class
for (name, attribute) in iteritems(self.__class__.__dict__): for (name, attribute) in self._get_base_attributes().iteritems():
aname = name[1:]
# process Field attributes which get loaded from the YAML # copy the value over unless a _load_field method is defined
if name in ds:
if isinstance(attribute, FieldAttribute): method = getattr(self, '_load_%s' % name, None)
if method:
# copy the value over unless a _load_field method is defined self._attributes[name] = method(name, ds[name])
if aname in ds: else:
method = getattr(self, '_load_%s' % aname, None) self._attributes[name] = ds[name]
if method:
self._attributes[aname] = method(aname, ds[aname])
else:
self._attributes[aname] = ds[aname]
# return the constructed object # return the constructed object
self.validate() self.validate()
@ -81,20 +91,12 @@ class Base:
''' validation that is done at parse time, not load time ''' ''' validation that is done at parse time, not load time '''
# walk all fields in the object # walk all fields in the object
for (name, attribute) in self.__dict__.iteritems(): for (name, attribute) in self._get_base_attributes().iteritems():
# find any field attributes # run validator only if present
if isinstance(attribute, FieldAttribute): method = getattr(self, '_validate_%s' % name, None)
if method:
if not name.startswith("_"): method(self, attribute)
raise AnsibleError("FieldAttribute %s must start with _" % name)
aname = name[1:]
# run validator only if present
method = getattr(self, '_validate_%s' % (prefix, aname), None)
if method:
method(self, attribute)
def post_validate(self, runner_context): def post_validate(self, runner_context):
''' '''

View file

@ -25,11 +25,13 @@ from ansible.playbook.attribute import Attribute, FieldAttribute
class Block(Base): class Block(Base):
# TODO: FIXME: block/rescue/always should be enough _block = FieldAttribute(isa='list')
_begin = FieldAttribute(isa='list')
_rescue = FieldAttribute(isa='list') _rescue = FieldAttribute(isa='list')
_end = FieldAttribute(isa='list') _always = FieldAttribute(isa='list')
_otherwise = FieldAttribute(isa='list')
# for future consideration? this would be functionally
# similar to the 'else' clause for exceptions
#_otherwise = FieldAttribute(isa='list')
def __init__(self, role=None): def __init__(self, role=None):
self.role = role self.role = role
@ -45,6 +47,20 @@ class Block(Base):
b = Block(role=role) b = Block(role=role)
return b.load_data(data) return b.load_data(data)
def munge(self, ds):
'''
If a simple task is given, an implicit block for that single task
is created, which goes in the main portion of the block
'''
is_block = False
for attr in ('block', 'rescue', 'always'):
if attr in ds:
is_block = True
break
if not is_block:
return dict(block=ds)
return ds
def _load_list_of_tasks(self, ds): def _load_list_of_tasks(self, ds):
assert type(ds) == list assert type(ds) == list
task_list = [] task_list = []
@ -53,15 +69,16 @@ class Block(Base):
task_list.append(t) task_list.append(t)
return task_list return task_list
def _load_begin(self, attr, ds): def _load_block(self, attr, ds):
return self._load_list_of_tasks(ds) return self._load_list_of_tasks(ds)
def _load_rescue(self, attr, ds): def _load_rescue(self, attr, ds):
return self._load_list_of_tasks(ds) return self._load_list_of_tasks(ds)
def _load_end(self, attr, ds): def _load_always(self, attr, ds):
return self._load_list_of_tasks(ds) return self._load_list_of_tasks(ds)
def _load_otherwise(self, attr, ds): # not currently used
return self._load_list_of_tasks(ds) #def _load_otherwise(self, attr, ds):
# return self._load_list_of_tasks(ds)

View file

@ -19,31 +19,251 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from v2.playbook.base import PlaybookBase from six import iteritems, string_types
from v2.utils import list_union
class Role(PlaybookBase): import os
# TODO: this will be overhauled to match Task.py at some point from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.block import Block
from ansible.parsing import load_data_from_file
def __init__(self): #from ansible.utils import list_union, unfrackpath
pass
class Role(Base):
_role = FieldAttribute(isa='string')
_src = FieldAttribute(isa='string')
_scm = FieldAttribute(isa='string')
_version = FieldAttribute(isa='string')
_params = FieldAttribute(isa='dict')
_metadata = FieldAttribute(isa='dict')
_task_blocks = FieldAttribute(isa='list')
_handler_blocks = FieldAttribute(isa='list')
_default_vars = FieldAttribute(isa='dict')
_role_vars = FieldAttribute(isa='dict')
def __init__(self, vault_password=None):
self._role_path = None
self._vault_password = vault_password
super(Role, self).__init__()
def __repr__(self):
return self.get_name()
def get_name(self): def get_name(self):
return "TEMPORARY" return self._attributes['role']
def load(self, ds): @staticmethod
self._ds = ds def load(data, vault_password=None):
self._tasks = [] assert isinstance(data, string_types) or isinstance(data, dict)
self._handlers = [] r = Role(vault_password=vault_password)
self._blocks = [] r.load_data(data)
self._dependencies = [] return r
self._metadata = dict()
self._defaults = dict()
self._vars = dict()
self._params = dict()
def get_vars(self): #------------------------------------------------------------------------------
# munge, and other functions used for loading the ds
def munge(self, ds):
# Role definitions can be strings or dicts, so we fix
# things up here. Anything that is not a role name, tag,
# or conditional will also be added to the params sub-
# dictionary for loading later
if isinstance(ds, string_types):
new_ds = dict(role=ds)
else:
ds = self._munge_role(ds)
params = dict()
new_ds = dict()
for (key, value) in iteritems(ds):
if key not in [name for (name, value) in self._get_base_attributes().iteritems()]:
# this key does not match a field attribute,
# so it must be a role param
params[key] = value
else:
# this is a field attribute, so copy it over directly
new_ds[key] = value
# finally, assign the params to a new entry in the revised ds
new_ds['params'] = params
# set the role path, based on the role definition
self._role_path = self._get_role_path(new_ds.get('role'))
# load the role's files, if they exist
new_ds['metadata'] = self._load_role_yaml('meta')
new_ds['task_blocks'] = self._load_role_yaml('tasks')
new_ds['handler_blocks'] = self._load_role_yaml('handlers')
new_ds['default_vars'] = self._load_role_yaml('defaults')
new_ds['role_vars'] = self._load_role_yaml('vars')
return new_ds
def _load_role_yaml(self, subdir):
file_path = os.path.join(self._role_path, subdir)
if os.path.exists(file_path) and os.path.isdir(file_path):
main_file = self._resolve_main(file_path)
if os.path.exists(main_file):
return load_data_from_file(main_file, self._vault_password)
return None
def _resolve_main(self, basepath):
''' flexibly handle variations in main filenames '''
possible_mains = (
os.path.join(basepath, 'main'),
os.path.join(basepath, 'main.yml'),
os.path.join(basepath, 'main.yaml'),
os.path.join(basepath, 'main.json'),
)
if sum([os.path.isfile(x) for x in possible_mains]) > 1:
raise errors.AnsibleError("found multiple main files at %s, only one allowed" % (basepath))
else:
for m in possible_mains:
if os.path.isfile(m):
return m # exactly one main file
return possible_mains[0] # zero mains (we still need to return something)
def _get_role_path(self, role):
'''
the 'role', as specified in the ds (or as a bare string), can either
be a simple name or a full path. If it is a full path, we use the
basename as the role name, otherwise we take the name as-given and
append it to the default role path
'''
# FIXME: this should use unfrackpath once the utils code has been sorted out
role_path = os.path.normpath(role)
if os.path.exists(role_path):
return role_path
else:
for path in ('./roles', '/etc/ansible/roles'):
role_path = os.path.join(path, role)
if os.path.exists(role_path):
return role_path
# FIXME: raise something here
raise
def _repo_url_to_role_name(self, repo_url):
# gets the role name out of a repo like
# http://git.example.com/repos/repo.git" => "repo"
if '://' not in repo_url and '@' not in repo_url:
return repo_url
trailing_path = repo_url.split('/')[-1]
if trailing_path.endswith('.git'):
trailing_path = trailing_path[:-4]
if trailing_path.endswith('.tar.gz'):
trailing_path = trailing_path[:-7]
if ',' in trailing_path:
trailing_path = trailing_path.split(',')[0]
return trailing_path
def _role_spec_parse(self, role_spec):
# takes a repo and a version like
# git+http://git.example.com/repos/repo.git,v1.0
# and returns a list of properties such as:
# {
# 'scm': 'git',
# 'src': 'http://git.example.com/repos/repo.git',
# 'version': 'v1.0',
# 'name': 'repo'
# }
default_role_versions = dict(git='master', hg='tip')
role_spec = role_spec.strip()
role_version = ''
if role_spec == "" or role_spec.startswith("#"):
return (None, None, None, None)
tokens = [s.strip() for s in role_spec.split(',')]
# assume https://github.com URLs are git+https:// URLs and not
# tarballs unless they end in '.zip'
if 'github.com/' in tokens[0] and not tokens[0].startswith("git+") and not tokens[0].endswith('.tar.gz'):
tokens[0] = 'git+' + tokens[0]
if '+' in tokens[0]:
(scm, role_url) = tokens[0].split('+')
else:
scm = None
role_url = tokens[0]
if len(tokens) >= 2:
role_version = tokens[1]
if len(tokens) == 3:
role_name = tokens[2]
else:
role_name = repo_url_to_role_name(tokens[0])
if scm and not role_version:
role_version = default_role_versions.get(scm, '')
return dict(scm=scm, src=role_url, version=role_version, name=role_name)
def _munge_role(self, ds):
if 'role' in ds:
# Old style: {role: "galaxy.role,version,name", other_vars: "here" }
role_info = self._role_spec_parse(ds['role'])
if isinstance(role_info, dict):
# Warning: Slight change in behaviour here. name may be being
# overloaded. Previously, name was only a parameter to the role.
# Now it is both a parameter to the role and the name that
# ansible-galaxy will install under on the local system.
if 'name' in ds and 'name' in role_info:
del role_info['name']
ds.update(role_info)
else:
# New style: { src: 'galaxy.role,version,name', other_vars: "here" }
if 'github.com' in ds["src"] and 'http' in ds["src"] and '+' not in ds["src"] and not ds["src"].endswith('.tar.gz'):
ds["src"] = "git+" + ds["src"]
if '+' in ds["src"]:
(scm, src) = ds["src"].split('+')
ds["scm"] = scm
ds["src"] = src
if 'name' in role:
ds["role"] = ds["name"]
del ds["name"]
else:
ds["role"] = self._repo_url_to_role_name(ds["src"])
# set some values to a default value, if none were specified
ds.setdefault('version', '')
ds.setdefault('scm', None)
return ds
#------------------------------------------------------------------------------
# attribute loading defs
def _load_list_of_blocks(self, ds):
assert type(ds) == list
block_list = []
for block in ds:
b = Block(block)
block_list.append(b)
return block_list
def _load_task_blocks(self, attr, ds):
if ds is None:
return []
return self._load_list_of_blocks(ds)
def _load_handler_blocks(self, attr, ds):
if ds is None:
return []
return self._load_list_of_blocks(ds)
#------------------------------------------------------------------------------
# other functions
def get_variables(self):
# returns the merged variables for this role, including # returns the merged variables for this role, including
# recursively merging those of all child roles # recursively merging those of all child roles
return dict() return dict()
@ -60,8 +280,3 @@ class Role(PlaybookBase):
all_deps = list_union(all_deps, self.dependencies) all_deps = list_union(all_deps, self.dependencies)
return all_deps return all_deps
def get_blocks(self):
# should return
return self.blocks

View file

@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from .. compat import unittest from .. compat import unittest
from ansible.parsing import load from ansible.parsing import load_data
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
import json import json
@ -68,12 +68,12 @@ class TestGeneralParsing(unittest.TestCase):
"jkl" : 5678 "jkl" : 5678
} }
""" """
output = load(input) output = load_data(input)
self.assertEqual(output['asdf'], '1234') self.assertEqual(output['asdf'], '1234')
self.assertEqual(output['jkl'], 5678) self.assertEqual(output['jkl'], 5678)
def parse_json_from_file(self): def parse_json_from_file(self):
output = load(MockFile(dict(a=1,b=2,c=3)),'json') output = load_data(MockFile(dict(a=1,b=2,c=3)),'json')
self.assertEqual(ouput, dict(a=1,b=2,c=3)) self.assertEqual(ouput, dict(a=1,b=2,c=3))
def parse_yaml_from_dict(self): def parse_yaml_from_dict(self):
@ -81,12 +81,12 @@ class TestGeneralParsing(unittest.TestCase):
asdf: '1234' asdf: '1234'
jkl: 5678 jkl: 5678
""" """
output = load(input) output = load_data(input)
self.assertEqual(output['asdf'], '1234') self.assertEqual(output['asdf'], '1234')
self.assertEqual(output['jkl'], 5678) self.assertEqual(output['jkl'], 5678)
def parse_yaml_from_file(self): def parse_yaml_from_file(self):
output = load(MockFile(dict(a=1,b=2,c=3),'yaml')) output = load_data(MockFile(dict(a=1,b=2,c=3),'yaml'))
self.assertEqual(output, dict(a=1,b=2,c=3)) self.assertEqual(output, dict(a=1,b=2,c=3))
def parse_fail(self): def parse_fail(self):
@ -95,10 +95,10 @@ class TestGeneralParsing(unittest.TestCase):
*** ***
NOT VALID NOT VALID
""" """
self.assertRaises(load(input), AnsibleParserError) self.assertRaises(load_data(input), AnsibleParserError)
def parse_fail_from_file(self): def parse_fail_from_file(self):
self.assertRaises(load(MockFile(None,'fail')), AnsibleParserError) self.assertRaises(load_data(MockFile(None,'fail')), AnsibleParserError)
def parse_fail_invalid_type(self): def parse_fail_invalid_type(self):
self.assertRaises(3000, AnsibleParsingError) self.assertRaises(3000, AnsibleParsingError)

View file

@ -49,31 +49,39 @@ class TestBlock(unittest.TestCase):
def test_load_block_simple(self): def test_load_block_simple(self):
ds = dict( ds = dict(
begin = [], block = [],
rescue = [], rescue = [],
end = [], always = [],
otherwise = [], #otherwise = [],
) )
b = Block.load(ds) b = Block.load(ds)
self.assertEqual(b.begin, []) self.assertEqual(b.block, [])
self.assertEqual(b.rescue, []) self.assertEqual(b.rescue, [])
self.assertEqual(b.end, []) self.assertEqual(b.always, [])
self.assertEqual(b.otherwise, []) # not currently used
#self.assertEqual(b.otherwise, [])
def test_load_block_with_tasks(self): def test_load_block_with_tasks(self):
ds = dict( ds = dict(
begin = [dict(action='begin')], block = [dict(action='block')],
rescue = [dict(action='rescue')], rescue = [dict(action='rescue')],
end = [dict(action='end')], always = [dict(action='always')],
otherwise = [dict(action='otherwise')], #otherwise = [dict(action='otherwise')],
) )
b = Block.load(ds) b = Block.load(ds)
self.assertEqual(len(b.begin), 1) self.assertEqual(len(b.block), 1)
assert isinstance(b.begin[0], Task) assert isinstance(b.block[0], Task)
self.assertEqual(len(b.rescue), 1) self.assertEqual(len(b.rescue), 1)
assert isinstance(b.rescue[0], Task) assert isinstance(b.rescue[0], Task)
self.assertEqual(len(b.end), 1) self.assertEqual(len(b.always), 1)
assert isinstance(b.end[0], Task) assert isinstance(b.always[0], Task)
self.assertEqual(len(b.otherwise), 1) # not currently used
assert isinstance(b.otherwise[0], Task) #self.assertEqual(len(b.otherwise), 1)
#assert isinstance(b.otherwise[0], Task)
def test_load_implicit_block(self):
ds = [dict(action='foo')]
b = Block.load(ds)
self.assertEqual(len(b.block), 1)
assert isinstance(b.block[0], Task)

View file

@ -0,0 +1,52 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.playbook.block import Block
from ansible.playbook.role import Role
from ansible.playbook.task import Task
from .. compat import unittest
class TestRole(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_construct_empty_block(self):
r = Role()
def test_role__load_list_of_blocks(self):
task = dict(action='test')
r = Role()
self.assertEqual(r._load_list_of_blocks([]), [])
res = r._load_list_of_blocks([task])
self.assertEqual(len(res), 1)
assert isinstance(res[0], Block)
res = r._load_list_of_blocks([task,task,task])
self.assertEqual(len(res), 3)
def test_load_role_simple(self):
pass
def test_load_role_complex(self):
pass