1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Implement vault encrypted yaml variables. (#16274)

Make !vault-encrypted create a AnsibleVaultUnicode
yaml object that can be used as a regular string object.

This allows a playbook to include a encrypted vault
blob for the value of a yaml variable. A 'secret_password'
variable can have it's value encrypted instead of having
to vault encrypt an entire vars file.

Add __ENCRYPTED__ to the vault yaml types so
template.Template can treat it similar
to __UNSAFE__ flags.

vault.VaultLib api changes:
    - Split VaultLib.encrypt to encrypt and encrypt_bytestring

    - VaultLib.encrypt() previously accepted the plaintext data
      as either a byte string or a unicode string.
      Doing the right thing based on the input type would fail
      on py3 if given a arg of type 'bytes'. To simplify the
      API, vaultlib.encrypt() now assumes input plaintext is a
      py2 unicode or py3 str. It will encode to utf-8 then call
      the new encrypt_bytestring(). The new methods are less
      ambiguous.

    - moved VaultLib.is_encrypted logic to vault module scope
      and split to is_encrypted() and is_encrypted_file().

Add a test/unit/mock/yaml_helper.py
It has some helpers for testing parsing/yaml

Integration tests added as roles test_vault and test_vault_embedded
This commit is contained in:
Adrian Likins 2016-08-23 20:03:11 -04:00 committed by GitHub
parent dbf7df4439
commit e396d5d508
21 changed files with 934 additions and 111 deletions

View file

@ -149,7 +149,7 @@ class DataLoader():
def _safe_load(self, stream, file_name=None): def _safe_load(self, stream, file_name=None):
''' Implements yaml.safe_load(), except using our custom loader class. ''' ''' Implements yaml.safe_load(), except using our custom loader class. '''
loader = AnsibleLoader(stream, file_name) loader = AnsibleLoader(stream, file_name, self._vault_password)
try: try:
return loader.get_single_data() return loader.get_single_data()
finally: finally:
@ -394,7 +394,7 @@ class DataLoader():
try: try:
with open(to_bytes(real_path), 'rb') as f: with open(to_bytes(real_path), 'rb') as f:
if self._vault.is_encrypted(f): if self._vault.is_encrypted_file(f):
# if the file is encrypted and no password was specified, # if the file is encrypted and no password was specified,
# the decrypt call would throw an error, but we check first # the decrypt call would throw an error, but we check first
# since the decrypt function doesn't know the file name # since the decrypt function doesn't know the file name

View file

@ -89,8 +89,9 @@ HAS_ANY_PBKDF2HMAC = HAS_PBKDF2 or HAS_PBKDF2HMAC
CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto" CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto"
b_HEADER = b'$ANSIBLE_VAULT' b_HEADER = b'$ANSIBLE_VAULT'
HEADER = '$ANSIBLE_VAULT'
CIPHER_WHITELIST = frozenset((u'AES', u'AES256')) CIPHER_WHITELIST = frozenset((u'AES', u'AES256'))
CIPHER_WRITE_WHITELIST=frozenset((u'AES256',)) CIPHER_WRITE_WHITELIST = frozenset((u'AES256',))
# See also CIPHER_MAPPING at the bottom of the file which maps cipher strings # See also CIPHER_MAPPING at the bottom of the file which maps cipher strings
# (used in VaultFile header) to a cipher class # (used in VaultFile header) to a cipher class
@ -100,6 +101,37 @@ def check_prereqs():
if not HAS_AES or not HAS_COUNTER or not HAS_ANY_PBKDF2HMAC or not HAS_HASH: if not HAS_AES or not HAS_COUNTER or not HAS_ANY_PBKDF2HMAC or not HAS_HASH:
raise AnsibleError(CRYPTO_UPGRADE) raise AnsibleError(CRYPTO_UPGRADE)
class AnsibleVaultError(AnsibleError):
pass
def is_encrypted(b_data):
""" Test if this is vault encrypted data blob
:arg data: a python2 str or a python3 'bytes' to test whether it is
recognized as vault encrypted data
:returns: True if it is recognized. Otherwise, False.
"""
if b_data.startswith(b_HEADER):
return True
return False
def is_encrypted_file(file_obj):
"""Test if the contents of a file obj are a vault encrypted data blob.
The data read from the file_obj is expected to be bytestrings (py2 'str' or
python3 'bytes'). This more or less expects 'utf-8' encoding.
:arg file_obj: A file object that will be read from.
:returns: True if the file is a vault file. Otherwise, False.
"""
# read the header and reset the file stream to where it started
current_position = file_obj.tell()
b_header_part = file_obj.read(len(b_HEADER))
file_obj.seek(current_position)
return is_encrypted(b_header_part)
class VaultLib: class VaultLib:
def __init__(self, password): def __init__(self, password):
@ -107,36 +139,48 @@ class VaultLib:
self.cipher_name = None self.cipher_name = None
self.b_version = b'1.1' self.b_version = b'1.1'
# really b_data, but for compat
def is_encrypted(self, data): def is_encrypted(self, data):
""" Test if this is vault encrypted data """ Test if this is vault encrypted data
:arg data: a byte str or unicode string to test whether it is :arg data: a python2 utf-8 string or a python3 'bytes' to test whether it is
recognized as vault encrypted data recognized as vault encrypted data
:returns: True if it is recognized. Otherwise, False. :returns: True if it is recognized. Otherwise, False.
""" """
if hasattr(data, 'read'): # This could in the future, check to see if the data is a vault blob and
current_position = data.tell() # is encrypted with a key associated with this vault
header_part = data.read(len(b_HEADER)) # instead of just checking the format.
data.seek(current_position) return is_encrypted(data)
return self.is_encrypted(header_part)
if to_bytes(data, errors='strict', encoding='utf-8').startswith(b_HEADER): def is_encrypted_file(self, file_obj):
return True return is_encrypted_file(file_obj)
return False
def encrypt(self, data): def encrypt(self, data):
"""Vault encrypt a piece of data. """Vault encrypt a piece of data.
:arg data: a utf-8 byte str or unicode string to encrypt. :arg data: a PY2 unicode string or PY3 string to encrypt.
:returns: a utf-8 encoded byte str of encrypted data. The string :returns: a utf-8 encoded byte str of encrypted data. The string
contains a header identifying this as vault encrypted data and contains a header identifying this as vault encrypted data and
formatted to newline terminated lines of 80 characters. This is formatted to newline terminated lines of 80 characters. This is
suitable for dumping as is to a vault file. suitable for dumping as is to a vault file.
"""
b_data = to_bytes(data, errors='strict', encoding='utf-8')
if self.is_encrypted(b_data): The unicode or string passed in as data will encoded to UTF-8 before
encryption. If the a already encoded string or PY2 bytestring needs to
be encrypted, use encrypt_bytestring().
"""
plaintext = data
plaintext_bytes = plaintext.encode('utf-8')
return self.encrypt_bytestring(plaintext_bytes)
def encrypt_bytestring(self, plaintext_bytes):
'''Encrypt a PY2 bytestring.
Like encrypt(), except plaintext_bytes is not encoded to UTF-8
before encryption.'''
if self.is_encrypted(plaintext_bytes):
raise AnsibleError("input is already encrypted") raise AnsibleError("input is already encrypted")
if not self.cipher_name or self.cipher_name not in CIPHER_WRITE_WHITELIST: if not self.cipher_name or self.cipher_name not in CIPHER_WRITE_WHITELIST:
@ -149,11 +193,11 @@ class VaultLib:
this_cipher = Cipher() this_cipher = Cipher()
# encrypt data # encrypt data
b_enc_data = this_cipher.encrypt(b_data, self.b_password) ciphertext_bytes = this_cipher.encrypt(plaintext_bytes, self.b_password)
# format the data for output to the file # format the data for output to the file
b_tmp_data = self._format_output(b_enc_data) ciphertext_envelope = self._format_output(ciphertext_bytes)
return b_tmp_data return ciphertext_envelope
def decrypt(self, data, filename=None): def decrypt(self, data, filename=None):
"""Decrypt a piece of vault encrypted data. """Decrypt a piece of vault encrypted data.
@ -168,9 +212,9 @@ class VaultLib:
raise AnsibleError("A vault password must be specified to decrypt data") raise AnsibleError("A vault password must be specified to decrypt data")
if not self.is_encrypted(b_data): if not self.is_encrypted(b_data):
msg = "input is not encrypted" msg = "input is not vault encrypted data"
if filename: if filename:
msg += "%s is not encrypted" % filename msg += "%s is not a vault encrypted file" % filename
raise AnsibleError(msg) raise AnsibleError(msg)
# clean out header # clean out header
@ -178,6 +222,7 @@ class VaultLib:
# create the cipher object # create the cipher object
cipher_class_name = u'Vault{0}'.format(self.cipher_name) cipher_class_name = u'Vault{0}'.format(self.cipher_name)
if cipher_class_name in globals() and self.cipher_name in CIPHER_WHITELIST: if cipher_class_name in globals() and self.cipher_name in CIPHER_WHITELIST:
Cipher = globals()[cipher_class_name] Cipher = globals()[cipher_class_name]
this_cipher = Cipher() this_cipher = Cipher()
@ -205,10 +250,11 @@ class VaultLib:
if not self.cipher_name: if not self.cipher_name:
raise AnsibleError("the cipher must be set before adding a header") raise AnsibleError("the cipher must be set before adding a header")
header = b';'.join([b_HEADER, self.b_version, b_header = HEADER.encode('utf-8')
to_bytes(self.cipher_name, errors='strict', encoding='utf-8')]) header = b';'.join([b_header, self.b_version,
to_bytes(self.cipher_name,'utf-8',errors='strict')])
tmpdata = [header] tmpdata = [header]
tmpdata += [b_data[i:i+80] for i in range(0, len(b_data), 80)] tmpdata += [b_data[i:i + 80] for i in range(0, len(b_data), 80)]
tmpdata += [b''] tmpdata += [b'']
tmpdata = b'\n'.join(tmpdata) tmpdata = b'\n'.join(tmpdata)
@ -243,6 +289,7 @@ class VaultEditor:
def __init__(self, password): def __init__(self, password):
self.vault = VaultLib(password) self.vault = VaultLib(password)
# TODO: mv shred file stuff to it's own class
def _shred_file_custom(self, tmp_path): def _shred_file_custom(self, tmp_path):
""""Destroy a file, when shred (core-utils) is not available """"Destroy a file, when shred (core-utils) is not available
@ -277,7 +324,6 @@ class VaultEditor:
assert(fh.tell() == file_len) # FIXME remove this assert once we have unittests to check its accuracy assert(fh.tell() == file_len) # FIXME remove this assert once we have unittests to check its accuracy
os.fsync(fh) os.fsync(fh)
def _shred_file(self, tmp_path): def _shred_file(self, tmp_path):
"""Securely destroy a decrypted file """Securely destroy a decrypted file
@ -335,7 +381,9 @@ class VaultEditor:
return return
# encrypt new data and write out to tmp # encrypt new data and write out to tmp
enc_data = self.vault.encrypt(tmpdata) # An existing vaultfile will always be UTF-8,
# so decode to unicode here
enc_data = self.vault.encrypt(tmpdata.decode())
self.write_data(enc_data, tmp_path) self.write_data(enc_data, tmp_path)
# shuffle tmp file into place # shuffle tmp file into place
@ -345,8 +393,10 @@ class VaultEditor:
check_prereqs() check_prereqs()
# A file to be encrypted into a vaultfile could be any encoding
# so treat the contents as a byte string.
plaintext = self.read_data(filename) plaintext = self.read_data(filename)
ciphertext = self.vault.encrypt(plaintext) ciphertext = self.vault.encrypt_bytestring(plaintext)
self.write_data(ciphertext, output_file or filename) self.write_data(ciphertext, output_file or filename)
def decrypt_file(self, filename, output_file=None): def decrypt_file(self, filename, output_file=None):
@ -433,15 +483,20 @@ class VaultEditor:
return data return data
# TODO: add docstrings for arg types since this code is picky about that
def write_data(self, data, filename, shred=True): def write_data(self, data, filename, shred=True):
"""write data to given path """write data to given path
if shred==True, make sure that the original data is first shredded so :arg data: the encrypted and hexlified data as a utf-8 byte string
that is cannot be recovered :arg filename: filename to save 'data' to.
:arg shred: if shred==True, make sure that the original data is first shredded so
that is cannot be recovered.
""" """
bytes = to_bytes(data, errors='strict') # FIXME: do we need this now? data_bytes should always be a utf-8 byte string
b_file_data = to_bytes(data, errors='strict')
if filename == '-': if filename == '-':
sys.stdout.write(bytes) sys.stdout.write(b_file_data)
else: else:
if os.path.isfile(filename): if os.path.isfile(filename):
if shred: if shred:
@ -449,7 +504,7 @@ class VaultEditor:
else: else:
os.remove(filename) os.remove(filename)
with open(filename, "wb") as fh: with open(filename, "wb") as fh:
fh.write(bytes) fh.write(b_file_data)
def shuffle_files(self, src, dest): def shuffle_files(self, src, dest):
prev = None prev = None
@ -462,7 +517,7 @@ class VaultEditor:
# reset permissions if needed # reset permissions if needed
if prev is not None: if prev is not None:
#TODO: selinux, ACLs, xattr? # TODO: selinux, ACLs, xattr?
os.chmod(dest, prev.st_mode) os.chmod(dest, prev.st_mode)
os.chown(dest, prev.st_uid, prev.st_gid) os.chown(dest, prev.st_uid, prev.st_gid)
@ -488,7 +543,7 @@ class VaultFile(object):
_, self.tmpfile = tempfile.mkstemp() _, self.tmpfile = tempfile.mkstemp()
### TODO: # TODO:
# __del__ can be problematic in python... For this use case, make # __del__ can be problematic in python... For this use case, make
# VaultFile a context manager instead (implement __enter__ and __exit__) # VaultFile a context manager instead (implement __enter__ and __exit__)
def __del__(self): def __del__(self):
@ -496,11 +551,7 @@ class VaultFile(object):
os.unlink(self.tmpfile) os.unlink(self.tmpfile)
def is_encrypted(self): def is_encrypted(self):
peak = self.filehandle.readline() return is_encrypted_file(self.filehandle)
if peak.startswith(b_HEADER):
return True
else:
return False
def get_decrypted(self): def get_decrypted(self):
check_prereqs() check_prereqs()
@ -627,7 +678,6 @@ class VaultAES256:
# make two keys and one iv # make two keys and one iv
pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest() pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest()
derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength,
count=10000, prf=pbkdf2_prf) count=10000, prf=pbkdf2_prf)
return derivedkey return derivedkey
@ -657,9 +707,7 @@ class VaultAES256:
return key1, key2, hexlify(iv) return key1, key2, hexlify(iv)
def encrypt(self, data, password): def encrypt(self, data, password):
salt = os.urandom(32) salt = os.urandom(32)
key1, key2, iv = self.gen_key_initctr(password, salt) key1, key2, iv = self.gen_key_initctr(password, salt)
@ -691,20 +739,17 @@ class VaultAES256:
return message return message
def decrypt(self, data, password): def decrypt(self, data, password):
# SPLIT SALT, DIGEST, AND DATA # SPLIT SALT, DIGEST, AND DATA
data = unhexlify(data) data = unhexlify(data)
salt, cryptedHmac, cryptedData = data.split(b"\n", 2) salt, cryptedHmac, cryptedData = data.split(b"\n", 2)
salt = unhexlify(salt) salt = unhexlify(salt)
cryptedData = unhexlify(cryptedData) cryptedData = unhexlify(cryptedData)
key1, key2, iv = self.gen_key_initctr(password, salt) key1, key2, iv = self.gen_key_initctr(password, salt)
# EXIT EARLY IF DIGEST DOESN'T MATCH # EXIT EARLY IF DIGEST DOESN'T MATCH
hmacDecrypt = HMAC.new(key2, cryptedData, SHA256) hmacDecrypt = HMAC.new(key2, cryptedData, SHA256)
if not self.is_equal(cryptedHmac, to_bytes(hmacDecrypt.hexdigest())): if not self.is_equal(cryptedHmac, to_bytes(hmacDecrypt.hexdigest())):
return None return None
# SET THE COUNTER AND THE CIPHER # SET THE COUNTER AND THE CIPHER
ctr = Counter.new(128, initial_value=int(iv, 16)) ctr = Counter.new(128, initial_value=int(iv, 16))
cipher = AES.new(key1, AES.MODE_CTR, counter=ctr) cipher = AES.new(key1, AES.MODE_CTR, counter=ctr)
@ -719,7 +764,6 @@ class VaultAES256:
padding_length = decryptedData[-1] padding_length = decryptedData[-1]
decryptedData = decryptedData[:-padding_length] decryptedData = decryptedData[:-padding_length]
return decryptedData return decryptedData
def is_equal(self, a, b): def is_equal(self, a, b):
@ -746,6 +790,6 @@ class VaultAES256:
# Keys could be made bytes later if the code that gets the data is more # Keys could be made bytes later if the code that gets the data is more
# naturally byte-oriented # naturally byte-oriented
CIPHER_MAPPING = { CIPHER_MAPPING = {
u'AES': VaultAES, u'AES': VaultAES,
u'AES256': VaultAES256, u'AES256': VaultAES256,
} }

