diff --git a/v2/ansible/parsing/mod_args.py b/v2/ansible/parsing/mod_args.py index 6650355ba3..e3fdba093d 100644 --- a/v2/ansible/parsing/mod_args.py +++ b/v2/ansible/parsing/mod_args.py @@ -20,7 +20,6 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type from six import iteritems, string_types -from types import NoneType from ansible.errors import AnsibleParserError from ansible.plugins import module_loader @@ -165,7 +164,7 @@ class ModuleArgsParser: # form is like: local_action: copy src=a dest=b ... pretty common check_raw = action in ('command', 'shell', 'script') args = parse_kv(thing, check_raw=check_raw) - elif isinstance(thing, NoneType): + elif thing is None: # this can happen with modules which take no params, like ping: args = None else: diff --git a/v2/ansible/parsing/vault/__init__.py b/v2/ansible/parsing/vault/__init__.py index 92c99fdad5..80c48a3b69 100644 --- a/v2/ansible/parsing/vault/__init__.py +++ b/v2/ansible/parsing/vault/__init__.py @@ -22,6 +22,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import sys import os import shlex import shutil @@ -35,7 +36,10 @@ from hashlib import sha256 from hashlib import md5 from binascii import hexlify from binascii import unhexlify +from six import binary_type, byte2int, PY2, text_type from ansible import constants as C +from ansible.utils.unicode import to_unicode, to_bytes + try: from Crypto.Hash import SHA256, HMAC @@ -60,15 +64,16 @@ except ImportError: # AES IMPORTS try: from Crypto.Cipher import AES as AES - HAS_AES = True + HAS_AES = True except ImportError: - HAS_AES = False + HAS_AES = False CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto" -HEADER='$ANSIBLE_VAULT' +HEADER=u'$ANSIBLE_VAULT' CIPHER_WHITELIST=['AES', 'AES256'] + class VaultLib(object): def __init__(self, password): @@ -76,26 +81,28 @@ class VaultLib(object): self.cipher_name = None self.version = '1.1' - def is_encrypted(self, data): + def is_encrypted(self, data): + data = to_unicode(data) if data.startswith(HEADER): return True else: return False def encrypt(self, data): + data = to_unicode(data) if self.is_encrypted(data): raise errors.AnsibleError("data is already encrypted") if not self.cipher_name: self.cipher_name = "AES256" - #raise errors.AnsibleError("the cipher must be set before encrypting data") + # raise errors.AnsibleError("the cipher must be set before encrypting data") - if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: + if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: cipher = globals()['Vault' + self.cipher_name] this_cipher = cipher() else: - raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + raise errors.AnsibleError("{} cipher could not be found".format(self.cipher_name)) """ # combine sha + data @@ -106,11 +113,13 @@ class VaultLib(object): # encrypt sha + data enc_data = this_cipher.encrypt(data, self.password) - # add header + # add header tmp_data = self._add_header(enc_data) return tmp_data def decrypt(self, data): + data = to_bytes(data) + if self.password is None: raise errors.AnsibleError("A vault password must be specified to decrypt data") @@ -121,48 +130,47 @@ class VaultLib(object): data = self._split_header(data) # create the cipher object - if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: - cipher = globals()['Vault' + self.cipher_name] + ciphername = to_unicode(self.cipher_name) + if 'Vault' + ciphername in globals() and ciphername in CIPHER_WHITELIST: + cipher = globals()['Vault' + ciphername] this_cipher = cipher() else: - raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + raise errors.AnsibleError("{} cipher could not be found".format(ciphername)) # try to unencrypt data data = this_cipher.decrypt(data, self.password) if data is None: raise errors.AnsibleError("Decryption failed") - return data + return data - def _add_header(self, data): + def _add_header(self, data): # combine header and encrypted data in 80 char columns #tmpdata = hexlify(data) - tmpdata = [data[i:i+80] for i in range(0, len(data), 80)] - + tmpdata = [to_bytes(data[i:i+80]) for i in range(0, len(data), 80)] if not self.cipher_name: raise errors.AnsibleError("the cipher must be set before adding a header") - dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher_name + "\n" - + dirty_data = to_bytes(HEADER + ";" + self.version + ";" + self.cipher_name + "\n") for l in tmpdata: - dirty_data += l + '\n' + dirty_data += l + b'\n' return dirty_data - def _split_header(self, data): + def _split_header(self, data): # used by decrypt - tmpdata = data.split('\n') - tmpheader = tmpdata[0].strip().split(';') + tmpdata = data.split(b'\n') + tmpheader = tmpdata[0].strip().split(b';') - self.version = str(tmpheader[1].strip()) - self.cipher_name = str(tmpheader[2].strip()) - clean_data = '\n'.join(tmpdata[1:]) + self.version = to_unicode(tmpheader[1].strip()) + self.cipher_name = to_unicode(tmpheader[2].strip()) + clean_data = b'\n'.join(tmpdata[1:]) """ - # strip out newline, join, unhex + # strip out newline, join, unhex clean_data = [ x.strip() for x in clean_data ] clean_data = unhexlify(''.join(clean_data)) """ @@ -176,9 +184,9 @@ class VaultLib(object): pass class VaultEditor(object): - # uses helper methods for write_file(self, filename, data) - # to write a file so that code isn't duplicated for simple - # file I/O, ditto read_file(self, filename) and launch_editor(self, filename) + # uses helper methods for write_file(self, filename, data) + # to write a file so that code isn't duplicated for simple + # file I/O, ditto read_file(self, filename) and launch_editor(self, filename) # ... "Don't Repeat Yourself", etc. def __init__(self, cipher_name, password, filename): @@ -302,7 +310,7 @@ class VaultEditor(object): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: raise errors.AnsibleError(CRYPTO_UPGRADE) - # decrypt + # decrypt tmpdata = self.read_data(self.filename) this_vault = VaultLib(self.password) dec_data = this_vault.decrypt(tmpdata) @@ -324,10 +332,10 @@ class VaultEditor(object): return tmpdata def write_data(self, data, filename): - if os.path.isfile(filename): + if os.path.isfile(filename): os.remove(filename) f = open(filename, "wb") - f.write(data) + f.write(to_bytes(data)) f.close() def shuffle_files(self, src, dest): @@ -369,9 +377,10 @@ class VaultAES(object): """ Create a key and an initialization vector """ - d = d_i = '' + d = d_i = b'' while len(d) < key_length + iv_length: - d_i = md5(d_i + password + salt).digest() + text = "{}{}{}".format(d_i, password, salt) + d_i = md5(to_bytes(text)).digest() d += d_i key = d[:key_length] @@ -385,28 +394,29 @@ class VaultAES(object): # combine sha + data - this_sha = sha256(data).hexdigest() + this_sha = sha256(to_bytes(data)).hexdigest() tmp_data = this_sha + "\n" + data - in_file = BytesIO(tmp_data) + in_file = BytesIO(to_bytes(tmp_data)) in_file.seek(0) out_file = BytesIO() bs = AES.block_size - # Get a block of random data. EL does not have Crypto.Random.new() + # Get a block of random data. EL does not have Crypto.Random.new() # so os.urandom is used for cross platform purposes salt = os.urandom(bs - len('Salted__')) key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) cipher = AES.new(key, AES.MODE_CBC, iv) - out_file.write('Salted__' + salt) + full = to_bytes(b'Salted__' + salt) + out_file.write(full) finished = False while not finished: chunk = in_file.read(1024 * bs) if len(chunk) == 0 or len(chunk) % bs != 0: padding_length = (bs - len(chunk) % bs) or bs - chunk += padding_length * chr(padding_length) + chunk += to_bytes(padding_length * chr(padding_length)) finished = True out_file.write(cipher.encrypt(chunk)) @@ -416,14 +426,14 @@ class VaultAES(object): return tmp_data - + def decrypt(self, data, password, key_length=32): """ Read encrypted data from in_file and write decrypted to out_file """ # http://stackoverflow.com/a/14989032 - data = ''.join(data.split('\n')) + data = b''.join(data.split(b'\n')) data = unhexlify(data) in_file = BytesIO(data) @@ -431,41 +441,49 @@ class VaultAES(object): out_file = BytesIO() bs = AES.block_size - salt = in_file.read(bs)[len('Salted__'):] + tmpsalt = in_file.read(bs) + salt = tmpsalt[len('Salted__'):] key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) cipher = AES.new(key, AES.MODE_CBC, iv) - next_chunk = '' + next_chunk = b'' finished = False while not finished: chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs)) if len(next_chunk) == 0: - padding_length = ord(chunk[-1]) + if PY2: + padding_length = ord(chunk[-1]) + else: + padding_length = chunk[-1] + chunk = chunk[:-padding_length] finished = True + out_file.write(chunk) + out_file.flush() # reset the stream pointer to the beginning out_file.seek(0) - new_data = out_file.read() + out_data = out_file.read() + out_file.close() + new_data = to_unicode(out_data) # split out sha and verify decryption split_data = new_data.split("\n") this_sha = split_data[0] this_data = '\n'.join(split_data[1:]) - test_sha = sha256(this_data).hexdigest() + test_sha = sha256(to_bytes(this_data)).hexdigest() if this_sha != test_sha: raise errors.AnsibleError("Decryption failed") - #return out_file.read() return this_data class VaultAES256(object): """ - Vault implementation using AES-CTR with an HMAC-SHA256 authentication code. + Vault implementation using AES-CTR with an HMAC-SHA256 authentication code. Keys are derived using PBKDF2 """ @@ -481,7 +499,7 @@ class VaultAES256(object): keylength = 32 # match the size used for counter.new to avoid extra work - ivlength = 16 + ivlength = 16 hash_function = SHA256 @@ -489,7 +507,7 @@ class VaultAES256(object): pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest() - derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, + derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, count=10000, prf=pbkdf2_prf) key1 = derivedkey[:keylength] @@ -523,28 +541,28 @@ class VaultAES256(object): cipher = AES.new(key1, AES.MODE_CTR, counter=ctr) # ENCRYPT PADDED DATA - cryptedData = cipher.encrypt(data) + cryptedData = cipher.encrypt(data) # COMBINE SALT, DIGEST AND DATA hmac = HMAC.new(key2, cryptedData, SHA256) - message = "%s\n%s\n%s" % ( hexlify(salt), hmac.hexdigest(), hexlify(cryptedData) ) + message = b''.join([hexlify(salt), b"\n", to_bytes(hmac.hexdigest()), b"\n", hexlify(cryptedData)]) message = hexlify(message) return message def decrypt(self, data, password): # SPLIT SALT, DIGEST, AND DATA - data = ''.join(data.split("\n")) + data = b''.join(data.split(b"\n")) data = unhexlify(data) - salt, cryptedHmac, cryptedData = data.split("\n", 2) + salt, cryptedHmac, cryptedData = data.split(b"\n", 2) salt = unhexlify(salt) cryptedData = unhexlify(cryptedData) key1, key2, iv = self.gen_key_initctr(password, salt) - # EXIT EARLY IF DIGEST DOESN'T MATCH + # EXIT EARLY IF DIGEST DOESN'T MATCH hmacDecrypt = HMAC.new(key2, cryptedData, SHA256) - if not self.is_equal(cryptedHmac, hmacDecrypt.hexdigest()): + if not self.is_equal(cryptedHmac, to_bytes(hmacDecrypt.hexdigest())): return None # SET THE COUNTER AND THE CIPHER @@ -555,19 +573,31 @@ class VaultAES256(object): decryptedData = cipher.decrypt(cryptedData) # UNPAD DATA - padding_length = ord(decryptedData[-1]) + try: + padding_length = ord(decryptedData[-1]) + except TypeError: + padding_length = decryptedData[-1] + decryptedData = decryptedData[:-padding_length] - return decryptedData + return to_unicode(decryptedData) def is_equal(self, a, b): + """ + Comparing 2 byte arrrays in constant time + to avoid timing attacks. + + It would be nice if there was a library for this but + hey. + """ # http://codahale.com/a-lesson-in-timing-attacks/ if len(a) != len(b): return False - + result = 0 for x, y in zip(a, b): - result |= ord(x) ^ ord(y) - return result == 0 - - + if PY2: + result |= ord(x) ^ ord(y) + else: + result |= x ^ y + return result == 0 diff --git a/v2/ansible/parsing/yaml/objects.py b/v2/ansible/parsing/yaml/objects.py index fe37eaab94..33ea1ad37e 100644 --- a/v2/ansible/parsing/yaml/objects.py +++ b/v2/ansible/parsing/yaml/objects.py @@ -19,14 +19,17 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -class AnsibleBaseYAMLObject: +from six import text_type + + +class AnsibleBaseYAMLObject(object): ''' the base class used to sub-class python built-in objects so that we can add attributes to them during yaml parsing ''' - _data_source = None - _line_number = 0 + _data_source = None + _line_number = 0 _column_number = 0 def _get_ansible_position(self): @@ -36,21 +39,27 @@ class AnsibleBaseYAMLObject: try: (src, line, col) = obj except (TypeError, ValueError): - raise AssertionError('ansible_pos can only be set with a tuple/list of three values: source, line number, column number') - self._data_source = src - self._line_number = line + raise AssertionError( + 'ansible_pos can only be set with a tuple/list ' + 'of three values: source, line number, column number' + ) + self._data_source = src + self._line_number = line self._column_number = col ansible_pos = property(_get_ansible_position, _set_ansible_position) + class AnsibleMapping(AnsibleBaseYAMLObject, dict): ''' sub class for dictionaries ''' pass -class AnsibleUnicode(AnsibleBaseYAMLObject, unicode): + +class AnsibleUnicode(AnsibleBaseYAMLObject, text_type): ''' sub class for unicode objects ''' pass + class AnsibleSequence(AnsibleBaseYAMLObject, list): ''' sub class for lists ''' pass diff --git a/v2/ansible/utils/unicode.py b/v2/ansible/utils/unicode.py index 7bd035c007..2cff2e5e45 100644 --- a/v2/ansible/utils/unicode.py +++ b/v2/ansible/utils/unicode.py @@ -19,6 +19,8 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from six import string_types, text_type, binary_type, PY3 + # to_bytes and to_unicode were written by Toshio Kuratomi for the # python-kitchen library https://pypi.python.org/pypi/kitchen # They are licensed in kitchen under the terms of the GPLv2+ @@ -35,6 +37,9 @@ _LATIN1_ALIASES = frozenset(('latin-1', 'LATIN-1', 'latin1', 'LATIN1', # EXCEPTION_CONVERTERS is defined below due to using to_unicode +if PY3: + basestring = (str, bytes) + def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None): '''Convert an object into a :class:`unicode` string @@ -89,12 +94,12 @@ def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None): # Could use isbasestring/isunicode here but we want this code to be as # fast as possible if isinstance(obj, basestring): - if isinstance(obj, unicode): + if isinstance(obj, text_type): return obj if encoding in _UTF8_ALIASES: - return unicode(obj, 'utf-8', errors) + return text_type(obj, 'utf-8', errors) if encoding in _LATIN1_ALIASES: - return unicode(obj, 'latin-1', errors) + return text_type(obj, 'latin-1', errors) return obj.decode(encoding, errors) if not nonstring: @@ -110,19 +115,19 @@ def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None): simple = None if not simple: try: - simple = str(obj) + simple = text_type(obj) except UnicodeError: try: simple = obj.__str__() except (UnicodeError, AttributeError): simple = u'' - if isinstance(simple, str): - return unicode(simple, encoding, errors) + if isinstance(simple, binary_type): + return text_type(simple, encoding, errors) return simple elif nonstring in ('repr', 'strict'): obj_repr = repr(obj) - if isinstance(obj_repr, str): - obj_repr = unicode(obj_repr, encoding, errors) + if isinstance(obj_repr, binary_type): + obj_repr = text_type(obj_repr, encoding, errors) if nonstring == 'repr': return obj_repr raise TypeError('to_unicode was given "%(obj)s" which is neither' @@ -198,19 +203,19 @@ def to_bytes(obj, encoding='utf-8', errors='replace', nonstring=None): # Could use isbasestring, isbytestring here but we want this to be as fast # as possible if isinstance(obj, basestring): - if isinstance(obj, str): + if isinstance(obj, binary_type): return obj return obj.encode(encoding, errors) if not nonstring: nonstring = 'simplerepr' if nonstring == 'empty': - return '' + return b'' elif nonstring == 'passthru': return obj elif nonstring == 'simplerepr': try: - simple = str(obj) + simple = binary_type(obj) except UnicodeError: try: simple = obj.__str__() @@ -220,19 +225,19 @@ def to_bytes(obj, encoding='utf-8', errors='replace', nonstring=None): try: simple = obj.__unicode__() except (AttributeError, UnicodeError): - simple = '' - if isinstance(simple, unicode): + simple = b'' + if isinstance(simple, text_type): simple = simple.encode(encoding, 'replace') return simple elif nonstring in ('repr', 'strict'): try: obj_repr = obj.__repr__() except (AttributeError, UnicodeError): - obj_repr = '' - if isinstance(obj_repr, unicode): + obj_repr = b'' + if isinstance(obj_repr, text_type): obj_repr = obj_repr.encode(encoding, errors) else: - obj_repr = str(obj_repr) + obj_repr = binary_type(obj_repr) if nonstring == 'repr': return obj_repr raise TypeError('to_bytes was given "%(obj)s" which is neither' diff --git a/v2/test/parsing/test_data_loader.py b/v2/test/parsing/test_data_loader.py index 75ceb662f7..b9c37cdd0c 100644 --- a/v2/test/parsing/test_data_loader.py +++ b/v2/test/parsing/test_data_loader.py @@ -19,6 +19,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from six import PY2 from yaml.scanner import ScannerError from ansible.compat.tests import unittest @@ -79,6 +80,11 @@ class TestDataLoaderWithVault(unittest.TestCase): 3135306561356164310a343937653834643433343734653137383339323330626437313562306630 3035 """ - with patch('__builtin__.open', mock_open(read_data=vaulted_data)): + if PY2: + builtins_name = '__builtin__' + else: + builtins_name = 'builtins' + + with patch(builtins_name + '.open', mock_open(read_data=vaulted_data)): output = self._loader.load_from_file('dummy_vault.txt') self.assertEqual(output, dict(foo='bar')) diff --git a/v2/test/parsing/vault/test_vault.py b/v2/test/parsing/vault/test_vault.py index d24573c729..2aaac27fc7 100644 --- a/v2/test/parsing/vault/test_vault.py +++ b/v2/test/parsing/vault/test_vault.py @@ -24,11 +24,14 @@ import os import shutil import time import tempfile +import six + from binascii import unhexlify from binascii import hexlify from nose.plugins.skip import SkipTest from ansible.compat.tests import unittest +from ansible.utils.unicode import to_bytes, to_unicode from ansible import errors from ansible.parsing.vault import VaultLib @@ -63,13 +66,13 @@ class TestVaultLib(unittest.TestCase): 'decrypt', '_add_header', '_split_header',] - for slot in slots: + for slot in slots: assert hasattr(v, slot), "VaultLib is missing the %s method" % slot def test_is_encrypted(self): v = VaultLib(None) - assert not v.is_encrypted("foobar"), "encryption check on plaintext failed" - data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify("ansible") + assert not v.is_encrypted(u"foobar"), "encryption check on plaintext failed" + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") assert v.is_encrypted(data), "encryption check on headered text failed" def test_add_header(self): @@ -77,22 +80,22 @@ class TestVaultLib(unittest.TestCase): v.cipher_name = "TEST" sensitive_data = "ansible" data = v._add_header(sensitive_data) - lines = data.split('\n') + lines = data.split(b'\n') assert len(lines) > 1, "failed to properly add header" - header = lines[0] + header = to_unicode(lines[0]) assert header.endswith(';TEST'), "header does end with cipher name" header_parts = header.split(';') - assert len(header_parts) == 3, "header has the wrong number of parts" + assert len(header_parts) == 3, "header has the wrong number of parts" assert header_parts[0] == '$ANSIBLE_VAULT', "header does not start with $ANSIBLE_VAULT" assert header_parts[1] == v.version, "header version is incorrect" assert header_parts[2] == 'TEST', "header does end with cipher name" def test_split_header(self): v = VaultLib('ansible') - data = "$ANSIBLE_VAULT;9.9;TEST\nansible" - rdata = v._split_header(data) - lines = rdata.split('\n') - assert lines[0] == "ansible" + data = b"$ANSIBLE_VAULT;9.9;TEST\nansible" + rdata = v._split_header(data) + lines = rdata.split(b'\n') + assert lines[0] == b"ansible" assert v.cipher_name == 'TEST', "cipher name was not set" assert v.version == "9.9" @@ -100,11 +103,11 @@ class TestVaultLib(unittest.TestCase): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: raise SkipTest v = VaultLib('ansible') - v.cipher_name = 'AES' + v.cipher_name = u'AES' enc_data = v.encrypt("foobar") dec_data = v.decrypt(enc_data) assert enc_data != "foobar", "encryption failed" - assert dec_data == "foobar", "decryption failed" + assert dec_data == "foobar", "decryption failed" def test_encrypt_decrypt_aes256(self): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: @@ -114,20 +117,20 @@ class TestVaultLib(unittest.TestCase): enc_data = v.encrypt("foobar") dec_data = v.decrypt(enc_data) assert enc_data != "foobar", "encryption failed" - assert dec_data == "foobar", "decryption failed" + assert dec_data == "foobar", "decryption failed" def test_encrypt_encrypted(self): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: raise SkipTest v = VaultLib('ansible') v.cipher_name = 'AES' - data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify("ansible") + data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(six.b("ansible")) error_hit = False try: enc_data = v.encrypt(data) except errors.AnsibleError as e: error_hit = True - assert error_hit, "No error was thrown when trying to encrypt data with a header" + assert error_hit, "No error was thrown when trying to encrypt data with a header" def test_decrypt_decrypted(self): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: @@ -139,7 +142,7 @@ class TestVaultLib(unittest.TestCase): dec_data = v.decrypt(data) except errors.AnsibleError as e: error_hit = True - assert error_hit, "No error was thrown when trying to decrypt data without a header" + assert error_hit, "No error was thrown when trying to decrypt data without a header" def test_cipher_not_set(self): # not setting the cipher should default to AES256 @@ -152,5 +155,5 @@ class TestVaultLib(unittest.TestCase): enc_data = v.encrypt(data) except errors.AnsibleError as e: error_hit = True - assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set" - assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name + assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set" + assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name diff --git a/v2/test/parsing/vault/test_vault_editor.py b/v2/test/parsing/vault/test_vault_editor.py index c788df54ae..2ddf3de27a 100644 --- a/v2/test/parsing/vault/test_vault_editor.py +++ b/v2/test/parsing/vault/test_vault_editor.py @@ -21,6 +21,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type #!/usr/bin/env python +import sys import getpass import os import shutil @@ -32,6 +33,7 @@ from nose.plugins.skip import SkipTest from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch +from ansible.utils.unicode import to_bytes, to_unicode from ansible import errors from ansible.parsing.vault import VaultLib @@ -88,12 +90,12 @@ class TestVaultEditor(unittest.TestCase): 'read_data', 'write_data', 'shuffle_files'] - for slot in slots: + for slot in slots: assert hasattr(v, slot), "VaultLib is missing the %s method" % slot @patch.object(VaultEditor, '_editor_shell_command') def test_create_file(self, mock_editor_shell_command): - + def sc_side_effect(filename): return ['touch', filename] mock_editor_shell_command.side_effect = sc_side_effect @@ -107,12 +109,16 @@ class TestVaultEditor(unittest.TestCase): self.assertTrue(os.path.exists(tmp_file.name)) def test_decrypt_1_0(self): - if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: + """ + Skip testing decrypting 1.0 files if we don't have access to AES, KDF or + Counter, or we are running on python3 since VaultAES hasn't been backported. + """ + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3': raise SkipTest v10_file = tempfile.NamedTemporaryFile(delete=False) with v10_file as f: - f.write(v10_data) + f.write(to_bytes(v10_data)) ve = VaultEditor(None, "ansible", v10_file.name) @@ -125,13 +131,13 @@ class TestVaultEditor(unittest.TestCase): # verify decrypted content f = open(v10_file.name, "rb") - fdata = f.read() + fdata = to_unicode(f.read()) f.close() os.unlink(v10_file.name) - assert error_hit == False, "error decrypting 1.0 file" - assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() + assert error_hit == False, "error decrypting 1.0 file" + assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() def test_decrypt_1_1(self): @@ -140,7 +146,7 @@ class TestVaultEditor(unittest.TestCase): v11_file = tempfile.NamedTemporaryFile(delete=False) with v11_file as f: - f.write(v11_data) + f.write(to_bytes(v11_data)) ve = VaultEditor(None, "ansible", v11_file.name) @@ -153,28 +159,32 @@ class TestVaultEditor(unittest.TestCase): # verify decrypted content f = open(v11_file.name, "rb") - fdata = f.read() + fdata = to_unicode(f.read()) f.close() os.unlink(v11_file.name) - assert error_hit == False, "error decrypting 1.0 file" - assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() + assert error_hit == False, "error decrypting 1.0 file" + assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() def test_rekey_migration(self): - if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: + """ + Skip testing rekeying files if we don't have access to AES, KDF or + Counter, or we are running on python3 since VaultAES hasn't been backported. + """ + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3': raise SkipTest v10_file = tempfile.NamedTemporaryFile(delete=False) with v10_file as f: - f.write(v10_data) + f.write(to_bytes(v10_data)) ve = VaultEditor(None, "ansible", v10_file.name) # make sure the password functions for the cipher error_hit = False - try: + try: ve.rekey_file('ansible2') except errors.AnsibleError as e: error_hit = True @@ -184,7 +194,7 @@ class TestVaultEditor(unittest.TestCase): fdata = f.read() f.close() - assert error_hit == False, "error rekeying 1.0 file to 1.1" + assert error_hit == False, "error rekeying 1.0 file to 1.1" # ensure filedata can be decrypted, is 1.1 and is AES256 vl = VaultLib("ansible2") @@ -198,7 +208,7 @@ class TestVaultEditor(unittest.TestCase): os.unlink(v10_file.name) assert vl.cipher_name == "AES256", "wrong cipher name set after rekey: %s" % vl.cipher_name - assert error_hit == False, "error decrypting migrated 1.0 file" + assert error_hit == False, "error decrypting migrated 1.0 file" assert dec_data.strip() == "foo", "incorrect decryption of rekeyed/migrated file: %s" % dec_data diff --git a/v2/test/parsing/yaml/test_loader.py b/v2/test/parsing/yaml/test_loader.py index d393d72a00..37eeabff83 100644 --- a/v2/test/parsing/yaml/test_loader.py +++ b/v2/test/parsing/yaml/test_loader.py @@ -20,6 +20,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from six import text_type, binary_type from six.moves import StringIO from collections import Sequence, Set, Mapping @@ -28,6 +29,7 @@ from ansible.compat.tests.mock import patch from ansible.parsing.yaml.loader import AnsibleLoader + class TestAnsibleLoaderBasic(unittest.TestCase): def setUp(self): @@ -52,7 +54,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase): loader = AnsibleLoader(stream, 'myfile.yml') data = loader.get_single_data() self.assertEqual(data, u'Ansible') - self.assertIsInstance(data, unicode) + self.assertIsInstance(data, text_type) self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) @@ -63,7 +65,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase): loader = AnsibleLoader(stream, 'myfile.yml') data = loader.get_single_data() self.assertEqual(data, u'Cafè Eñyei') - self.assertIsInstance(data, unicode) + self.assertIsInstance(data, text_type) self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) @@ -76,8 +78,8 @@ class TestAnsibleLoaderBasic(unittest.TestCase): data = loader.get_single_data() self.assertEqual(data, {'webster': 'daniel', 'oed': 'oxford'}) self.assertEqual(len(data), 2) - self.assertIsInstance(data.keys()[0], unicode) - self.assertIsInstance(data.values()[0], unicode) + self.assertIsInstance(list(data.keys())[0], text_type) + self.assertIsInstance(list(data.values())[0], text_type) # Beginning of the first key self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) @@ -94,7 +96,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase): data = loader.get_single_data() self.assertEqual(data, [u'a', u'b']) self.assertEqual(len(data), 2) - self.assertIsInstance(data[0], unicode) + self.assertIsInstance(data[0], text_type) self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) @@ -204,10 +206,10 @@ class TestAnsibleLoaderPlay(unittest.TestCase): def walk(self, data): # Make sure there's no str in the data - self.assertNotIsInstance(data, str) + self.assertNotIsInstance(data, binary_type) # Descend into various container types - if isinstance(data, unicode): + if isinstance(data, text_type): # strings are a sequence so we have to be explicit here return elif isinstance(data, (Sequence, Set)):