From 52a8efefbae9192ed153b89fa907a4086242666e Mon Sep 17 00:00:00 2001 From: James Tanner Date: Mon, 24 Feb 2014 13:09:36 -0500 Subject: [PATCH] Vault rewrite, pass 1 --- bin/ansible-vault | 56 ++-- lib/ansible/utils/__init__.py | 19 +- lib/ansible/utils/vault.py | 508 +++++++++++++--------------------- 3 files changed, 236 insertions(+), 347 deletions(-) diff --git a/bin/ansible-vault b/bin/ansible-vault index 6c793b871a..7d3c7f208c 100755 --- a/bin/ansible-vault +++ b/bin/ansible-vault @@ -20,13 +20,13 @@ # example playbook to bootstrap this script in the examples/ dir which # installs ansible and sets it up to run on cron. +import os import sys import traceback from ansible import utils from ansible import errors -from ansible.utils.vault import * -from ansible.utils.vault import Vault +from ansible.utils.vault import VaultEditor from optparse import OptionParser @@ -100,32 +100,30 @@ def get_opt(options, k, defval=""): # Command functions #------------------------------------------------------------------------------------- -def _get_vault(filename, options, password): - this_vault = Vault() - this_vault.filename = filename - this_vault.vault_password = password - this_vault.password = password - return this_vault - def execute_create(args, options, parser): if len(args) > 1: - raise errors.AnsibleError("create does not accept more than one filename") - + raise errors.AnsibleError("'create' does not accept more than one filename") password, new_password = utils.ask_vault_passwords(ask_vault_pass=True, confirm_vault=True) - this_vault = _get_vault(args[0], options, password) - if not hasattr(options, 'cipher'): - this_vault.cipher = 'AES' - this_vault.create() + cipher = 'AES' + if hasattr(options, 'cipher'): + cipher = options.cipher + + this_editor = VaultEditor(cipher, password, args[0]) + this_editor.create_file() def execute_decrypt(args, options, parser): password, new_password = utils.ask_vault_passwords(ask_vault_pass=True) + cipher = 'AES' + if hasattr(options, 'cipher'): + cipher = options.cipher + for f in args: - this_vault = _get_vault(f, options, password) - this_vault.decrypt() + this_editor = VaultEditor(cipher, password, f) + this_editor.decrypt_file() print "Decryption successful" @@ -136,29 +134,35 @@ def execute_edit(args, options, parser): password, new_password = utils.ask_vault_passwords(ask_vault_pass=True) + cipher = None + for f in args: - this_vault = _get_vault(f, options, password) - this_vault.edit() + this_editor = VaultEditor(cipher, password, f) + this_editor.edit_file() def execute_encrypt(args, options, parser): + if len(args) > 1: + raise errors.AnsibleError("'create' does not accept more than one filename") password, new_password = utils.ask_vault_passwords(ask_vault_pass=True, confirm_vault=True) + cipher = 'AES' + if hasattr(options, 'cipher'): + cipher = options.cipher + for f in args: - this_vault = _get_vault(f, options, password) - if not hasattr(options, 'cipher'): - this_vault.cipher = 'AES' - this_vault.encrypt() + this_editor = VaultEditor(cipher, password, f) + this_editor.encrypt_file() print "Encryption successful" def execute_rekey(args, options, parser): password, new_password = utils.ask_vault_passwords(ask_vault_pass=True, ask_new_vault_pass=True, confirm_new=True) - + cipher = None for f in args: - this_vault = _get_vault(f, options, password) - this_vault.rekey(new_password) + this_editor = VaultEditor(cipher, password, f) + this_editor.rekey_file(new_password) print "Rekey successful" diff --git a/lib/ansible/utils/__init__.py b/lib/ansible/utils/__init__.py index 3df2ddfbed..4f2ad73419 100644 --- a/lib/ansible/utils/__init__.py +++ b/lib/ansible/utils/__init__.py @@ -43,7 +43,8 @@ import getpass import sys import textwrap -import vault +#import vault +from vault import VaultLib VERBOSITY=0 @@ -501,14 +502,14 @@ def parse_yaml_from_file(path, vault_password=None): data = None - #VAULT - if vault.is_encrypted(path): - data = vault.decrypt(path, vault_password) - else: - try: - data = open(path).read() - except IOError: - raise errors.AnsibleError("file could not read: %s" % path) + try: + data = open(path).read() + except IOError: + raise errors.AnsibleError("file could not read: %s" % path) + + vault = VaultLib(password=vault_password) + if vault.is_encrypted(data): + data = vault.decrypt(data) try: return parse_yaml(data) diff --git a/lib/ansible/utils/vault.py b/lib/ansible/utils/vault.py index b25f19c155..b5987151a3 100644 --- a/lib/ansible/utils/vault.py +++ b/lib/ansible/utils/vault.py @@ -32,101 +32,122 @@ from ansible import constants as C # AES IMPORTS try: - from Crypto.Cipher import AES as AES_ + from Crypto.Cipher import AES as AES HAS_AES = True except ImportError: HAS_AES = False HEADER='$ANSIBLE_VAULT' +CIPHER_WHITELIST=['AES'] -def is_encrypted(filename): +class VaultLib(object): - ''' - Check a file for the encrypted header and return True or False - - The first line should start with the header - defined by the global HEADER. If true, we - assume this is a properly encrypted file. - ''' - - # read first line of the file - with open(filename) as f: - try: - head = f.next() - except StopIteration: - # empty file, so not encrypted - return False - - if head.startswith(HEADER): - return True - else: - return False - -def decrypt(filename, password): - - ''' - Return a decrypted string of the contents in an encrypted file - - This is used by the yaml loading code in ansible - to automatically determine the encryption type - and return a plaintext string of the unencrypted - data. - ''' - - if password is None: - raise errors.AnsibleError("A vault password must be specified to decrypt %s" % filename) - - V = Vault(filename=filename, vault_password=password) - return_data = V._decrypt_to_string() - - if not V._verify_decryption(return_data): - raise errors.AnsibleError("Decryption of %s failed" % filename) - - this_sha, return_data = V._strip_sha(return_data) - return return_data.strip() - - -class Vault(object): - def __init__(self, filename=None, cipher=None, vault_password=None): - self.filename = filename - self.vault_password = vault_password - self.cipher = cipher + def __init__(self, password): + self.password = password + self.cipher_name = None self.version = '1.0' - ############### - # PUBLIC - ############### - - def eval_header(self): - - """ Read first line of the file and parse header """ - - # read first line - with open(self.filename) as f: - #head=[f.next() for x in xrange(1)] - head = f.next() - - this_version = None - this_cipher = None - - # split segments - if len(head.split(';')) == 3: - this_version = head.split(';')[1].strip() - this_cipher = head.split(';')[2].strip() + def is_encrypted(self, data): + if data.startswith(HEADER): + return True else: - raise errors.AnsibleError("%s has an invalid header" % self.filename) + return False - # validate acceptable version - this_version = float(this_version) - if this_version < C.VAULT_VERSION_MIN or this_version > C.VAULT_VERSION_MAX: - raise errors.AnsibleError("%s must have a version between %s and %s " % (self.filename, - C.VAULT_VERSION_MIN, - C.VAULT_VERSION_MAX)) - # set properties - self.cipher = this_cipher - self.version = this_version + def encrypt(self, data): - def create(self): + if self.is_encrypted(data): + raise errors.AnsibleError("data is already encrypted") + + if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: + cipher = globals()['Vault' + self.cipher_name] + this_cipher = cipher() + else: + raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + + # combine sha + data + this_sha = sha256(data).hexdigest() + tmp_data = this_sha + "\n" + data + # encrypt sha + data + tmp_data = this_cipher.encrypt(tmp_data, self.password) + # add header + tmp_data = self._add_headers_and_hexify_encrypted_data(tmp_data) + return tmp_data + + def decrypt(self, data): + if self.password is None: + raise errors.AnsibleError("A vault password must be specified to decrypt data") + + if not self.is_encrypted(data): + raise errors.AnsibleError("data is not encrypted") + + # clean out header, hex and sha + data = self._split_headers_and_get_unhexified_data(data) + + # create the cipher object + if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: + cipher = globals()['Vault' + self.cipher_name] + this_cipher = cipher() + else: + raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + + # try to unencrypt data + data = this_cipher.decrypt(data, self.password) + + # split out sha and verify decryption + split_data = data.split("\n") + this_sha = split_data[0] + this_data = '\n'.join(split_data[1:]) + test_sha = sha256(this_data).hexdigest() + if this_sha != test_sha: + raise errors.AnsibleError("Decryption of %s failed" % filename) + + return this_data + + def _add_headers_and_hexify_encrypted_data(self, data): + # combine header and hexlified encrypted data in 80 char columns + + tmpdata = hexlify(data) + tmpdata = [tmpdata[i:i+80] for i in range(0, len(tmpdata), 80)] + + if not self.cipher_name: + raise errors.AnsibleError("the cipher must be set before adding a header") + + dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher_name + "\n" + for l in tmpdata: + dirty_data += l + '\n' + + return dirty_data + + + def _split_headers_and_get_unhexified_data(self, data): + # used by decrypt + + tmpdata = data.split('\n') + tmpheader = tmpdata[0].strip().split(';') + + self.version = str(tmpheader[1].strip()) + self.cipher_name = str(tmpheader[2].strip()) + clean_data = ''.join(tmpdata[1:]) + + # strip out newline, join, unhex + clean_data = [ x.strip() for x in clean_data ] + clean_data = unhexlify(''.join(clean_data)) + + return clean_data + +class VaultEditor(object): + # uses helper methods for write_file(self, filename, data) + # to write a file so that code isn't duplicated for simple + # file I/O, ditto read_file(self, filename) and launch_editor(self, filename) + # ... "Don't Repeat Yourself", etc. + + def __init__(self, cipher_name, password, filename): + # instantiates a member variable for VaultLib + self.cipher_name = cipher_name + self.password = password + self.filename = filename + + def create_file(self): """ create a new encrypted file """ if os.path.isfile(self.filename): @@ -135,250 +156,100 @@ class Vault(object): # drop the user into vim on file EDITOR = os.environ.get('EDITOR','vim') call([EDITOR, self.filename]) + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + this_vault.cipher_name = self.cipher_name + enc_data = this_vault.encrypt(tmpdata) + self.write_data(enc_data, self.filename) - self.encrypt() - - def decrypt(self): - """ unencrypt a file inplace """ - - if not is_encrypted(self.filename): - raise errors.AnsibleError("%s is not encrypted" % self.filename) - - # set cipher based on file header - self.eval_header() - - # decrypt it - data = self._decrypt_to_string() - - # verify sha and then strip it out - if not self._verify_decryption(data): - raise errors.AnsibleError("decryption of %s failed" % self.filename) - this_sha, clean_data = self._strip_sha(data) + def decrypt_file(self): + if not os.path.isfile(self.filename): + raise errors.AnsibleError("%s does not exist" % self.filename) - # write back to original file - f = open(self.filename, "wb") - f.write(clean_data) - f.close() - - def edit(self, filename=None, password=None, cipher=None, version=None): - - if not is_encrypted(self.filename): + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + if this_vault.is_encrypted(tmpdata): + dec_data = this_vault.decrypt(tmpdata) + self.write_data(dec_data, self.filename) + else: raise errors.AnsibleError("%s is not encrypted" % self.filename) - #decrypt to string - data = self._decrypt_to_string() + def edit_file(self): - # verify sha and then strip it out - if not self._verify_decryption(data): - raise errors.AnsibleError("decryption of %s failed" % self.filename) - this_sha, clean_data = self._strip_sha(data) + # decrypt to tmpfile + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + dec_data = this_vault.decrypt(tmpdata) + _, tmp_path = tempfile.mkstemp() + self.write_data(dec_data, tmp_path) - # rewrite file without sha - _, in_path = tempfile.mkstemp() - f = open(in_path, "wb") - tmpdata = f.write(clean_data) - f.close() - - # drop the user into vim on the unencrypted tmp file + # drop the user into vim on the tmp file EDITOR = os.environ.get('EDITOR','vim') - call([EDITOR, in_path]) + call([EDITOR, tmp_path]) + new_data = self.read_data(tmp_path) - f = open(in_path, "rb") - tmpdata = f.read() - f.close() + # create new vault and set cipher to old + new_vault = VaultLib(self.password) + new_vault.cipher_name = this_vault.cipher_name - self._string_to_encrypted_file(tmpdata, self.filename) + # encrypt new data a write out to tmp + enc_data = new_vault.encrypt(new_data) + self.write_data(enc_data, tmp_path) + # shuffle tmp file into place + self.shuffle_files(tmp_path, self.filename) - def encrypt(self): - """ encrypt a file inplace """ - - if is_encrypted(self.filename): + def encrypt_file(self): + if not os.path.isfile(self.filename): + raise errors.AnsibleError("%s does not exist" % self.filename) + + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + this_vault.cipher_name = self.cipher_name + if not this_vault.is_encrypted(tmpdata): + enc_data = this_vault.encrypt(tmpdata) + self.write_data(enc_data, self.filename) + else: raise errors.AnsibleError("%s is already encrypted" % self.filename) - #self.eval_header() - self.__load_cipher() + def rekey_file(self, new_password): + # decrypt + tmpdata = self.read_data(self.filename) + this_vault = VaultLib(self.password) + dec_data = this_vault.decrypt(tmpdata) - # read data - f = open(self.filename, "rb") + # create new vault, set cipher to old and password to new + new_vault = VaultLib(new_password) + new_vault.cipher_name = this_vault.cipher_name + + # re-encrypt data and re-write file + enc_data = new_vault.encrypt(dec_data) + self.write_data(enc_data, self.filename) + + def read_data(self, filename): + f = open(filename, "rb") tmpdata = f.read() f.close() + return tmpdata - self._string_to_encrypted_file(tmpdata, self.filename) - - - def rekey(self, newpassword): - - """ unencrypt file then encrypt with new password """ - - if not is_encrypted(self.filename): - raise errors.AnsibleError("%s is not encrypted" % self.filename) - - # unencrypt to string with old password - data = self._decrypt_to_string() - - # verify sha and then strip it out - if not self._verify_decryption(data): - raise errors.AnsibleError("decryption of %s failed" % self.filename) - this_sha, clean_data = self._strip_sha(data) - - # set password - self.vault_password = newpassword - - self._string_to_encrypted_file(clean_data, self.filename) - - - ############### - # PRIVATE - ############### - - def __load_cipher(self): - - """ - Load a cipher class by it's name - - This is a lightweight "plugin" implementation to allow - for future support of other cipher types - """ - - whitelist = ['AES'] - - if self.cipher in whitelist: - self.cipher_obj = None - if self.cipher in globals(): - this_cipher = globals()[self.cipher] - self.cipher_obj = this_cipher() - else: - raise errors.AnsibleError("%s cipher could not be loaded" % self.cipher) - else: - raise errors.AnsibleError("%s is not an allowed encryption cipher" % self.cipher) - - - - def _decrypt_to_string(self): - - """ decrypt file to string """ - - if not is_encrypted(self.filename): - raise errors.AnsibleError("%s is not encrypted" % self.filename) - - # figure out what this is - self.eval_header() - self.__load_cipher() - - # strip out header and unhex the file - clean_stream = self._dirty_file_to_clean_file(self.filename) - - # reset pointer - clean_stream.seek(0) - - # create a byte stream to hold unencrypted - dst = BytesIO() - - # decrypt from src stream to dst stream - self.cipher_obj.decrypt(clean_stream, dst, self.vault_password) - - # read data from the unencrypted stream - data = dst.read() - - return data - - def _dirty_file_to_clean_file(self, dirty_filename): - """ Strip out headers from a file, unhex and write to new file""" - - - _, in_path = tempfile.mkstemp() - #_, out_path = tempfile.mkstemp() - - # strip header from data, write rest to tmp file - f = open(dirty_filename, "rb") - tmpdata = f.readlines() - f.close() - - tmpheader = tmpdata[0].strip() - tmpdata = ''.join(tmpdata[1:]) - - # strip out newline, join, unhex - tmpdata = [ x.strip() for x in tmpdata ] - tmpdata = unhexlify(''.join(tmpdata)) - - # create and return stream - clean_stream = BytesIO(tmpdata) - return clean_stream - - def _clean_stream_to_dirty_stream(self, clean_stream): - - # combine header and hexlified encrypted data in 80 char columns - clean_stream.seek(0) - tmpdata = clean_stream.read() - tmpdata = hexlify(tmpdata) - tmpdata = [tmpdata[i:i+80] for i in range(0, len(tmpdata), 80)] - - dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher + "\n" - for l in tmpdata: - dirty_data += l + '\n' - - dirty_stream = BytesIO(dirty_data) - return dirty_stream - - def _string_to_encrypted_file(self, tmpdata, filename): - - """ Write a string of data to a file with the format ... - - HEADER;VERSION;CIPHER - HEX(ENCRYPTED(SHA256(STRING)+STRING)) - """ - - # sha256 the data - this_sha = sha256(tmpdata).hexdigest() - - # combine sha + data to tmpfile - tmpdata = this_sha + "\n" + tmpdata - src_stream = BytesIO(tmpdata) - dst_stream = BytesIO() - - # encrypt tmpfile - self.cipher_obj.encrypt(src_stream, dst_stream, self.password) - - # hexlify tmpfile and combine with header - dirty_stream = self._clean_stream_to_dirty_stream(dst_stream) - - if os.path.isfile(filename): + def write_data(self, data, filename): + if os.path.isfile(filename): os.remove(filename) - - # write back to original file - dirty_stream.seek(0) f = open(filename, "wb") - f.write(dirty_stream.read()) + f.write(data) f.close() + def shuffle_files(self, src, dest): + # overwrite dest with src + if os.path.isfile(dest): + os.remove(dest) + shutil.move(src, dest) - def _verify_decryption(self, data): +######################################## +# CIPHERS # +######################################## - """ Split data to sha/data and check the sha """ - - # split the sha and other data - this_sha, clean_data = self._strip_sha(data) - - # does the decrypted data match the sha ? - clean_sha = sha256(clean_data).hexdigest() - - # compare, return result - if this_sha == clean_sha: - return True - else: - return False - - def _strip_sha(self, data): - # is the first line a sha? - lines = data.split("\n") - this_sha = lines[0] - - clean_data = '\n'.join(lines[1:]) - return this_sha, clean_data - - -class AES(object): +class VaultAES(object): # http://stackoverflow.com/a/16761459 @@ -400,18 +271,22 @@ class AES(object): return key, iv - def encrypt(self, in_file, out_file, password, key_length=32): + def encrypt(self, data, password, key_length=32): """ Read plaintext data from in_file and write encrypted to out_file """ - bs = AES_.block_size + in_file = BytesIO(data) + in_file.seek(0) + out_file = BytesIO() + + bs = AES.block_size # Get a block of random data. EL does not have Crypto.Random.new() # so os.urandom is used for cross platform purposes salt = os.urandom(bs - len('Salted__')) key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) - cipher = AES_.new(key, AES_.MODE_CBC, iv) + cipher = AES.new(key, AES.MODE_CBC, iv) out_file.write('Salted__' + salt) finished = False while not finished: @@ -422,16 +297,23 @@ class AES(object): finished = True out_file.write(cipher.encrypt(chunk)) - def decrypt(self, in_file, out_file, password, key_length=32): + out_file.seek(0) + return out_file.read() + + def decrypt(self, data, password, key_length=32): """ Read encrypted data from in_file and write decrypted to out_file """ # http://stackoverflow.com/a/14989032 - bs = AES_.block_size + in_file = BytesIO(data) + in_file.seek(0) + out_file = BytesIO() + + bs = AES.block_size salt = in_file.read(bs)[len('Salted__'):] key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) - cipher = AES_.new(key, AES_.MODE_CBC, iv) + cipher = AES.new(key, AES.MODE_CBC, iv) next_chunk = '' finished = False @@ -444,5 +326,7 @@ class AES(object): out_file.write(chunk) # reset the stream pointer to the beginning - if hasattr(out_file, 'seek'): - out_file.seek(0) + out_file.seek(0) + return out_file.read() + +