View file

@ -21,8 +21,13 @@ __metaclass__ = type
from yaml.constructor import Constructor, ConstructorError from yaml.constructor import Constructor, ConstructorError
from yaml.nodes import MappingNode from yaml.nodes import MappingNode
from ansible.parsing.yaml.objects import AnsibleMapping, AnsibleSequence, AnsibleUnicode from ansible.parsing.yaml.objects import AnsibleMapping, AnsibleSequence, AnsibleUnicode
from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode
from ansible.vars.unsafe_proxy import wrap_var from ansible.vars.unsafe_proxy import wrap_var
from ansible.parsing.vault import VaultLib
from ansible.utils.unicode import to_bytes
try: try:
from __main__ import display from __main__ import display
@ -32,9 +37,12 @@ except ImportError:
class AnsibleConstructor(Constructor): class AnsibleConstructor(Constructor):
def __init__(self, file_name=None): def __init__(self, file_name=None, vault_password=None):
self._vault_password = vault_password
self._ansible_file_name = file_name self._ansible_file_name = file_name
super(AnsibleConstructor, self).__init__() super(AnsibleConstructor, self).__init__()
self._vaults = {}
self._vaults['default'] = VaultLib(password=self._vault_password)
def construct_yaml_map(self, node): def construct_yaml_map(self, node):
data = AnsibleMapping() data = AnsibleMapping()
@ -86,6 +94,20 @@ class AnsibleConstructor(Constructor):
return ret return ret
def construct_vault_encrypted_unicode(self, node):
value = self.construct_scalar(node)
ciphertext_data = to_bytes(value)
if self._vault_password is None:
raise ConstructorError(None, None,
"found vault but no vault password provided", node.start_mark)
# could pass in a key id here to choose the vault to associate with
vault = self._vaults['default']
ret = AnsibleVaultEncryptedUnicode(ciphertext_data)
ret.vault = vault
return ret
def construct_yaml_seq(self, node): def construct_yaml_seq(self, node):
data = AnsibleSequence() data = AnsibleSequence()
yield data yield data
@ -109,6 +131,7 @@ class AnsibleConstructor(Constructor):
return (datasource, line, column) return (datasource, line, column)
AnsibleConstructor.add_constructor( AnsibleConstructor.add_constructor(
u'tag:yaml.org,2002:map', u'tag:yaml.org,2002:map',
AnsibleConstructor.construct_yaml_map) AnsibleConstructor.construct_yaml_map)
@ -132,3 +155,7 @@ AnsibleConstructor.add_constructor(
AnsibleConstructor.add_constructor( AnsibleConstructor.add_constructor(
u'!unsafe', u'!unsafe',
AnsibleConstructor.construct_yaml_unsafe) AnsibleConstructor.construct_yaml_unsafe)
AnsibleConstructor.add_constructor(
u'!vault-encrypted',
AnsibleConstructor.construct_vault_encrypted_unicode)

View file

@ -23,8 +23,10 @@ import yaml
from ansible.compat.six import PY3 from ansible.compat.six import PY3
from ansible.parsing.yaml.objects import AnsibleUnicode, AnsibleSequence, AnsibleMapping from ansible.parsing.yaml.objects import AnsibleUnicode, AnsibleSequence, AnsibleMapping
from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode
from ansible.vars.hostvars import HostVars from ansible.vars.hostvars import HostVars
class AnsibleDumper(yaml.SafeDumper): class AnsibleDumper(yaml.SafeDumper):
''' '''
A simple stub class that allows us to add representers A simple stub class that allows us to add representers
@ -35,6 +37,10 @@ class AnsibleDumper(yaml.SafeDumper):
def represent_hostvars(self, data): def represent_hostvars(self, data):
return self.represent_dict(dict(data)) return self.represent_dict(dict(data))
# Note: only want to represent the encrypted data
def represent_vault_encrypted_unicode(self, data):
return self.represent_scalar(u'!vault-encrypted', data._ciphertext.decode(), style='|')
if PY3: if PY3:
represent_unicode = yaml.representer.SafeRepresenter.represent_str represent_unicode = yaml.representer.SafeRepresenter.represent_str
else: else:
@ -60,3 +66,7 @@ AnsibleDumper.add_representer(
yaml.representer.SafeRepresenter.represent_dict, yaml.representer.SafeRepresenter.represent_dict,
) )
AnsibleDumper.add_representer(
AnsibleVaultEncryptedUnicode,
represent_vault_encrypted_unicode,
)

