diff --git a/lib/ansible/parsing/dataloader.py b/lib/ansible/parsing/dataloader.py index d0db34ce31..9cac94d83c 100644 --- a/lib/ansible/parsing/dataloader.py +++ b/lib/ansible/parsing/dataloader.py @@ -149,7 +149,7 @@ class DataLoader(): def _safe_load(self, stream, file_name=None): ''' Implements yaml.safe_load(), except using our custom loader class. ''' - loader = AnsibleLoader(stream, file_name) + loader = AnsibleLoader(stream, file_name, self._vault_password) try: return loader.get_single_data() finally: @@ -394,7 +394,7 @@ class DataLoader(): try: with open(to_bytes(real_path), 'rb') as f: - if self._vault.is_encrypted(f): + if self._vault.is_encrypted_file(f): # if the file is encrypted and no password was specified, # the decrypt call would throw an error, but we check first # since the decrypt function doesn't know the file name diff --git a/lib/ansible/parsing/vault/__init__.py b/lib/ansible/parsing/vault/__init__.py index bf2d5b2081..42cebeb551 100644 --- a/lib/ansible/parsing/vault/__init__.py +++ b/lib/ansible/parsing/vault/__init__.py @@ -89,8 +89,9 @@ HAS_ANY_PBKDF2HMAC = HAS_PBKDF2 or HAS_PBKDF2HMAC CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto" b_HEADER = b'$ANSIBLE_VAULT' +HEADER = '$ANSIBLE_VAULT' CIPHER_WHITELIST = frozenset((u'AES', u'AES256')) -CIPHER_WRITE_WHITELIST=frozenset((u'AES256',)) +CIPHER_WRITE_WHITELIST = frozenset((u'AES256',)) # See also CIPHER_MAPPING at the bottom of the file which maps cipher strings # (used in VaultFile header) to a cipher class @@ -100,6 +101,37 @@ def check_prereqs(): if not HAS_AES or not HAS_COUNTER or not HAS_ANY_PBKDF2HMAC or not HAS_HASH: raise AnsibleError(CRYPTO_UPGRADE) + +class AnsibleVaultError(AnsibleError): + pass + +def is_encrypted(b_data): + """ Test if this is vault encrypted data blob + + :arg data: a python2 str or a python3 'bytes' to test whether it is + recognized as vault encrypted data + :returns: True if it is recognized. Otherwise, False. + """ + if b_data.startswith(b_HEADER): + return True + return False + +def is_encrypted_file(file_obj): + """Test if the contents of a file obj are a vault encrypted data blob. + + The data read from the file_obj is expected to be bytestrings (py2 'str' or + python3 'bytes'). This more or less expects 'utf-8' encoding. + + :arg file_obj: A file object that will be read from. + :returns: True if the file is a vault file. Otherwise, False. + """ + # read the header and reset the file stream to where it started + current_position = file_obj.tell() + b_header_part = file_obj.read(len(b_HEADER)) + file_obj.seek(current_position) + return is_encrypted(b_header_part) + + class VaultLib: def __init__(self, password): @@ -107,36 +139,48 @@ class VaultLib: self.cipher_name = None self.b_version = b'1.1' + # really b_data, but for compat def is_encrypted(self, data): """ Test if this is vault encrypted data - :arg data: a byte str or unicode string to test whether it is + :arg data: a python2 utf-8 string or a python3 'bytes' to test whether it is recognized as vault encrypted data :returns: True if it is recognized. Otherwise, False. """ - if hasattr(data, 'read'): - current_position = data.tell() - header_part = data.read(len(b_HEADER)) - data.seek(current_position) - return self.is_encrypted(header_part) + # This could in the future, check to see if the data is a vault blob and + # is encrypted with a key associated with this vault + # instead of just checking the format. + return is_encrypted(data) - if to_bytes(data, errors='strict', encoding='utf-8').startswith(b_HEADER): - return True - return False + def is_encrypted_file(self, file_obj): + return is_encrypted_file(file_obj) def encrypt(self, data): """Vault encrypt a piece of data. - :arg data: a utf-8 byte str or unicode string to encrypt. + :arg data: a PY2 unicode string or PY3 string to encrypt. :returns: a utf-8 encoded byte str of encrypted data. The string contains a header identifying this as vault encrypted data and formatted to newline terminated lines of 80 characters. This is suitable for dumping as is to a vault file. - """ - b_data = to_bytes(data, errors='strict', encoding='utf-8') - if self.is_encrypted(b_data): + The unicode or string passed in as data will encoded to UTF-8 before + encryption. If the a already encoded string or PY2 bytestring needs to + be encrypted, use encrypt_bytestring(). + """ + plaintext = data + plaintext_bytes = plaintext.encode('utf-8') + + return self.encrypt_bytestring(plaintext_bytes) + + def encrypt_bytestring(self, plaintext_bytes): + '''Encrypt a PY2 bytestring. + + Like encrypt(), except plaintext_bytes is not encoded to UTF-8 + before encryption.''' + + if self.is_encrypted(plaintext_bytes): raise AnsibleError("input is already encrypted") if not self.cipher_name or self.cipher_name not in CIPHER_WRITE_WHITELIST: @@ -149,11 +193,11 @@ class VaultLib: this_cipher = Cipher() # encrypt data - b_enc_data = this_cipher.encrypt(b_data, self.b_password) + ciphertext_bytes = this_cipher.encrypt(plaintext_bytes, self.b_password) # format the data for output to the file - b_tmp_data = self._format_output(b_enc_data) - return b_tmp_data + ciphertext_envelope = self._format_output(ciphertext_bytes) + return ciphertext_envelope def decrypt(self, data, filename=None): """Decrypt a piece of vault encrypted data. @@ -168,9 +212,9 @@ class VaultLib: raise AnsibleError("A vault password must be specified to decrypt data") if not self.is_encrypted(b_data): - msg = "input is not encrypted" + msg = "input is not vault encrypted data" if filename: - msg += "%s is not encrypted" % filename + msg += "%s is not a vault encrypted file" % filename raise AnsibleError(msg) # clean out header @@ -178,6 +222,7 @@ class VaultLib: # create the cipher object cipher_class_name = u'Vault{0}'.format(self.cipher_name) + if cipher_class_name in globals() and self.cipher_name in CIPHER_WHITELIST: Cipher = globals()[cipher_class_name] this_cipher = Cipher() @@ -205,10 +250,11 @@ class VaultLib: if not self.cipher_name: raise AnsibleError("the cipher must be set before adding a header") - header = b';'.join([b_HEADER, self.b_version, - to_bytes(self.cipher_name, errors='strict', encoding='utf-8')]) + b_header = HEADER.encode('utf-8') + header = b';'.join([b_header, self.b_version, + to_bytes(self.cipher_name,'utf-8',errors='strict')]) tmpdata = [header] - tmpdata += [b_data[i:i+80] for i in range(0, len(b_data), 80)] + tmpdata += [b_data[i:i + 80] for i in range(0, len(b_data), 80)] tmpdata += [b''] tmpdata = b'\n'.join(tmpdata) @@ -243,6 +289,7 @@ class VaultEditor: def __init__(self, password): self.vault = VaultLib(password) + # TODO: mv shred file stuff to it's own class def _shred_file_custom(self, tmp_path): """"Destroy a file, when shred (core-utils) is not available @@ -277,7 +324,6 @@ class VaultEditor: assert(fh.tell() == file_len) # FIXME remove this assert once we have unittests to check its accuracy os.fsync(fh) - def _shred_file(self, tmp_path): """Securely destroy a decrypted file @@ -335,7 +381,9 @@ class VaultEditor: return # encrypt new data and write out to tmp - enc_data = self.vault.encrypt(tmpdata) + # An existing vaultfile will always be UTF-8, + # so decode to unicode here + enc_data = self.vault.encrypt(tmpdata.decode()) self.write_data(enc_data, tmp_path) # shuffle tmp file into place @@ -345,8 +393,10 @@ class VaultEditor: check_prereqs() + # A file to be encrypted into a vaultfile could be any encoding + # so treat the contents as a byte string. plaintext = self.read_data(filename) - ciphertext = self.vault.encrypt(plaintext) + ciphertext = self.vault.encrypt_bytestring(plaintext) self.write_data(ciphertext, output_file or filename) def decrypt_file(self, filename, output_file=None): @@ -433,15 +483,20 @@ class VaultEditor: return data + # TODO: add docstrings for arg types since this code is picky about that def write_data(self, data, filename, shred=True): """write data to given path - - if shred==True, make sure that the original data is first shredded so - that is cannot be recovered + + :arg data: the encrypted and hexlified data as a utf-8 byte string + :arg filename: filename to save 'data' to. + :arg shred: if shred==True, make sure that the original data is first shredded so + that is cannot be recovered. """ - bytes = to_bytes(data, errors='strict') + # FIXME: do we need this now? data_bytes should always be a utf-8 byte string + b_file_data = to_bytes(data, errors='strict') + if filename == '-': - sys.stdout.write(bytes) + sys.stdout.write(b_file_data) else: if os.path.isfile(filename): if shred: @@ -449,7 +504,7 @@ class VaultEditor: else: os.remove(filename) with open(filename, "wb") as fh: - fh.write(bytes) + fh.write(b_file_data) def shuffle_files(self, src, dest): prev = None @@ -462,7 +517,7 @@ class VaultEditor: # reset permissions if needed if prev is not None: - #TODO: selinux, ACLs, xattr? + # TODO: selinux, ACLs, xattr? os.chmod(dest, prev.st_mode) os.chown(dest, prev.st_uid, prev.st_gid) @@ -488,7 +543,7 @@ class VaultFile(object): _, self.tmpfile = tempfile.mkstemp() - ### TODO: + # TODO: # __del__ can be problematic in python... For this use case, make # VaultFile a context manager instead (implement __enter__ and __exit__) def __del__(self): @@ -496,11 +551,7 @@ class VaultFile(object): os.unlink(self.tmpfile) def is_encrypted(self): - peak = self.filehandle.readline() - if peak.startswith(b_HEADER): - return True - else: - return False + return is_encrypted_file(self.filehandle) def get_decrypted(self): check_prereqs() @@ -627,7 +678,6 @@ class VaultAES256: # make two keys and one iv pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest() - derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, count=10000, prf=pbkdf2_prf) return derivedkey @@ -657,9 +707,7 @@ class VaultAES256: return key1, key2, hexlify(iv) - def encrypt(self, data, password): - salt = os.urandom(32) key1, key2, iv = self.gen_key_initctr(password, salt) @@ -691,20 +739,17 @@ class VaultAES256: return message def decrypt(self, data, password): - # SPLIT SALT, DIGEST, AND DATA data = unhexlify(data) salt, cryptedHmac, cryptedData = data.split(b"\n", 2) salt = unhexlify(salt) cryptedData = unhexlify(cryptedData) - key1, key2, iv = self.gen_key_initctr(password, salt) # EXIT EARLY IF DIGEST DOESN'T MATCH hmacDecrypt = HMAC.new(key2, cryptedData, SHA256) if not self.is_equal(cryptedHmac, to_bytes(hmacDecrypt.hexdigest())): return None - # SET THE COUNTER AND THE CIPHER ctr = Counter.new(128, initial_value=int(iv, 16)) cipher = AES.new(key1, AES.MODE_CTR, counter=ctr) @@ -719,7 +764,6 @@ class VaultAES256: padding_length = decryptedData[-1] decryptedData = decryptedData[:-padding_length] - return decryptedData def is_equal(self, a, b): @@ -746,6 +790,6 @@ class VaultAES256: # Keys could be made bytes later if the code that gets the data is more # naturally byte-oriented CIPHER_MAPPING = { - u'AES': VaultAES, - u'AES256': VaultAES256, - } + u'AES': VaultAES, + u'AES256': VaultAES256, +} diff --git a/lib/ansible/parsing/yaml/constructor.py b/lib/ansible/parsing/yaml/constructor.py index 6c984ad080..a67481a781 100644 --- a/lib/ansible/parsing/yaml/constructor.py +++ b/lib/ansible/parsing/yaml/constructor.py @@ -21,8 +21,13 @@ __metaclass__ = type from yaml.constructor import Constructor, ConstructorError from yaml.nodes import MappingNode + from ansible.parsing.yaml.objects import AnsibleMapping, AnsibleSequence, AnsibleUnicode +from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode + from ansible.vars.unsafe_proxy import wrap_var +from ansible.parsing.vault import VaultLib +from ansible.utils.unicode import to_bytes try: from __main__ import display @@ -32,9 +37,12 @@ except ImportError: class AnsibleConstructor(Constructor): - def __init__(self, file_name=None): + def __init__(self, file_name=None, vault_password=None): + self._vault_password = vault_password self._ansible_file_name = file_name super(AnsibleConstructor, self).__init__() + self._vaults = {} + self._vaults['default'] = VaultLib(password=self._vault_password) def construct_yaml_map(self, node): data = AnsibleMapping() @@ -86,6 +94,20 @@ class AnsibleConstructor(Constructor): return ret + def construct_vault_encrypted_unicode(self, node): + value = self.construct_scalar(node) + ciphertext_data = to_bytes(value) + + if self._vault_password is None: + raise ConstructorError(None, None, + "found vault but no vault password provided", node.start_mark) + + # could pass in a key id here to choose the vault to associate with + vault = self._vaults['default'] + ret = AnsibleVaultEncryptedUnicode(ciphertext_data) + ret.vault = vault + return ret + def construct_yaml_seq(self, node): data = AnsibleSequence() yield data @@ -109,6 +131,7 @@ class AnsibleConstructor(Constructor): return (datasource, line, column) + AnsibleConstructor.add_constructor( u'tag:yaml.org,2002:map', AnsibleConstructor.construct_yaml_map) @@ -132,3 +155,7 @@ AnsibleConstructor.add_constructor( AnsibleConstructor.add_constructor( u'!unsafe', AnsibleConstructor.construct_yaml_unsafe) + +AnsibleConstructor.add_constructor( + u'!vault-encrypted', + AnsibleConstructor.construct_vault_encrypted_unicode) diff --git a/lib/ansible/parsing/yaml/dumper.py b/lib/ansible/parsing/yaml/dumper.py index a8a5015b8e..3ced4b638d 100644 --- a/lib/ansible/parsing/yaml/dumper.py +++ b/lib/ansible/parsing/yaml/dumper.py @@ -23,8 +23,10 @@ import yaml from ansible.compat.six import PY3 from ansible.parsing.yaml.objects import AnsibleUnicode, AnsibleSequence, AnsibleMapping +from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode from ansible.vars.hostvars import HostVars + class AnsibleDumper(yaml.SafeDumper): ''' A simple stub class that allows us to add representers @@ -35,6 +37,10 @@ class AnsibleDumper(yaml.SafeDumper): def represent_hostvars(self, data): return self.represent_dict(dict(data)) +# Note: only want to represent the encrypted data +def represent_vault_encrypted_unicode(self, data): + return self.represent_scalar(u'!vault-encrypted', data._ciphertext.decode(), style='|') + if PY3: represent_unicode = yaml.representer.SafeRepresenter.represent_str else: @@ -60,3 +66,7 @@ AnsibleDumper.add_representer( yaml.representer.SafeRepresenter.represent_dict, ) +AnsibleDumper.add_representer( + AnsibleVaultEncryptedUnicode, + represent_vault_encrypted_unicode, +) diff --git a/lib/ansible/parsing/yaml/loader.py b/lib/ansible/parsing/yaml/loader.py index e8547ff0d1..050e9f0553 100644 --- a/lib/ansible/parsing/yaml/loader.py +++ b/lib/ansible/parsing/yaml/loader.py @@ -30,10 +30,11 @@ from yaml.resolver import Resolver from ansible.parsing.yaml.constructor import AnsibleConstructor if HAVE_PYYAML_C: + class AnsibleLoader(CParser, AnsibleConstructor, Resolver): - def __init__(self, stream, file_name=None): + def __init__(self, stream, file_name=None, vault_password=None): CParser.__init__(self, stream) - AnsibleConstructor.__init__(self, file_name=file_name) + AnsibleConstructor.__init__(self, file_name=file_name, vault_password=vault_password) Resolver.__init__(self) else: from yaml.composer import Composer @@ -42,10 +43,10 @@ else: from yaml.parser import Parser class AnsibleLoader(Reader, Scanner, Parser, Composer, AnsibleConstructor, Resolver): - def __init__(self, stream, file_name=None): + def __init__(self, stream, file_name=None, vault_password=None): Reader.__init__(self, stream) Scanner.__init__(self) Parser.__init__(self) Composer.__init__(self) - AnsibleConstructor.__init__(self, file_name=file_name) + AnsibleConstructor.__init__(self, file_name=file_name, vault_password=vault_password) Resolver.__init__(self) diff --git a/lib/ansible/parsing/yaml/objects.py b/lib/ansible/parsing/yaml/objects.py index 1389fd59f2..60553d5cb4 100644 --- a/lib/ansible/parsing/yaml/objects.py +++ b/lib/ansible/parsing/yaml/objects.py @@ -19,7 +19,11 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import yaml + from ansible.compat.six import text_type +from ansible.errors import AnsibleError +from ansible.utils.unicode import to_bytes class AnsibleBaseYAMLObject(object): @@ -63,3 +67,69 @@ class AnsibleUnicode(AnsibleBaseYAMLObject, text_type): class AnsibleSequence(AnsibleBaseYAMLObject, list): ''' sub class for lists ''' pass + + +# Unicode like object that is not evaluated (decrypted) until it needs to be +# TODO: is there a reason these objects are subclasses for YAMLObject? +class AnsibleVaultEncryptedUnicode(yaml.YAMLObject, AnsibleUnicode): + __UNSAFE__ = True + __ENCRYPTED__ = True + yaml_tag = u'!vault-encrypted' + + @classmethod + def from_plaintext(cls, seq, vault): + if not vault: + raise vault.AnsibleVaultError('Error creating AnsibleVaultEncryptedUnicode, invalid vault (%s) provided' % vault) + + ciphertext = vault.encrypt(seq) + avu = cls(ciphertext) + avu.vault = vault + return avu + + def __init__(self, ciphertext): + '''A AnsibleUnicode with a Vault attribute that can decrypt it. + + ciphertext is a byte string (str on PY2, bytestring on PY3). + + The .data atttribute is a property that returns the decrypted plaintext + of the ciphertext as a PY2 unicode or PY3 string object. + ''' + super(AnsibleVaultEncryptedUnicode, self).__init__() + + # after construction, calling code has to set the .vault attribute to a vaultlib object + self.vault = None + self._ciphertext = to_bytes(ciphertext) + + @property + def data(self): + if not self.vault: + # FIXME: raise exception? + return self._ciphertext + return self.vault.decrypt(self._ciphertext).decode() + + @data.setter + def data(self, value): + self._ciphertext = value + + def __repr__(self): + return 'AnsibleVaultEncryptedUnicode(%s)' % self._ciphertext + + # Compare a regular str/text_type with the decrypted hypertext + def __eq__(self, other): + if self.vault: + return other == self.data + return False + + def __hash__(self): + return id(self) + + def __ne__(self, other): + if self.vault: + return other != self.data + return True + + def __str__(self): + return str(self.data) + + def __unicode__(self): + return unicode(self.data) diff --git a/lib/ansible/template/__init__.py b/lib/ansible/template/__init__.py index 73fe50fe8a..f55cab6eec 100644 --- a/lib/ansible/template/__init__.py +++ b/lib/ansible/template/__init__.py @@ -227,7 +227,7 @@ class Templar: def _clean_data(self, orig_data): ''' remove jinja2 template tags from a string ''' - if not isinstance(orig_data, string_types): + if not isinstance(orig_data, string_types) or hasattr(orig_data, '__ENCRYPTED__'): return orig_data with contextlib.closing(StringIO(orig_data)) as data: @@ -281,7 +281,7 @@ class Templar: ''' Templates (possibly recursively) any given data as input. If convert_bare is set to True, the given data will be wrapped as a jinja2 variable ('{{foo}}') - before being sent through the template engine. + before being sent through the template engine. ''' if fail_on_undefined is None: @@ -304,7 +304,6 @@ class Templar: result = variable if self._contains_vars(variable): - # Check to see if the string we are trying to render is just referencing a single # var. In this case we don't want to accidentally change the type of the variable # to a string by using the jinja template renderer. We just want to pass it. @@ -438,7 +437,6 @@ class Templar: raise AnsibleError("lookup plugin (%s) not found" % name) def _do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=True, fail_on_undefined=None, overrides=None): - # For preserving the number of input newlines in the output (used # later in this method) data_newlines = _count_newlines_from_end(data) @@ -517,7 +515,6 @@ class Templar: res_newlines = _count_newlines_from_end(res) if data_newlines > res_newlines: res += '\n' * (data_newlines - res_newlines) - return res except (UndefinedError, AnsibleUndefinedVariable) as e: if fail_on_undefined: diff --git a/test/integration/Makefile b/test/integration/Makefile index bf861f641d..e4c3dcc446 100644 --- a/test/integration/Makefile +++ b/test/integration/Makefile @@ -162,6 +162,8 @@ test_vault: setup ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --list-hosts -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --syntax-check -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) ansible-playbook test_vault.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) + ansible-playbook test_vault_embedded.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) --syntax-check -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) + ansible-playbook test_vault_embedded.yml -i $(INVENTORY) $(CREDENTIALS_ARG) -v $(TEST_FLAGS) --vault-password-file $(VAULT_PASSWORD_FILE) -e outputdir=$(TEST_DIR) -e @$(VARS_FILE) # test_delegate_to does not work unless we have permission to ssh to localhost. # Would take some more effort on our test systems to implement that -- probably diff --git a/test/integration/roles/test_vault/tasks/main.yml b/test/integration/roles/test_vault/tasks/main.yml new file mode 100644 index 0000000000..2c7591a957 --- /dev/null +++ b/test/integration/roles/test_vault/tasks/main.yml @@ -0,0 +1,7 @@ + +- assert: + that: + - 'secret_var == "secret"' + + +- copy: src=vault-secret.txt dest={{output_dir}}/secret.txt diff --git a/test/integration/roles/test_vault/vars/main.yml b/test/integration/roles/test_vault/vars/main.yml new file mode 100644 index 0000000000..cfac107aed --- /dev/null +++ b/test/integration/roles/test_vault/vars/main.yml @@ -0,0 +1,9 @@ +$ANSIBLE_VAULT;1.1;AES256 +31626536666232643662346539623662393436386162643439643434656231343435653936343235 +6139346364396166336636383734333430373763336434310a303137623539653939336132626234 +64613232396532313731313935333433353330666466646663303233323331636234326464643166 +6538653264636166370a613161313064653566323037393962643032353230396536313865326362 +34396262303130326632623162623230346238633932393938393766313036643835613936356233 +33323730373331386337353339613165373064323134343930333031623036326164353534646631 +31313963666234623731316238656233396638643331306231373539643039383434373035306233 +30386230363730643561 diff --git a/test/integration/roles/test_vault_embedded/tasks/main.yml b/test/integration/roles/test_vault_embedded/tasks/main.yml new file mode 100644 index 0000000000..4dda2acbcd --- /dev/null +++ b/test/integration/roles/test_vault_embedded/tasks/main.yml @@ -0,0 +1,14 @@ +--- +- name: Assert that a embedded vault of a string with no newline works + assert: + that: + - '"{{ vault_encrypted_one_line_var }}" == "Setec Astronomy"' + +- name: Assert that a multi line embedded vault works, including new line + assert: + that: + - vault_encrypted_var == "Setec Astronomy\n" + +# TODO: add a expected fail here +# - debug: var=vault_encrypted_one_line_var_with_embedded_template + diff --git a/test/integration/roles/test_vault_embedded/vars/main.yml b/test/integration/roles/test_vault_embedded/vars/main.yml new file mode 100644 index 0000000000..e9c568eac1 --- /dev/null +++ b/test/integration/roles/test_vault_embedded/vars/main.yml @@ -0,0 +1,17 @@ +# If you use normal 'ansible-vault create' or edit, files always have at least one new line +# so c&p from a vault encrypted that wasn't specifically created sans new line ends up with one. +# (specifically created, as in 'echo -n "just one line" > my_secret.yml' +vault_encrypted_var: !vault-encrypted | + $ANSIBLE_VAULT;1.1;AES256 + 66386439653236336462626566653063336164663966303231363934653561363964363833313662 + 6431626536303530376336343832656537303632313433360a626438346336353331386135323734 + 62656361653630373231613662633962316233633936396165386439616533353965373339616234 + 3430613539666330390a313736323265656432366236633330313963326365653937323833366536 + 34623731376664623134383463316265643436343438623266623965636363326136 +vault_encrypted_one_line_var: !vault-encrypted | + $ANSIBLE_VAULT;1.1;AES256 + 33363965326261303234626463623963633531343539616138316433353830356566396130353436 + 3562643163366231316662386565383735653432386435610a306664636137376132643732393835 + 63383038383730306639353234326630666539346233376330303938323639306661313032396437 + 6233623062366136310a633866373936313238333730653739323461656662303864663666653563 + 3138 diff --git a/test/integration/test_vault.yml b/test/integration/test_vault.yml index a3b2498766..65b5fa5c1a 100644 --- a/test/integration/test_vault.yml +++ b/test/integration/test_vault.yml @@ -1,13 +1,4 @@ - hosts: testhost - vars_files: - - vars/test_var_encrypted.yml - gather_facts: False - - tasks: - - assert: - that: - - 'secret_var == "secret"' - - - copy: src=vault-secret.txt dest={{output_dir}}/secret.txt - + roles: + - { role: test_vault, tags: test_vault} diff --git a/test/integration/test_vault_embedded.yml b/test/integration/test_vault_embedded.yml new file mode 100644 index 0000000000..ee9739f8bb --- /dev/null +++ b/test/integration/test_vault_embedded.yml @@ -0,0 +1,4 @@ +- hosts: testhost + gather_facts: False + roles: + - { role: test_vault_embedded, tags: test_vault_embedded} diff --git a/test/units/mock/yaml_helper.py b/test/units/mock/yaml_helper.py new file mode 100644 index 0000000000..7f78780721 --- /dev/null +++ b/test/units/mock/yaml_helper.py @@ -0,0 +1,120 @@ +import io +import yaml + +from ansible.parsing.yaml.loader import AnsibleLoader +from ansible.parsing.yaml.dumper import AnsibleDumper + +from ansible.compat.six import PY3 + +class YamlTestUtils(object): + """Mixin class to combine with a unittest.TestCase subclass.""" + def _loader(self, stream): + """Vault related tests will want to override this. + + Vault cases should setup a AnsibleLoader that has the vault password.""" + return AnsibleLoader(stream) + + def _dump_stream(self, obj, stream, dumper=None): + """Dump to a py2-unicode or py3-string stream.""" + if PY3: + return yaml.dump(obj, stream, Dumper=dumper) + else: + return yaml.dump(obj, stream, Dumper=dumper, encoding=None) + + def _dump_string(self, obj, dumper=None): + """Dump to a py2-unicode or py3-string""" + if PY3: + return yaml.dump(obj, Dumper=dumper) + else: + return yaml.dump(obj, Dumper=dumper, encoding=None) + + def _dump_load_cycle(self, obj): + # Each pass though a dump or load revs the 'generation' + # obj to yaml string + string_from_object_dump = self._dump_string(obj, dumper=AnsibleDumper) + + # wrap a stream/file like StringIO around that yaml + stream_from_object_dump = io.StringIO(string_from_object_dump) + loader = self._loader(stream_from_object_dump) + # load the yaml stream to create a new instance of the object (gen 2) + obj_2 = loader.get_data() + + # dump the gen 2 objects directory to strings + string_from_object_dump_2 = self._dump_string(obj_2, + dumper=AnsibleDumper) + + # The gen 1 and gen 2 yaml strings + self.assertEquals(string_from_object_dump, string_from_object_dump_2) + # the gen 1 (orig) and gen 2 py object + self.assertEquals(obj, obj_2) + + # again! gen 3... load strings into py objects + stream_3 = io.StringIO(string_from_object_dump_2) + loader_3 = self._loader(stream_3) + obj_3 = loader_3.get_data() + + string_from_object_dump_3 = self._dump_string(obj_3, dumper=AnsibleDumper) + + self.assertEquals(obj, obj_3) + # should be transitive, but... + self.assertEquals(obj_2, obj_3) + self.assertEquals(string_from_object_dump, string_from_object_dump_3) + + def _old_dump_load_cycle(self, obj): + '''Dump the passed in object to yaml, load it back up, dump again, compare.''' + stream = io.StringIO() + + yaml_string = self._dump_string(obj, dumper=AnsibleDumper) + self._dump_stream(obj, stream, dumper=AnsibleDumper) + + yaml_string_from_stream = stream.getvalue() + + # reset stream + stream.seek(0) + + loader = self._loader(stream) + # loader = AnsibleLoader(stream, vault_password=self.vault_password) + obj_from_stream = loader.get_data() + + stream_from_string = io.StringIO(yaml_string) + loader2 = self._loader(stream_from_string) + # loader2 = AnsibleLoader(stream_from_string, vault_password=self.vault_password) + obj_from_string = loader2.get_data() + + stream_obj_from_stream = io.StringIO() + stream_obj_from_string = io.StringIO() + + if PY3: + yaml.dump(obj_from_stream, stream_obj_from_stream, Dumper=AnsibleDumper) + yaml.dump(obj_from_stream, stream_obj_from_string, Dumper=AnsibleDumper) + else: + yaml.dump(obj_from_stream, stream_obj_from_stream, Dumper=AnsibleDumper, encoding=None) + yaml.dump(obj_from_stream, stream_obj_from_string, Dumper=AnsibleDumper, encoding=None) + + yaml_string_stream_obj_from_stream = stream_obj_from_stream.getvalue() + yaml_string_stream_obj_from_string = stream_obj_from_string.getvalue() + + stream_obj_from_stream.seek(0) + stream_obj_from_string.seek(0) + + if PY3: + yaml_string_obj_from_stream = yaml.dump(obj_from_stream, Dumper=AnsibleDumper) + yaml_string_obj_from_string = yaml.dump(obj_from_string, Dumper=AnsibleDumper) + else: + yaml_string_obj_from_stream = yaml.dump(obj_from_stream, Dumper=AnsibleDumper, encoding=None) + yaml_string_obj_from_string = yaml.dump(obj_from_string, Dumper=AnsibleDumper, encoding=None) + + assert yaml_string == yaml_string_obj_from_stream + assert yaml_string == yaml_string_obj_from_stream == yaml_string_obj_from_string + assert yaml_string == yaml_string_obj_from_stream == yaml_string_obj_from_string == yaml_string_stream_obj_from_stream == yaml_string_stream_obj_from_string + assert obj == obj_from_stream + assert obj == obj_from_string + assert obj == yaml_string_obj_from_stream + assert obj == yaml_string_obj_from_string + assert obj == obj_from_stream == obj_from_string == yaml_string_obj_from_stream == yaml_string_obj_from_string + return {'obj': obj, + 'yaml_string': yaml_string, + 'yaml_string_from_stream': yaml_string_from_stream, + 'obj_from_stream': obj_from_stream, + 'obj_from_string': obj_from_string, + 'yaml_string_obj_from_string': yaml_string_obj_from_string} diff --git a/test/units/parsing/test_data_loader.py b/test/units/parsing/test_data_loader.py index 19e3107447..0182d40d3e 100644 --- a/test/units/parsing/test_data_loader.py +++ b/test/units/parsing/test_data_loader.py @@ -20,14 +20,12 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type from six import PY3 -from yaml.scanner import ScannerError from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, mock_open from ansible.errors import AnsibleParserError from ansible.parsing.dataloader import DataLoader -from ansible.parsing.yaml.objects import AnsibleMapping class TestDataLoader(unittest.TestCase): @@ -85,6 +83,6 @@ class TestDataLoaderWithVault(unittest.TestCase): else: builtins_name = '__builtin__' - with patch(builtins_name + '.open', mock_open(read_data=vaulted_data)): + with patch(builtins_name + '.open', mock_open(read_data=vaulted_data.encode('utf-8'))): output = self._loader.load_from_file('dummy_vault.txt') self.assertEqual(output, dict(foo='bar')) diff --git a/test/units/parsing/vault/test_vault.py b/test/units/parsing/vault/test_vault.py index 1b48b07169..4fe0fff5e0 100644 --- a/test/units/parsing/vault/test_vault.py +++ b/test/units/parsing/vault/test_vault.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # (c) 2012-2014, Michael DeHaan # # This file is part of Ansible @@ -19,14 +20,12 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import getpass -import os -import shutil -import time -import tempfile import six -from binascii import unhexlify +import binascii +import io +import os + from binascii import hexlify from nose.plugins.skip import SkipTest @@ -35,6 +34,7 @@ from ansible.utils.unicode import to_bytes, to_unicode from ansible import errors from ansible.parsing.vault import VaultLib +from ansible.parsing import vault # Counter import fails for 2.0.1, requires >= 2.6.1 from pip try: @@ -57,6 +57,150 @@ try: except ImportError: HAS_AES = False + +class TestVaultIsEncrypted(unittest.TestCase): + def test_utf8_not_encrypted(self): + b_data = "foobar".encode('utf8') + self.assertFalse(vault.is_encrypted(b_data)) + + def test_utf8_encrypted(self): + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") + b_data = data.encode('utf8') + self.assertTrue(vault.is_encrypted(b_data)) + + def test_bytes_not_encrypted(self): + b_data = b"foobar" + self.assertFalse(vault.is_encrypted(b_data)) + + def test_bytes_encrypted(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" + hexlify(b"ansible") + self.assertTrue(vault.is_encrypted(b_data)) + + def test_unicode_not_encrypted_py3(self): + if not six.PY3: + raise SkipTest() + data = u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " + self.assertRaises(TypeError, vault.is_encrypted, data) + + def test_unicode_not_encrypted_py2(self): + if six.PY3: + raise SkipTest() + data = u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " + # py2 will take a unicode string, but that should always fails + self.assertFalse(vault.is_encrypted(data)) + + def test_unicode_is_encrypted_py3(self): + if not six.PY3: + raise SkipTest() + data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") + # should still be a type error + self.assertRaises(TypeError, vault.is_encrypted, data) + + def test_unicode_is_encrypted_py2(self): + if six.PY3: + raise SkipTest() + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") + # THis works, but arguably shouldn't... + self.assertTrue(vault.is_encrypted(data)) + + +class TestVaultIsEncryptedFile(unittest.TestCase): + def test_utf8_not_encrypted(self): + b_data = "foobar".encode('utf8') + b_data_fo = io.BytesIO(b_data) + self.assertFalse(vault.is_encrypted_file(b_data_fo)) + + def test_utf8_encrypted(self): + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") + b_data = data.encode('utf8') + b_data_fo = io.BytesIO(b_data) + self.assertTrue(vault.is_encrypted_file(b_data_fo)) + + def test_bytes_not_encrypted(self): + b_data = b"foobar" + b_data_fo = io.BytesIO(b_data) + self.assertFalse(vault.is_encrypted_file(b_data_fo)) + + def test_bytes_encrypted(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" + hexlify(b"ansible") + b_data_fo = io.BytesIO(b_data) + self.assertTrue(vault.is_encrypted_file(b_data_fo)) + + +class TestVaultCipherAes256(unittest.TestCase): + def test(self): + vault_cipher = vault.VaultAES256() + self.assertIsInstance(vault_cipher, vault.VaultAES256) + + # TODO: tag these as slow tests + def test_create_key(self): + vault_cipher = vault.VaultAES256() + password = 'hunter42' + b_salt = os.urandom(32) + b_key = vault_cipher.create_key(password=password, salt=b_salt, keylength=32, ivlength=16) + self.assertIsInstance(b_key, six.binary_type) + + def test_create_key_known(self): + vault_cipher = vault.VaultAES256() + password = 'hunter42' + + # A fixed salt + b_salt = b'q' * 32 # q is the most random letter. + b_key = vault_cipher.create_key(password=password, salt=b_salt, keylength=32, ivlength=16) + self.assertIsInstance(b_key, six.binary_type) + + # verify we get the same answer + # we could potentially run a few iterations of this and time it to see if it's roughly constant time + # and or that it exceeds some minimal time, but that would likely cause unreliable fails, esp in CI + b_key_2 = vault_cipher.create_key(password=password, salt=b_salt, keylength=32, ivlength=16) + self.assertIsInstance(b_key, six.binary_type) + self.assertEqual(b_key, b_key_2) + + def test_is_equal_is_equal(self): + vault_cipher = vault.VaultAES256() + res = vault_cipher.is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwxyz') + self.assertTrue(res) + + def test_is_equal_unequal_length(self): + vault_cipher = vault.VaultAES256() + res = vault_cipher.is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwx and sometimes y') + self.assertFalse(res) + + def test_is_equal_not_equal(self): + vault_cipher = vault.VaultAES256() + res = vault_cipher.is_equal(b'abcdefghijklmnopqrstuvwxyz', b'AbcdefghijKlmnopQrstuvwxZ') + self.assertFalse(res) + + def test_is_equal_empty(self): + vault_cipher = vault.VaultAES256() + res = vault_cipher.is_equal(b'', b'') + self.assertTrue(res) + + # NOTE: I'm not really sure what the method should do if it doesn't get bytes, + # but this at least sees if it explodes (maybe it should?) + def test_is_equal_unicode_py3(self): + if not six.PY3: + raise SkipTest + vault_cipher = vault.VaultAES256() + self.assertRaises(TypeError, vault_cipher.is_equal, + u'私はガラスを食べられます。それは私を傷つけません。', + u'私はガラスを食べられます。それは私を傷つけません。') + + def test_is_equal_unicode_py2(self): + if not six.PY2: + raise SkipTest + vault_cipher = vault.VaultAES256() + res = vault_cipher.is_equal(u'私はガラスを食べられます。それは私を傷つけません。', + u'私はガラスを食べられます。それは私を傷つけません。') + self.assertTrue(res) + + def test_is_equal_unicode_different(self): + vault_cipher = vault.VaultAES256() + res = vault_cipher.is_equal(u'私はガラスを食べられます。それは私を傷つけません。', + u'Pot să mănânc sticlă și ea nu mă rănește.') + self.assertFalse(res) + + class TestVaultLib(unittest.TestCase): def test_methods_exist(self): @@ -69,10 +213,24 @@ class TestVaultLib(unittest.TestCase): for slot in slots: assert hasattr(v, slot), "VaultLib is missing the %s method" % slot + def test_encrypt(self): + v = VaultLib(password='the_unit_test_password') + plaintext = u'Some text to encrypt.' + ciphertext = v.encrypt(plaintext) + + self.assertIsInstance(ciphertext, (bytes, str)) + # TODO: assert something... + def test_is_encrypted(self): v = VaultLib(None) - assert not v.is_encrypted(u"foobar"), "encryption check on plaintext failed" + assert not v.is_encrypted("foobar".encode('utf-8')), "encryption check on plaintext failed" data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") + assert v.is_encrypted(data.encode('utf-8')), "encryption check on headered text failed" + + def test_is_encrypted_bytes(self): + v = VaultLib(None) + assert not v.is_encrypted(b"foobar"), "encryption check on plaintext failed" + data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" + hexlify(b"ansible") assert v.is_encrypted(data), "encryption check on headered text failed" def test_format_output(self): @@ -115,35 +273,82 @@ class TestVaultLib(unittest.TestCase): raise SkipTest v = VaultLib('ansible') v.cipher_name = 'AES256' - enc_data = v.encrypt(b"foobar") + plaintext = "foobar" + enc_data = v.encrypt(plaintext) dec_data = v.decrypt(enc_data) assert enc_data != b"foobar", "encryption failed" assert dec_data == b"foobar", "decryption failed" + def test_encrypt_decrypt_aes256_existing_vault(self): + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: + raise SkipTest + v = VaultLib('test-vault-password') + v.cipher_name = 'AES256' + plaintext = b"Setec Astronomy" + enc_data = '''$ANSIBLE_VAULT;1.1;AES256 +33363965326261303234626463623963633531343539616138316433353830356566396130353436 +3562643163366231316662386565383735653432386435610a306664636137376132643732393835 +63383038383730306639353234326630666539346233376330303938323639306661313032396437 +6233623062366136310a633866373936313238333730653739323461656662303864663666653563 +3138''' + + dec_data = v.decrypt(enc_data) + assert dec_data == plaintext, "decryption failed" + + def test_encrypt_decrypt_aes256_bad_hmac(self): + # FIXME This test isn't working quite yet. + raise SkipTest + + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: + raise SkipTest + v = VaultLib('test-vault-password') + v.cipher_name = 'AES256' + # plaintext = "Setec Astronomy" + enc_data = '''$ANSIBLE_VAULT;1.1;AES256 +33363965326261303234626463623963633531343539616138316433353830356566396130353436 +3562643163366231316662386565383735653432386435610a306664636137376132643732393835 +63383038383730306639353234326630666539346233376330303938323639306661313032396437 +6233623062366136310a633866373936313238333730653739323461656662303864663666653563 +3138''' + b_data = to_bytes(enc_data, errors='strict', encoding='utf-8') + b_data = v._split_header(b_data) + foo = binascii.unhexlify(b_data) + lines = foo.splitlines() + # line 0 is salt, line 1 is hmac, line 2+ is ciphertext + b_salt = lines[0] + b_hmac = lines[1] + b_ciphertext_data = b'\n'.join(lines[2:]) + + b_ciphertext = binascii.unhexlify(b_ciphertext_data) + # b_orig_ciphertext = b_ciphertext[:] + + # now muck with the text + # b_munged_ciphertext = b_ciphertext[:10] + b'\x00' + b_ciphertext[11:] + # b_munged_ciphertext = b_ciphertext + # assert b_orig_ciphertext != b_munged_ciphertext + + b_ciphertext_data = binascii.hexlify(b_ciphertext) + b_payload = b'\n'.join([b_salt, b_hmac, b_ciphertext_data]) + # reformat + b_invalid_ciphertext = v._format_output(b_payload) + + # assert we throw an error + v.decrypt(b_invalid_ciphertext) + def test_encrypt_encrypted(self): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: raise SkipTest v = VaultLib('ansible') v.cipher_name = 'AES' data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(six.b("ansible")) - error_hit = False - try: - enc_data = v.encrypt(data) - except errors.AnsibleError as e: - error_hit = True - assert error_hit, "No error was thrown when trying to encrypt data with a header" + self.assertRaises(errors.AnsibleError, v.encrypt, data,) def test_decrypt_decrypted(self): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: raise SkipTest v = VaultLib('ansible') data = "ansible" - error_hit = False - try: - dec_data = v.decrypt(data) - except errors.AnsibleError as e: - error_hit = True - assert error_hit, "No error was thrown when trying to decrypt data without a header" + self.assertRaises(errors.AnsibleError, v.decrypt, data) def test_cipher_not_set(self): # not setting the cipher should default to AES256 @@ -151,10 +356,5 @@ class TestVaultLib(unittest.TestCase): raise SkipTest v = VaultLib('ansible') data = "ansible" - error_hit = False - try: - enc_data = v.encrypt(data) - except errors.AnsibleError as e: - error_hit = True - assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set" - assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name + v.encrypt(data) + self.assertEquals(v.cipher_name, "AES256") diff --git a/test/units/parsing/vault/test_vault_editor.py b/test/units/parsing/vault/test_vault_editor.py index e943b00868..0fa09c8537 100644 --- a/test/units/parsing/vault/test_vault_editor.py +++ b/test/units/parsing/vault/test_vault_editor.py @@ -22,13 +22,8 @@ __metaclass__ = type #!/usr/bin/env python import sys -import getpass import os -import shutil -import time import tempfile -from binascii import unhexlify -from binascii import hexlify from nose.plugins.skip import SkipTest from ansible.compat.tests import unittest diff --git a/test/units/parsing/yaml/test_dumper.py b/test/units/parsing/yaml/test_dumper.py new file mode 100644 index 0000000000..540979baef --- /dev/null +++ b/test/units/parsing/yaml/test_dumper.py @@ -0,0 +1,64 @@ +# coding: utf-8 +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import io +import yaml + +try: + from _yaml import ParserError +except ImportError: + from yaml.parser import ParserError + +from ansible.parsing.yaml import dumper +from ansible.parsing.yaml.loader import AnsibleLoader + +from ansible.compat.tests import unittest +from ansible.parsing.yaml import objects +from ansible.parsing import vault + +from units.mock.yaml_helper import YamlTestUtils + +class TestAnsibleDumper(unittest.TestCase, YamlTestUtils): + def setUp(self): + self.vault_password = "hunter42" + self.good_vault = vault.VaultLib(self.vault_password) + self.vault = self.good_vault + self.stream = self._build_stream() + self.dumper = dumper.AnsibleDumper + + def _build_stream(self,yaml_text=None): + text = yaml_text or u'' + stream = io.StringIO(text) + return stream + + def _loader(self, stream): + return AnsibleLoader(stream, vault_password=self.vault_password) + + def test(self): + plaintext = 'This is a string we are going to encrypt.' + avu = objects.AnsibleVaultEncryptedUnicode.from_plaintext(plaintext, vault=self.vault) + + yaml_out = self._dump_string(avu, dumper=self.dumper) + stream = self._build_stream(yaml_out) + loader = self._loader(stream) + + data_from_yaml = loader.get_single_data() + + self.assertEquals(plaintext, data_from_yaml.data) diff --git a/test/units/parsing/yaml/test_loader.py b/test/units/parsing/yaml/test_loader.py index 253ba4d591..1e224f6704 100644 --- a/test/units/parsing/yaml/test_loader.py +++ b/test/units/parsing/yaml/test_loader.py @@ -26,20 +26,26 @@ from six import text_type, binary_type from collections import Sequence, Set, Mapping from ansible.compat.tests import unittest -from ansible.compat.tests.mock import patch +from ansible import errors from ansible.parsing.yaml.loader import AnsibleLoader +from ansible.parsing import vault +from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode +from ansible.parsing.yaml.dumper import AnsibleDumper +from ansible.utils.unicode import to_bytes + +from units.mock.yaml_helper import YamlTestUtils try: from _yaml import ParserError except ImportError: from yaml.parser import ParserError - class NameStringIO(StringIO): """In py2.6, StringIO doesn't let you set name because a baseclass has it as readonly property""" name = None + def __init__(self, *args, **kwargs): super(NameStringIO, self).__init__(*args, **kwargs) @@ -159,6 +165,124 @@ class TestAnsibleLoaderBasic(unittest.TestCase): self.assertEqual(data[0][u'baz'].ansible_pos, ('myfile.yml', 2, 9)) +class TestAnsibleLoaderVault(unittest.TestCase, YamlTestUtils): + def setUp(self): + self.vault_password = "hunter42" + self.vault = vault.VaultLib(self.vault_password) + + def test_wrong_password(self): + plaintext = u"Ansible" + bob_password = "this is a different password" + + bobs_vault = vault.VaultLib(bob_password) + + ciphertext = bobs_vault.encrypt(plaintext) + + try: + self.vault.decrypt(ciphertext) + except Exception as e: + self.assertIsInstance(e, errors.AnsibleError) + self.assertEqual(e.message, 'Decryption failed') + + def _encrypt_plaintext(self, plaintext): + # Construct a yaml repr of a vault by hand + vaulted_var_bytes = self.vault.encrypt(plaintext) + + # add yaml tag + vaulted_var = vaulted_var_bytes.decode() + lines = vaulted_var.splitlines() + lines2 = [] + for line in lines: + lines2.append(' %s' % line) + + vaulted_var = '\n'.join(lines2) + tagged_vaulted_var = u"""!vault-encrypted |\n%s""" % vaulted_var + return tagged_vaulted_var + + def _build_stream(self, yaml_text): + stream = NameStringIO(yaml_text) + stream.name = 'my.yml' + return stream + + def _loader(self, stream): + return AnsibleLoader(stream, vault_password=self.vault_password) + + def _load_yaml(self, yaml_text, password): + stream = self._build_stream(yaml_text) + loader = self._loader(stream) + + data_from_yaml = loader.get_single_data() + + return data_from_yaml + + def test_dump_load_cycle(self): + avu = AnsibleVaultEncryptedUnicode.from_plaintext('The plaintext for test_dump_load_cycle.', vault=self.vault) + self._dump_load_cycle(avu) + + def test_embedded_vault_from_dump(self): + avu = AnsibleVaultEncryptedUnicode.from_plaintext('setec astronomy', vault=self.vault) + blip = {'stuff1': [{'a dict key': 24}, + {'shhh-ssh-secrets': avu, + 'nothing to see here': 'move along'}], + 'another key': 24.1} + + blip = ['some string', 'another string', avu] + stream = NameStringIO() + + self._dump_stream(blip, stream, dumper=AnsibleDumper) + + print(stream.getvalue()) + stream.seek(0) + + stream.seek(0) + + loader = self._loader(stream) + + data_from_yaml = loader.get_data() + stream2 = NameStringIO(u'') + # verify we can dump the object again + self._dump_stream(data_from_yaml, stream2, dumper=AnsibleDumper) + + def test_embedded_vault(self): + plaintext_var = u"""This is the plaintext string.""" + tagged_vaulted_var = self._encrypt_plaintext(plaintext_var) + another_vaulted_var = self._encrypt_plaintext(plaintext_var) + + different_var = u"""A different string that is not the same as the first one.""" + different_vaulted_var = self._encrypt_plaintext(different_var) + + yaml_text = u"""---\nwebster: daniel\noed: oxford\nthe_secret: %s\nanother_secret: %s\ndifferent_secret: %s""" % (tagged_vaulted_var, another_vaulted_var, different_vaulted_var) + + data_from_yaml = self._load_yaml(yaml_text, self.vault_password) + vault_string = data_from_yaml['the_secret'] + + self.assertEquals(plaintext_var, data_from_yaml['the_secret']) + + test_dict = {} + test_dict[vault_string] = 'did this work?' + + self.assertEquals(vault_string.data, vault_string) + + # This looks weird and useless, but the object in question has a custom __eq__ + self.assertEquals(vault_string, vault_string) + + another_vault_string = data_from_yaml['another_secret'] + different_vault_string = data_from_yaml['different_secret'] + + self.assertEquals(vault_string, another_vault_string) + self.assertNotEquals(vault_string, different_vault_string) + + # More testing of __eq__/__ne__ + self.assertTrue('some string' != vault_string) + self.assertNotEquals('some string', vault_string) + + # Note this is a compare of the str/unicode of these, they are diferent types + # so we want to test self == other, and other == self etc + self.assertEquals(plaintext_var, vault_string) + self.assertEquals(vault_string, plaintext_var) + self.assertFalse(plaintext_var != vault_string) + self.assertFalse(vault_string != plaintext_var) + class TestAnsibleLoaderPlay(unittest.TestCase): def setUp(self): @@ -242,7 +366,7 @@ class TestAnsibleLoaderPlay(unittest.TestCase): def check_vars(self): # Numbers don't have line/col information yet - #self.assertEqual(self.data[0][u'vars'][u'number'].ansible_pos, (self.play_filename, 4, 21)) + # self.assertEqual(self.data[0][u'vars'][u'number'].ansible_pos, (self.play_filename, 4, 21)) self.assertEqual(self.data[0][u'vars'][u'string'].ansible_pos, (self.play_filename, 5, 29)) self.assertEqual(self.data[0][u'vars'][u'utf8_string'].ansible_pos, (self.play_filename, 6, 34)) @@ -255,8 +379,8 @@ class TestAnsibleLoaderPlay(unittest.TestCase): self.assertEqual(self.data[0][u'vars'][u'list'][0].ansible_pos, (self.play_filename, 11, 25)) self.assertEqual(self.data[0][u'vars'][u'list'][1].ansible_pos, (self.play_filename, 12, 25)) # Numbers don't have line/col info yet - #self.assertEqual(self.data[0][u'vars'][u'list'][2].ansible_pos, (self.play_filename, 13, 25)) - #self.assertEqual(self.data[0][u'vars'][u'list'][3].ansible_pos, (self.play_filename, 14, 25)) + # self.assertEqual(self.data[0][u'vars'][u'list'][2].ansible_pos, (self.play_filename, 13, 25)) + # self.assertEqual(self.data[0][u'vars'][u'list'][3].ansible_pos, (self.play_filename, 14, 25)) def check_tasks(self): # diff --git a/test/units/parsing/yaml/test_objects.py b/test/units/parsing/yaml/test_objects.py new file mode 100644 index 0000000000..a7af0a19bc --- /dev/null +++ b/test/units/parsing/yaml/test_objects.py @@ -0,0 +1,129 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +# Copyright 2016, Adrian Likins + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.compat.tests import unittest + + +from ansible.parsing import vault +from ansible.parsing.yaml.loader import AnsibleLoader + +# module under test +from ansible.parsing.yaml import objects + +from units.mock.yaml_helper import YamlTestUtils + + +class TestAnsibleVaultUnicodeNoVault(unittest.TestCase, YamlTestUtils): + def test_empty_init(self): + self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode) + + def test_empty_string_init(self): + seq = ''.encode('utf8') + self.assert_values(seq) + + def test_empty_byte_string_init(self): + seq = b'' + self.assert_values(seq) + + def _assert_values(self, avu, seq): + self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode) + self.assertTrue(avu.vault is None) + # AnsibleVaultEncryptedUnicode without a vault should never == any string + self.assertNotEquals(avu, seq) + + def assert_values(self, seq): + avu = objects.AnsibleVaultEncryptedUnicode(seq) + self._assert_values(avu, seq) + + def test_single_char(self): + seq = 'a'.encode('utf8') + self.assert_values(seq) + + def test_string(self): + seq = 'some letters' + self.assert_values(seq) + + def test_byte_string(self): + seq = 'some letters'.encode('utf8') + self.assert_values(seq) + + +class TestAnsibleVaultEncryptedUnicode(unittest.TestCase, YamlTestUtils): + def setUp(self): + self.vault_password = "hunter42" + self.good_vault = vault.VaultLib(self.vault_password) + + self.wrong_vault_password = 'not-hunter42' + self.wrong_vault = vault.VaultLib(self.wrong_vault_password) + + self.vault = self.good_vault + + def _loader(self, stream): + return AnsibleLoader(stream, vault_password=self.vault_password) + + def test_dump_load_cycle(self): + aveu = self._from_plaintext('the test string for TestAnsibleVaultEncryptedUnicode.test_dump_load_cycle') + self._dump_load_cycle(aveu) + + def assert_values(self, avu, seq): + self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode) + + self.assertEquals(avu, seq) + self.assertTrue(avu.vault is self.vault) + self.assertIsInstance(avu.vault, vault.VaultLib) + + def _from_plaintext(self, seq): + return objects.AnsibleVaultEncryptedUnicode.from_plaintext(seq, vault=self.vault) + + def _from_ciphertext(self, ciphertext): + avu = objects.AnsibleVaultEncryptedUnicode(ciphertext) + avu.vault = self.vault + return avu + + def test_empty_init(self): + self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode) + + def test_empty_string_init_from_plaintext(self): + seq = '' + avu = self._from_plaintext(seq) + self.assert_values(avu,seq) + + def test_empty_unicode_init_from_plaintext(self): + seq = u'' + avu = self._from_plaintext(seq) + self.assert_values(avu,seq) + + def test_string_from_plaintext(self): + seq = 'some letters' + avu = self._from_plaintext(seq) + self.assert_values(avu,seq) + + def test_unicode_from_plaintext(self): + seq = u'some letters' + avu = self._from_plaintext(seq) + self.assert_values(avu,seq) + + # TODO/FIXME: make sure bad password fails differently than 'thats not encrypted' + def test_empty_string_wrong_password(self): + seq = '' + self.vault = self.wrong_vault + avu = self._from_plaintext(seq) + self.assert_values(avu, seq)