View file

@ -30,10 +30,11 @@ from yaml.resolver import Resolver
from ansible.parsing.yaml.constructor import AnsibleConstructor from ansible.parsing.yaml.constructor import AnsibleConstructor
if HAVE_PYYAML_C: if HAVE_PYYAML_C:
class AnsibleLoader(CParser, AnsibleConstructor, Resolver): class AnsibleLoader(CParser, AnsibleConstructor, Resolver):
def __init__(self, stream, file_name=None): def __init__(self, stream, file_name=None, vault_password=None):
CParser.__init__(self, stream) CParser.__init__(self, stream)
AnsibleConstructor.__init__(self, file_name=file_name) AnsibleConstructor.__init__(self, file_name=file_name, vault_password=vault_password)
Resolver.__init__(self) Resolver.__init__(self)
else: else:
from yaml.composer import Composer from yaml.composer import Composer
@ -42,10 +43,10 @@ else:
from yaml.parser import Parser from yaml.parser import Parser
class AnsibleLoader(Reader, Scanner, Parser, Composer, AnsibleConstructor, Resolver): class AnsibleLoader(Reader, Scanner, Parser, Composer, AnsibleConstructor, Resolver):
def __init__(self, stream, file_name=None): def __init__(self, stream, file_name=None, vault_password=None):
Reader.__init__(self, stream) Reader.__init__(self, stream)
Scanner.__init__(self) Scanner.__init__(self)
Parser.__init__(self) Parser.__init__(self)
Composer.__init__(self) Composer.__init__(self)
AnsibleConstructor.__init__(self, file_name=file_name) AnsibleConstructor.__init__(self, file_name=file_name, vault_password=vault_password)
Resolver.__init__(self) Resolver.__init__(self)

View file

@ -19,7 +19,11 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import yaml
from ansible.compat.six import text_type from ansible.compat.six import text_type
from ansible.errors import AnsibleError
from ansible.utils.unicode import to_bytes
class AnsibleBaseYAMLObject(object): class AnsibleBaseYAMLObject(object):
@ -63,3 +67,69 @@ class AnsibleUnicode(AnsibleBaseYAMLObject, text_type):
class AnsibleSequence(AnsibleBaseYAMLObject, list): class AnsibleSequence(AnsibleBaseYAMLObject, list):
''' sub class for lists ''' ''' sub class for lists '''
pass pass
# Unicode like object that is not evaluated (decrypted) until it needs to be
# TODO: is there a reason these objects are subclasses for YAMLObject?
class AnsibleVaultEncryptedUnicode(yaml.YAMLObject, AnsibleUnicode):
__UNSAFE__ = True
__ENCRYPTED__ = True
yaml_tag = u'!vault-encrypted'
@classmethod
def from_plaintext(cls, seq, vault):
if not vault:
raise vault.AnsibleVaultError('Error creating AnsibleVaultEncryptedUnicode, invalid vault (%s) provided' % vault)
ciphertext = vault.encrypt(seq)
avu = cls(ciphertext)
avu.vault = vault
return avu
def __init__(self, ciphertext):
'''A AnsibleUnicode with a Vault attribute that can decrypt it.
ciphertext is a byte string (str on PY2, bytestring on PY3).
The .data atttribute is a property that returns the decrypted plaintext
of the ciphertext as a PY2 unicode or PY3 string object.
'''
super(AnsibleVaultEncryptedUnicode, self).__init__()
# after construction, calling code has to set the .vault attribute to a vaultlib object
self.vault = None
self._ciphertext = to_bytes(ciphertext)
@property
def data(self):
if not self.vault:
# FIXME: raise exception?
return self._ciphertext
return self.vault.decrypt(self._ciphertext).decode()
@data.setter
def data(self, value):
self._ciphertext = value
def __repr__(self):
return 'AnsibleVaultEncryptedUnicode(%s)' % self._ciphertext
# Compare a regular str/text_type with the decrypted hypertext
def __eq__(self, other):
if self.vault:
return other == self.data
return False
def __hash__(self):
return id(self)
def __ne__(self, other):
if self.vault:
return other != self.data
return True
def __str__(self):
return str(self.data)
def __unicode__(self):
return unicode(self.data)

View file

@ -227,7 +227,7 @@ class Templar:
def _clean_data(self, orig_data): def _clean_data(self, orig_data):
''' remove jinja2 template tags from a string ''' ''' remove jinja2 template tags from a string '''
if not isinstance(orig_data, string_types): if not isinstance(orig_data, string_types) or hasattr(orig_data, '__ENCRYPTED__'):
return orig_data return orig_data
with contextlib.closing(StringIO(orig_data)) as data: with contextlib.closing(StringIO(orig_data)) as data:
@ -281,7 +281,7 @@ class Templar:
''' '''
Templates (possibly recursively) any given data as input. If convert_bare is Templates (possibly recursively) any given data as input. If convert_bare is
set to True, the given data will be wrapped as a jinja2 variable ('{{foo}}') set to True, the given data will be wrapped as a jinja2 variable ('{{foo}}')
before being sent through the template engine. before being sent through the template engine.
''' '''
if fail_on_undefined is None: if fail_on_undefined is None:
@ -304,7 +304,6 @@ class Templar:
result = variable result = variable
if self._contains_vars(variable): if self._contains_vars(variable):
# Check to see if the string we are trying to render is just referencing a single # Check to see if the string we are trying to render is just referencing a single
# var. In this case we don't want to accidentally change the type of the variable # var. In this case we don't want to accidentally change the type of the variable
# to a string by using the jinja template renderer. We just want to pass it. # to a string by using the jinja template renderer. We just want to pass it.
@ -438,7 +437,6 @@ class Templar:
raise AnsibleError("lookup plugin (%s) not found" % name) raise AnsibleError("lookup plugin (%s) not found" % name)
def _do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=True, fail_on_undefined=None, overrides=None): def _do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=True, fail_on_undefined=None, overrides=None):
# For preserving the number of input newlines in the output (used # For preserving the number of input newlines in the output (used
# later in this method) # later in this method)
data_newlines = _count_newlines_from_end(data) data_newlines = _count_newlines_from_end(data)
@ -517,7 +515,6 @@ class Templar:
res_newlines = _count_newlines_from_end(res) res_newlines = _count_newlines_from_end(res)
if data_newlines > res_newlines: if data_newlines > res_newlines:
res += '\n' * (data_newlines - res_newlines) res += '\n' * (data_newlines - res_newlines)
return res return res
except (UndefinedError, AnsibleUndefinedVariable) as e: except (UndefinedError, AnsibleUndefinedVariable) as e:
if fail_on_undefined: if fail_on_undefined:

View file

@ -162,6 +162,8 @@ test_vault: setup
ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --list-hosts -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --list-hosts -e outputdir=$(TEST_DIR) -e @$(VARS_FILE)
ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --syntax-check -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --syntax-check -e outputdir=$(TEST_DIR) -e @$(VARS_FILE)
ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) -e outputdir=$(TEST_DIR) -e @$(VARS_FILE)
ansible-playbook test_vault_embedded.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --syntax-check -e outputdir=$(TEST_DIR) -e @$(VARS_FILE)
ansible-playbook test_vault_embedded.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) -e outputdir=$(TEST_DIR) -e @$(VARS_FILE)
# test_delegate_to does not work unless we have permission to ssh to localhost. # test_delegate_to does not work unless we have permission to ssh to localhost.
# Would take some more effort on our test systems to implement that -- probably # Would take some more effort on our test systems to implement that -- probably

View file

@ -0,0 +1,7 @@
- assert:
that:
- 'secret_var == "secret"'
- copy: src=vault-secret.txt dest={{output_dir}}/secret.txt

View file

@ -0,0 +1,9 @@
$ANSIBLE_VAULT;1.1;AES256
31626536666232643662346539623662393436386162643439643434656231343435653936343235
6139346364396166336636383734333430373763336434310a303137623539653939336132626234
64613232396532313731313935333433353330666466646663303233323331636234326464643166
6538653264636166370a613161313064653566323037393962643032353230396536313865326362
34396262303130326632623162623230346238633932393938393766313036643835613936356233
33323730373331386337353339613165373064323134343930333031623036326164353534646631
31313963666234623731316238656233396638643331306231373539643039383434373035306233
30386230363730643561

View file

@ -0,0 +1,14 @@
---
- name: Assert that a embedded vault of a string with no newline works
assert:
that:
- '"{{ vault_encrypted_one_line_var }}" == "Setec Astronomy"'
- name: Assert that a multi line embedded vault works, including new line
assert:
that:
- vault_encrypted_var == "Setec Astronomy\n"
# TODO: add a expected fail here
# - debug: var=vault_encrypted_one_line_var_with_embedded_template

View file

@ -0,0 +1,17 @@
# If you use normal 'ansible-vault create' or edit, files always have at least one new line
# so c&p from a vault encrypted that wasn't specifically created sans new line ends up with one.
# (specifically created, as in 'echo -n "just one line" > my_secret.yml'
vault_encrypted_var: !vault-encrypted |
$ANSIBLE_VAULT;1.1;AES256
66386439653236336462626566653063336164663966303231363934653561363964363833313662
6431626536303530376336343832656537303632313433360a626438346336353331386135323734
62656361653630373231613662633962316233633936396165386439616533353965373339616234
3430613539666330390a313736323265656432366236633330313963326365653937323833366536
34623731376664623134383463316265643436343438623266623965636363326136
vault_encrypted_one_line_var: !vault-encrypted |
$ANSIBLE_VAULT;1.1;AES256
33363965326261303234626463623963633531343539616138316433353830356566396130353436
3562643163366231316662386565383735653432386435610a306664636137376132643732393835
63383038383730306639353234326630666539346233376330303938323639306661313032396437
6233623062366136310a633866373936313238333730653739323461656662303864663666653563
3138

View file

@ -1,13 +1,4 @@
- hosts: testhost - hosts: testhost
vars_files:
- vars/test_var_encrypted.yml
gather_facts: False gather_facts: False
roles:
tasks: - { role: test_vault, tags: test_vault}
- assert:
that:
- 'secret_var == "secret"'
- copy: src=vault-secret.txt dest={{output_dir}}/secret.txt

View file

@ -0,0 +1,4 @@
- hosts: testhost
gather_facts: False
roles:
- { role: test_vault_embedded, tags: test_vault_embedded}

View file

@ -0,0 +1,120 @@
import io
import yaml
from ansible.parsing.yaml.loader import AnsibleLoader
from ansible.parsing.yaml.dumper import AnsibleDumper
from ansible.compat.six import PY3
class YamlTestUtils(object):
"""Mixin class to combine with a unittest.TestCase subclass."""
def _loader(self, stream):
"""Vault related tests will want to override this.
Vault cases should setup a AnsibleLoader that has the vault password."""
return AnsibleLoader(stream)
def _dump_stream(self, obj, stream, dumper=None):
"""Dump to a py2-unicode or py3-string stream."""
if PY3:
return yaml.dump(obj, stream, Dumper=dumper)
else:
return yaml.dump(obj, stream, Dumper=dumper, encoding=None)
def _dump_string(self, obj, dumper=None):
"""Dump to a py2-unicode or py3-string"""
if PY3:
return yaml.dump(obj, Dumper=dumper)
else:
return yaml.dump(obj, Dumper=dumper, encoding=None)
def _dump_load_cycle(self, obj):
# Each pass though a dump or load revs the 'generation'
# obj to yaml string
string_from_object_dump = self._dump_string(obj, dumper=AnsibleDumper)
# wrap a stream/file like StringIO around that yaml
stream_from_object_dump = io.StringIO(string_from_object_dump)
loader = self._loader(stream_from_object_dump)
# load the yaml stream to create a new instance of the object (gen 2)
obj_2 = loader.get_data()
# dump the gen 2 objects directory to strings
string_from_object_dump_2 = self._dump_string(obj_2,
dumper=AnsibleDumper)
# The gen 1 and gen 2 yaml strings
self.assertEquals(string_from_object_dump, string_from_object_dump_2)
# the gen 1 (orig) and gen 2 py object
self.assertEquals(obj, obj_2)
# again! gen 3... load strings into py objects
stream_3 = io.StringIO(string_from_object_dump_2)
loader_3 = self._loader(stream_3)
obj_3 = loader_3.get_data()
string_from_object_dump_3 = self._dump_string(obj_3, dumper=AnsibleDumper)
self.assertEquals(obj, obj_3)
# should be transitive, but...
self.assertEquals(obj_2, obj_3)
self.assertEquals(string_from_object_dump, string_from_object_dump_3)
def _old_dump_load_cycle(self, obj):
'''Dump the passed in object to yaml, load it back up, dump again, compare.'''
stream = io.StringIO()
yaml_string = self._dump_string(obj, dumper=AnsibleDumper)
self._dump_stream(obj, stream, dumper=AnsibleDumper)
yaml_string_from_stream = stream.getvalue()
# reset stream
stream.seek(0)
loader = self._loader(stream)
# loader = AnsibleLoader(stream, vault_password=self.vault_password)
obj_from_stream = loader.get_data()
stream_from_string = io.StringIO(yaml_string)
loader2 = self._loader(stream_from_string)
# loader2 = AnsibleLoader(stream_from_string, vault_password=self.vault_password)
obj_from_string = loader2.get_data()
stream_obj_from_stream = io.StringIO()
stream_obj_from_string = io.StringIO()
if PY3:
yaml.dump(obj_from_stream, stream_obj_from_stream, Dumper=AnsibleDumper)
yaml.dump(obj_from_stream, stream_obj_from_string, Dumper=AnsibleDumper)
else:
yaml.dump(obj_from_stream, stream_obj_from_stream, Dumper=AnsibleDumper, encoding=None)
yaml.dump(obj_from_stream, stream_obj_from_string, Dumper=AnsibleDumper, encoding=None)
yaml_string_stream_obj_from_stream = stream_obj_from_stream.getvalue()
yaml_string_stream_obj_from_string = stream_obj_from_string.getvalue()
stream_obj_from_stream.seek(0)
stream_obj_from_string.seek(0)
if PY3:
yaml_string_obj_from_stream = yaml.dump(obj_from_stream, Dumper=AnsibleDumper)
yaml_string_obj_from_string = yaml.dump(obj_from_string, Dumper=AnsibleDumper)
else:
yaml_string_obj_from_stream = yaml.dump(obj_from_stream, Dumper=AnsibleDumper, encoding=None)
yaml_string_obj_from_string = yaml.dump(obj_from_string, Dumper=AnsibleDumper, encoding=None)
assert yaml_string == yaml_string_obj_from_stream
assert yaml_string == yaml_string_obj_from_stream == yaml_string_obj_from_string
assert yaml_string == yaml_string_obj_from_stream == yaml_string_obj_from_string == yaml_string_stream_obj_from_stream == yaml_string_stream_obj_from_string
assert obj == obj_from_stream
assert obj == obj_from_string
assert obj == yaml_string_obj_from_stream
assert obj == yaml_string_obj_from_string
assert obj == obj_from_stream == obj_from_string == yaml_string_obj_from_stream == yaml_string_obj_from_string
return {'obj': obj,
'yaml_string': yaml_string,
'yaml_string_from_stream': yaml_string_from_stream,
'obj_from_stream': obj_from_stream,
'obj_from_string': obj_from_string,
'yaml_string_obj_from_string': yaml_string_obj_from_string}

View file

@ -20,14 +20,12 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from six import PY3 from six import PY3
from yaml.scanner import ScannerError
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, mock_open from ansible.compat.tests.mock import patch, mock_open
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.parsing.dataloader import DataLoader from ansible.parsing.dataloader import DataLoader
from ansible.parsing.yaml.objects import AnsibleMapping
class TestDataLoader(unittest.TestCase): class TestDataLoader(unittest.TestCase):
@ -85,6 +83,6 @@ class TestDataLoaderWithVault(unittest.TestCase):
else: else:
builtins_name = '__builtin__' builtins_name = '__builtin__'
with patch(builtins_name + '.open', mock_open(read_data=vaulted_data)): with patch(builtins_name + '.open', mock_open(read_data=vaulted_data.encode('utf-8'))):
output = self._loader.load_from_file('dummy_vault.txt') output = self._loader.load_from_file('dummy_vault.txt')
self.assertEqual(output, dict(foo='bar')) self.assertEqual(output, dict(foo='bar'))

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> # (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# #
# This file is part of Ansible # This file is part of Ansible
@ -19,14 +20,12 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import getpass
import os
import shutil
import time
import tempfile
import six import six
from binascii import unhexlify import binascii
import io
import os
from binascii import hexlify from binascii import hexlify
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
@ -35,6 +34,7 @@ from ansible.utils.unicode import to_bytes, to_unicode
from ansible import errors from ansible import errors
from ansible.parsing.vault import VaultLib from ansible.parsing.vault import VaultLib
from ansible.parsing import vault
# Counter import fails for 2.0.1, requires >= 2.6.1 from pip # Counter import fails for 2.0.1, requires >= 2.6.1 from pip
try: try:
@ -57,6 +57,150 @@ try:
except ImportError: except ImportError:
HAS_AES = False HAS_AES = False
class TestVaultIsEncrypted(unittest.TestCase):
def test_utf8_not_encrypted(self):
b_data = "foobar".encode('utf8')
self.assertFalse(vault.is_encrypted(b_data))
def test_utf8_encrypted(self):
data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
b_data = data.encode('utf8')
self.assertTrue(vault.is_encrypted(b_data))
def test_bytes_not_encrypted(self):
b_data = b"foobar"
self.assertFalse(vault.is_encrypted(b_data))
def test_bytes_encrypted(self):
b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" + hexlify(b"ansible")
self.assertTrue(vault.is_encrypted(b_data))
def test_unicode_not_encrypted_py3(self):
if not six.PY3:
raise SkipTest()
data = u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ "
self.assertRaises(TypeError, vault.is_encrypted, data)
def test_unicode_not_encrypted_py2(self):
if six.PY3:
raise SkipTest()
data = u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ "
# py2 will take a unicode string, but that should always fails
self.assertFalse(vault.is_encrypted(data))
def test_unicode_is_encrypted_py3(self):
if not six.PY3:
raise SkipTest()
data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
# should still be a type error
self.assertRaises(TypeError, vault.is_encrypted, data)
def test_unicode_is_encrypted_py2(self):
if six.PY3:
raise SkipTest()
data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
# THis works, but arguably shouldn't...
self.assertTrue(vault.is_encrypted(data))
class TestVaultIsEncryptedFile(unittest.TestCase):
def test_utf8_not_encrypted(self):
b_data = "foobar".encode('utf8')
b_data_fo = io.BytesIO(b_data)
self.assertFalse(vault.is_encrypted_file(b_data_fo))
def test_utf8_encrypted(self):
data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
b_data = data.encode('utf8')
b_data_fo = io.BytesIO(b_data)
self.assertTrue(vault.is_encrypted_file(b_data_fo))
def test_bytes_not_encrypted(self):
b_data = b"foobar"
b_data_fo = io.BytesIO(b_data)
self.assertFalse(vault.is_encrypted_file(b_data_fo))
def test_bytes_encrypted(self):
b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" + hexlify(b"ansible")
b_data_fo = io.BytesIO(b_data)
self.assertTrue(vault.is_encrypted_file(b_data_fo))
class TestVaultCipherAes256(unittest.TestCase):
def test(self):
vault_cipher = vault.VaultAES256()
self.assertIsInstance(vault_cipher, vault.VaultAES256)
# TODO: tag these as slow tests
def test_create_key(self):
vault_cipher = vault.VaultAES256()
password = 'hunter42'
b_salt = os.urandom(32)
b_key = vault_cipher.create_key(password=password, salt=b_salt, keylength=32, ivlength=16)
self.assertIsInstance(b_key, six.binary_type)
def test_create_key_known(self):
vault_cipher = vault.VaultAES256()
password = 'hunter42'
# A fixed salt
b_salt = b'q' * 32 # q is the most random letter.
b_key = vault_cipher.create_key(password=password, salt=b_salt, keylength=32, ivlength=16)
self.assertIsInstance(b_key, six.binary_type)
# verify we get the same answer
# we could potentially run a few iterations of this and time it to see if it's roughly constant time
# and or that it exceeds some minimal time, but that would likely cause unreliable fails, esp in CI
b_key_2 = vault_cipher.create_key(password=password, salt=b_salt, keylength=32, ivlength=16)
self.assertIsInstance(b_key, six.binary_type)
self.assertEqual(b_key, b_key_2)
def test_is_equal_is_equal(self):
vault_cipher = vault.VaultAES256()
res = vault_cipher.is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwxyz')
self.assertTrue(res)
def test_is_equal_unequal_length(self):
vault_cipher = vault.VaultAES256()
res = vault_cipher.is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwx and sometimes y')
self.assertFalse(res)
def test_is_equal_not_equal(self):
vault_cipher = vault.VaultAES256()
res = vault_cipher.is_equal(b'abcdefghijklmnopqrstuvwxyz', b'AbcdefghijKlmnopQrstuvwxZ')
self.assertFalse(res)
def test_is_equal_empty(self):
vault_cipher = vault.VaultAES256()
res = vault_cipher.is_equal(b'', b'')
self.assertTrue(res)
# NOTE: I'm not really sure what the method should do if it doesn't get bytes,
# but this at least sees if it explodes (maybe it should?)
def test_is_equal_unicode_py3(self):
if not six.PY3:
raise SkipTest
vault_cipher = vault.VaultAES256()
self.assertRaises(TypeError, vault_cipher.is_equal,
u'私はガラスを食べられます。それは私を傷つけません。',
u'私はガラスを食べられます。それは私を傷つけません。')
def test_is_equal_unicode_py2(self):
if not six.PY2:
raise SkipTest
vault_cipher = vault.VaultAES256()
res = vault_cipher.is_equal(u'私はガラスを食べられます。それは私を傷つけません。',
u'私はガラスを食べられます。それは私を傷つけません。')
self.assertTrue(res)
def test_is_equal_unicode_different(self):
vault_cipher = vault.VaultAES256()
res = vault_cipher.is_equal(u'私はガラスを食べられます。それは私を傷つけません。',
u'Pot să mănânc sticlă și ea nu mă rănește.')
self.assertFalse(res)
class TestVaultLib(unittest.TestCase): class TestVaultLib(unittest.TestCase):
def test_methods_exist(self): def test_methods_exist(self):
@ -69,10 +213,24 @@ class TestVaultLib(unittest.TestCase):
for slot in slots: for slot in slots:
assert hasattr(v, slot), "VaultLib is missing the %s method" % slot assert hasattr(v, slot), "VaultLib is missing the %s method" % slot
def test_encrypt(self):
v = VaultLib(password='the_unit_test_password')
plaintext = u'Some text to encrypt.'
ciphertext = v.encrypt(plaintext)
self.assertIsInstance(ciphertext, (bytes, str))
# TODO: assert something...
def test_is_encrypted(self): def test_is_encrypted(self):
v = VaultLib(None) v = VaultLib(None)
assert not v.is_encrypted(u"foobar"), "encryption check on plaintext failed" assert not v.is_encrypted("foobar".encode('utf-8')), "encryption check on plaintext failed"
data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
assert v.is_encrypted(data.encode('utf-8')), "encryption check on headered text failed"
def test_is_encrypted_bytes(self):
v = VaultLib(None)
assert not v.is_encrypted(b"foobar"), "encryption check on plaintext failed"
data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" + hexlify(b"ansible")
assert v.is_encrypted(data), "encryption check on headered text failed" assert v.is_encrypted(data), "encryption check on headered text failed"
def test_format_output(self): def test_format_output(self):
@ -115,35 +273,82 @@ class TestVaultLib(unittest.TestCase):
raise SkipTest raise SkipTest
v = VaultLib('ansible') v = VaultLib('ansible')
v.cipher_name = 'AES256' v.cipher_name = 'AES256'
enc_data = v.encrypt(b"foobar") plaintext = "foobar"
enc_data = v.encrypt(plaintext)
dec_data = v.decrypt(enc_data) dec_data = v.decrypt(enc_data)
assert enc_data != b"foobar", "encryption failed" assert enc_data != b"foobar", "encryption failed"
assert dec_data == b"foobar", "decryption failed" assert dec_data == b"foobar", "decryption failed"
def test_encrypt_decrypt_aes256_existing_vault(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest
v = VaultLib('test-vault-password')
v.cipher_name = 'AES256'
plaintext = b"Setec Astronomy"
enc_data = '''$ANSIBLE_VAULT;1.1;AES256
33363965326261303234626463623963633531343539616138316433353830356566396130353436
3562643163366231316662386565383735653432386435610a306664636137376132643732393835
63383038383730306639353234326630666539346233376330303938323639306661313032396437
6233623062366136310a633866373936313238333730653739323461656662303864663666653563
3138'''
dec_data = v.decrypt(enc_data)
assert dec_data == plaintext, "decryption failed"
def test_encrypt_decrypt_aes256_bad_hmac(self):
# FIXME This test isn't working quite yet.
raise SkipTest
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest
v = VaultLib('test-vault-password')
v.cipher_name = 'AES256'
# plaintext = "Setec Astronomy"
enc_data = '''$ANSIBLE_VAULT;1.1;AES256
33363965326261303234626463623963633531343539616138316433353830356566396130353436
3562643163366231316662386565383735653432386435610a306664636137376132643732393835
63383038383730306639353234326630666539346233376330303938323639306661313032396437
6233623062366136310a633866373936313238333730653739323461656662303864663666653563
3138'''
b_data = to_bytes(enc_data, errors='strict', encoding='utf-8')
b_data = v._split_header(b_data)
foo = binascii.unhexlify(b_data)
lines = foo.splitlines()
# line 0 is salt, line 1 is hmac, line 2+ is ciphertext
b_salt = lines[0]
b_hmac = lines[1]
b_ciphertext_data = b'\n'.join(lines[2:])
b_ciphertext = binascii.unhexlify(b_ciphertext_data)
# b_orig_ciphertext = b_ciphertext[:]
# now muck with the text
# b_munged_ciphertext = b_ciphertext[:10] + b'\x00' + b_ciphertext[11:]
# b_munged_ciphertext = b_ciphertext
# assert b_orig_ciphertext != b_munged_ciphertext
b_ciphertext_data = binascii.hexlify(b_ciphertext)
b_payload = b'\n'.join([b_salt, b_hmac, b_ciphertext_data])
# reformat
b_invalid_ciphertext = v._format_output(b_payload)
# assert we throw an error
v.decrypt(b_invalid_ciphertext)
def test_encrypt_encrypted(self): def test_encrypt_encrypted(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest raise SkipTest
v = VaultLib('ansible') v = VaultLib('ansible')
v.cipher_name = 'AES' v.cipher_name = 'AES'
data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(six.b("ansible")) data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(six.b("ansible"))
error_hit = False self.assertRaises(errors.AnsibleError, v.encrypt, data,)
try:
enc_data = v.encrypt(data)
except errors.AnsibleError as e:
error_hit = True
assert error_hit, "No error was thrown when trying to encrypt data with a header"
def test_decrypt_decrypted(self): def test_decrypt_decrypted(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest raise SkipTest
v = VaultLib('ansible') v = VaultLib('ansible')
data = "ansible" data = "ansible"
error_hit = False self.assertRaises(errors.AnsibleError, v.decrypt, data)
try:
dec_data = v.decrypt(data)
except errors.AnsibleError as e:
error_hit = True
assert error_hit, "No error was thrown when trying to decrypt data without a header"
def test_cipher_not_set(self): def test_cipher_not_set(self):
# not setting the cipher should default to AES256 # not setting the cipher should default to AES256
@ -151,10 +356,5 @@ class TestVaultLib(unittest.TestCase):
raise SkipTest raise SkipTest
v = VaultLib('ansible') v = VaultLib('ansible')
data = "ansible" data = "ansible"
error_hit = False v.encrypt(data)
try: self.assertEquals(v.cipher_name, "AES256")
enc_data = v.encrypt(data)
except errors.AnsibleError as e:
error_hit = True
assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set"
assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name

View file

@ -22,13 +22,8 @@ __metaclass__ = type
#!/usr/bin/env python #!/usr/bin/env python
import sys import sys
import getpass
import os import os
import shutil
import time
import tempfile import tempfile
from binascii import unhexlify
from binascii import hexlify
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from ansible.compat.tests import unittest from ansible.compat.tests import unittest

View file

@ -0,0 +1,64 @@
# coding: utf-8
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import io
import yaml
try:
from _yaml import ParserError
except ImportError:
from yaml.parser import ParserError
from ansible.parsing.yaml import dumper
from ansible.parsing.yaml.loader import AnsibleLoader
from ansible.compat.tests import unittest
from ansible.parsing.yaml import objects
from ansible.parsing import vault
from units.mock.yaml_helper import YamlTestUtils
class TestAnsibleDumper(unittest.TestCase, YamlTestUtils):
def setUp(self):
self.vault_password = "hunter42"
self.good_vault = vault.VaultLib(self.vault_password)
self.vault = self.good_vault
self.stream = self._build_stream()
self.dumper = dumper.AnsibleDumper
def _build_stream(self,yaml_text=None):
text = yaml_text or u''
stream = io.StringIO(text)
return stream
def _loader(self, stream):
return AnsibleLoader(stream, vault_password=self.vault_password)
def test(self):
plaintext = 'This is a string we are going to encrypt.'
avu = objects.AnsibleVaultEncryptedUnicode.from_plaintext(plaintext, vault=self.vault)
yaml_out = self._dump_string(avu, dumper=self.dumper)
stream = self._build_stream(yaml_out)
loader = self._loader(stream)
data_from_yaml = loader.get_single_data()
self.assertEquals(plaintext, data_from_yaml.data)

View file

@ -26,20 +26,26 @@ from six import text_type, binary_type
from collections import Sequence, Set, Mapping from collections import Sequence, Set, Mapping
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch
from ansible import errors
from ansible.parsing.yaml.loader import AnsibleLoader from ansible.parsing.yaml.loader import AnsibleLoader
from ansible.parsing import vault
from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode
from ansible.parsing.yaml.dumper import AnsibleDumper
from ansible.utils.unicode import to_bytes
from units.mock.yaml_helper import YamlTestUtils
try: try:
from _yaml import ParserError from _yaml import ParserError
except ImportError: except ImportError:
from yaml.parser import ParserError from yaml.parser import ParserError
class NameStringIO(StringIO): class NameStringIO(StringIO):
"""In py2.6, StringIO doesn't let you set name because a baseclass has it """In py2.6, StringIO doesn't let you set name because a baseclass has it
as readonly property""" as readonly property"""
name = None name = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(NameStringIO, self).__init__(*args, **kwargs) super(NameStringIO, self).__init__(*args, **kwargs)
@ -159,6 +165,124 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
self.assertEqual(data[0][u'baz'].ansible_pos, ('myfile.yml', 2, 9)) self.assertEqual(data[0][u'baz'].ansible_pos, ('myfile.yml', 2, 9))
class TestAnsibleLoaderVault(unittest.TestCase, YamlTestUtils):
def setUp(self):
self.vault_password = "hunter42"
self.vault = vault.VaultLib(self.vault_password)
def test_wrong_password(self):
plaintext = u"Ansible"
bob_password = "this is a different password"
bobs_vault = vault.VaultLib(bob_password)
ciphertext = bobs_vault.encrypt(plaintext)
try:
self.vault.decrypt(ciphertext)
except Exception as e:
self.assertIsInstance(e, errors.AnsibleError)
self.assertEqual(e.message, 'Decryption failed')
def _encrypt_plaintext(self, plaintext):
# Construct a yaml repr of a vault by hand
vaulted_var_bytes = self.vault.encrypt(plaintext)
# add yaml tag
vaulted_var = vaulted_var_bytes.decode()
lines = vaulted_var.splitlines()
lines2 = []
for line in lines:
lines2.append(' %s' % line)
vaulted_var = '\n'.join(lines2)
tagged_vaulted_var = u"""!vault-encrypted |\n%s""" % vaulted_var
return tagged_vaulted_var
def _build_stream(self, yaml_text):
stream = NameStringIO(yaml_text)
stream.name = 'my.yml'
return stream
def _loader(self, stream):
return AnsibleLoader(stream, vault_password=self.vault_password)
def _load_yaml(self, yaml_text, password):
stream = self._build_stream(yaml_text)
loader = self._loader(stream)
data_from_yaml = loader.get_single_data()
return data_from_yaml
def test_dump_load_cycle(self):
avu = AnsibleVaultEncryptedUnicode.from_plaintext('The plaintext for test_dump_load_cycle.', vault=self.vault)
self._dump_load_cycle(avu)
def test_embedded_vault_from_dump(self):
avu = AnsibleVaultEncryptedUnicode.from_plaintext('setec astronomy', vault=self.vault)
blip = {'stuff1': [{'a dict key': 24},
{'shhh-ssh-secrets': avu,
'nothing to see here': 'move along'}],
'another key': 24.1}
blip = ['some string', 'another string', avu]
stream = NameStringIO()
self._dump_stream(blip, stream, dumper=AnsibleDumper)
print(stream.getvalue())
stream.seek(0)
stream.seek(0)
loader = self._loader(stream)
data_from_yaml = loader.get_data()
stream2 = NameStringIO(u'')
# verify we can dump the object again
self._dump_stream(data_from_yaml, stream2, dumper=AnsibleDumper)
def test_embedded_vault(self):
plaintext_var = u"""This is the plaintext string."""
tagged_vaulted_var = self._encrypt_plaintext(plaintext_var)
another_vaulted_var = self._encrypt_plaintext(plaintext_var)
different_var = u"""A different string that is not the same as the first one."""
different_vaulted_var = self._encrypt_plaintext(different_var)
yaml_text = u"""---\nwebster: daniel\noed: oxford\nthe_secret: %s\nanother_secret: %s\ndifferent_secret: %s""" % (tagged_vaulted_var, another_vaulted_var, different_vaulted_var)
data_from_yaml = self._load_yaml(yaml_text, self.vault_password)
vault_string = data_from_yaml['the_secret']
self.assertEquals(plaintext_var, data_from_yaml['the_secret'])
test_dict = {}
test_dict[vault_string] = 'did this work?'
self.assertEquals(vault_string.data, vault_string)
# This looks weird and useless, but the object in question has a custom __eq__
self.assertEquals(vault_string, vault_string)
another_vault_string = data_from_yaml['another_secret']
different_vault_string = data_from_yaml['different_secret']
self.assertEquals(vault_string, another_vault_string)
self.assertNotEquals(vault_string, different_vault_string)
# More testing of __eq__/__ne__
self.assertTrue('some string' != vault_string)
self.assertNotEquals('some string', vault_string)
# Note this is a compare of the str/unicode of these, they are diferent types
# so we want to test self == other, and other == self etc
self.assertEquals(plaintext_var, vault_string)
self.assertEquals(vault_string, plaintext_var)
self.assertFalse(plaintext_var != vault_string)
self.assertFalse(vault_string != plaintext_var)
class TestAnsibleLoaderPlay(unittest.TestCase): class TestAnsibleLoaderPlay(unittest.TestCase):
def setUp(self): def setUp(self):
@ -242,7 +366,7 @@ class TestAnsibleLoaderPlay(unittest.TestCase):
def check_vars(self): def check_vars(self):
# Numbers don't have line/col information yet # Numbers don't have line/col information yet
#self.assertEqual(self.data[0][u'vars'][u'number'].ansible_pos, (self.play_filename, 4, 21)) # self.assertEqual(self.data[0][u'vars'][u'number'].ansible_pos, (self.play_filename, 4, 21))
self.assertEqual(self.data[0][u'vars'][u'string'].ansible_pos, (self.play_filename, 5, 29)) self.assertEqual(self.data[0][u'vars'][u'string'].ansible_pos, (self.play_filename, 5, 29))
self.assertEqual(self.data[0][u'vars'][u'utf8_string'].ansible_pos, (self.play_filename, 6, 34)) self.assertEqual(self.data[0][u'vars'][u'utf8_string'].ansible_pos, (self.play_filename, 6, 34))
@ -255,8 +379,8 @@ class TestAnsibleLoaderPlay(unittest.TestCase):
self.assertEqual(self.data[0][u'vars'][u'list'][0].ansible_pos, (self.play_filename, 11, 25)) self.assertEqual(self.data[0][u'vars'][u'list'][0].ansible_pos, (self.play_filename, 11, 25))
self.assertEqual(self.data[0][u'vars'][u'list'][1].ansible_pos, (self.play_filename, 12, 25)) self.assertEqual(self.data[0][u'vars'][u'list'][1].ansible_pos, (self.play_filename, 12, 25))
# Numbers don't have line/col info yet # Numbers don't have line/col info yet
#self.assertEqual(self.data[0][u'vars'][u'list'][2].ansible_pos, (self.play_filename, 13, 25)) # self.assertEqual(self.data[0][u'vars'][u'list'][2].ansible_pos, (self.play_filename, 13, 25))
#self.assertEqual(self.data[0][u'vars'][u'list'][3].ansible_pos, (self.play_filename, 14, 25)) # self.assertEqual(self.data[0][u'vars'][u'list'][3].ansible_pos, (self.play_filename, 14, 25))
def check_tasks(self): def check_tasks(self):
# #

View file

@ -0,0 +1,129 @@
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright 2016, Adrian Likins <alikins@redhat.com>
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.compat.tests import unittest
from ansible.parsing import vault
from ansible.parsing.yaml.loader import AnsibleLoader
# module under test
from ansible.parsing.yaml import objects
from units.mock.yaml_helper import YamlTestUtils
class TestAnsibleVaultUnicodeNoVault(unittest.TestCase, YamlTestUtils):
def test_empty_init(self):
self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode)
def test_empty_string_init(self):
seq = ''.encode('utf8')
self.assert_values(seq)
def test_empty_byte_string_init(self):
seq = b''
self.assert_values(seq)
def _assert_values(self, avu, seq):
self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode)
self.assertTrue(avu.vault is None)
# AnsibleVaultEncryptedUnicode without a vault should never == any string
self.assertNotEquals(avu, seq)
def assert_values(self, seq):
avu = objects.AnsibleVaultEncryptedUnicode(seq)
self._assert_values(avu, seq)
def test_single_char(self):
seq = 'a'.encode('utf8')
self.assert_values(seq)
def test_string(self):
seq = 'some letters'
self.assert_values(seq)
def test_byte_string(self):
seq = 'some letters'.encode('utf8')
self.assert_values(seq)
class TestAnsibleVaultEncryptedUnicode(unittest.TestCase, YamlTestUtils):
def setUp(self):
self.vault_password = "hunter42"
self.good_vault = vault.VaultLib(self.vault_password)
self.wrong_vault_password = 'not-hunter42'
self.wrong_vault = vault.VaultLib(self.wrong_vault_password)
self.vault = self.good_vault
def _loader(self, stream):
return AnsibleLoader(stream, vault_password=self.vault_password)
def test_dump_load_cycle(self):
aveu = self._from_plaintext('the test string for TestAnsibleVaultEncryptedUnicode.test_dump_load_cycle')
self._dump_load_cycle(aveu)
def assert_values(self, avu, seq):
self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode)
self.assertEquals(avu, seq)
self.assertTrue(avu.vault is self.vault)
self.assertIsInstance(avu.vault, vault.VaultLib)
def _from_plaintext(self, seq):
return objects.AnsibleVaultEncryptedUnicode.from_plaintext(seq, vault=self.vault)
def _from_ciphertext(self, ciphertext):
avu = objects.AnsibleVaultEncryptedUnicode(ciphertext)
avu.vault = self.vault
return avu
def test_empty_init(self):
self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode)
def test_empty_string_init_from_plaintext(self):
seq = ''
avu = self._from_plaintext(seq)
self.assert_values(avu,seq)
def test_empty_unicode_init_from_plaintext(self):
seq = u''
avu = self._from_plaintext(seq)
self.assert_values(avu,seq)
def test_string_from_plaintext(self):
seq = 'some letters'
avu = self._from_plaintext(seq)
self.assert_values(avu,seq)
def test_unicode_from_plaintext(self):
seq = u'some letters'
avu = self._from_plaintext(seq)
self.assert_values(avu,seq)
# TODO/FIXME: make sure bad password fails differently than 'thats not encrypted'
def test_empty_string_wrong_password(self):
seq = ''
self.vault = self.wrong_vault
avu = self._from_plaintext(seq)
self.assert_values(avu, seq)