From 28227546fa65c8925f8eb0b5c23aff3378a23222 Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Thu, 1 Sep 2016 04:19:03 -0700 Subject: [PATCH] Various python3 updates for module_utils: (#17345) * Port set_*_if_different functions to python3 * Add surrogate_or_strict and surrogate_or_replace error handlers for to_text, to_bytes, to_native * Set default error handler to surrogate_or_replace * Make use of the new error handlers in the already ported code * Move the unittests for module_utils._text as they aren't in basic.py * Cleanup around SEQUENCETYPE. On python2.6+ SEQUENCETYPE includes strings so make sure code omits those explicitly if necessary * Allow arg_spec aliases to be other sequence types --- lib/ansible/module_utils/_text.py | 45 +++++++- lib/ansible/module_utils/basic.py | 106 +++++++++--------- lib/ansible/module_utils/facts.py | 4 +- test/units/module_utils/test_basic.py | 8 +- .../module_utils/{basic => }/test_text.py | 0 5 files changed, 101 insertions(+), 62 deletions(-) rename test/units/module_utils/{basic => }/test_text.py (100%) diff --git a/lib/ansible/module_utils/_text.py b/lib/ansible/module_utils/_text.py index 7dec647580..fe5acf7f27 100644 --- a/lib/ansible/module_utils/_text.py +++ b/lib/ansible/module_utils/_text.py @@ -35,6 +35,13 @@ from ansible.module_utils.six import PY3, text_type, binary_type +import codecs +try: + codecs.lookup_error('surrogateescape') + HAS_SURROGATEESCAPE = True +except LookupError: + HAS_SURROGATEESCAPE = False + def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): """Make sure that a string is a byte string @@ -47,8 +54,22 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): :kwarg errors: The error handler to use if the text string is not encodable using the specified encoding. Any valid `codecs error handler `_ - may be specified. On Python3 this defaults to 'surrogateescape'. On - Python2, this defaults to 'replace'. + may be specified. There are two additional error strategies + specifically aimed at helping people to port code: + + :surrogate_or_strict: Will use surrogateescape if it is a valid + handler, otherwise it will use strict + :surrogate_or_replace: Will use surrogateescape if it is a valid + handler, otherwise it will use replace. + + Because surrogateescape was added in Python3 this usually means that + Python3 will use surrogateescape and Python2 will use the fallback + error handler. Note that the code checks for surrogateescape when the + module is imported. If you have a backport of surrogateescape for + python2, be sure to register the error handler prior to importing this + module. + + The default is `surrogate_or_replace` :kwarg nonstring: The strategy to use if a nonstring is specified in ``obj``. Default is 'simplerepr'. Valid values are: @@ -71,11 +92,16 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): if isinstance(obj, binary_type): return obj - if errors is None: - if PY3: + if errors in (None, 'surrogate_or_replace'): + if HAS_SURROGATEESCAPE: errors = 'surrogateescape' else: errors = 'replace' + elif errors == 'surrogate_or_strict': + if HAS_SURROGATEESCAPE: + errors = 'surrogateescape' + else: + errors = 'strict' if isinstance(obj, text_type): return obj.encode(encoding, errors) @@ -126,6 +152,17 @@ def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): if isinstance(obj, text_type): return obj + if errors in (None, 'surrogate_or_replace'): + if HAS_SURROGATEESCAPE: + errors = 'surrogateescape' + else: + errors = 'replace' + elif errors == 'surrogate_or_strict': + if HAS_SURROGATEESCAPE: + errors = 'surrogateescape' + else: + errors = 'strict' + if errors is None: if PY3: errors = 'surrogateescape' diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index 3493d5ad13..c940d6459a 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -87,6 +87,8 @@ except ImportError: Sequence = (list, tuple) Mapping = (dict,) +# Note: When getting Sequence from collections, it matches with strings. If +# this matters, make sure to check for strings before checking for sequencetype try: from collections.abc import KeysView SEQUENCETYPE = (Sequence, KeysView) @@ -292,7 +294,7 @@ def load_platform_subclass(cls, *args, **kwargs): return super(cls, subclass).__new__(subclass) -def json_dict_unicode_to_bytes(d, encoding='utf-8'): +def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'): ''' Recursively convert dict keys and values to byte str Specialized for json return because this only handles, lists, tuples, @@ -300,17 +302,17 @@ def json_dict_unicode_to_bytes(d, encoding='utf-8'): ''' if isinstance(d, text_type): - return d.encode(encoding) + return to_bytes(d, encoding=encoding, errors=errors) elif isinstance(d, dict): - return dict(map(json_dict_unicode_to_bytes, iteritems(d), repeat(encoding))) + return dict(map(json_dict_unicode_to_bytes, iteritems(d), repeat(encoding), repeat(errors))) elif isinstance(d, list): - return list(map(json_dict_unicode_to_bytes, d, repeat(encoding))) + return list(map(json_dict_unicode_to_bytes, d, repeat(encoding), repeat(errors))) elif isinstance(d, tuple): - return tuple(map(json_dict_unicode_to_bytes, d, repeat(encoding))) + return tuple(map(json_dict_unicode_to_bytes, d, repeat(encoding), repeat(errors))) else: return d -def json_dict_bytes_to_unicode(d, encoding='utf-8'): +def json_dict_bytes_to_unicode(d, encoding='utf-8', errors='surrogate_or_strict'): ''' Recursively convert dict keys and values to byte str Specialized for json return because this only handles, lists, tuples, @@ -319,13 +321,13 @@ def json_dict_bytes_to_unicode(d, encoding='utf-8'): if isinstance(d, binary_type): # Warning, can traceback - return d.decode(encoding) + return to_text(d, encoding=encoding, errors=errors) elif isinstance(d, dict): - return dict(map(json_dict_bytes_to_unicode, iteritems(d), repeat(encoding))) + return dict(map(json_dict_bytes_to_unicode, iteritems(d), repeat(encoding), repeat(errors))) elif isinstance(d, list): - return list(map(json_dict_bytes_to_unicode, d, repeat(encoding))) + return list(map(json_dict_bytes_to_unicode, d, repeat(encoding), repeat(errors))) elif isinstance(d, tuple): - return tuple(map(json_dict_bytes_to_unicode, d, repeat(encoding))) + return tuple(map(json_dict_bytes_to_unicode, d, repeat(encoding), repeat(errors))) else: return d @@ -335,14 +337,7 @@ def return_values(obj): For use with removing sensitive values pre-jsonification.""" if isinstance(obj, (text_type, binary_type)): if obj: - if isinstance(obj, text_type) and PY2: - # Unicode objects should all convert to utf-8 - yield obj.encode('utf-8') - elif isinstance(obj, binary_type) and PY3: - yield obj.decode('utf-8', 'surrogateescape') - else: - # Already native string for this python version - yield obj + yield to_native(obj, errors='surrogate_or_strict') return elif isinstance(obj, SEQUENCETYPE): for element in obj: @@ -356,7 +351,7 @@ def return_values(obj): # This must come before int because bools are also ints return elif isinstance(obj, NUMBERTYPES): - yield str(obj) + yield to_native(obj, nonstring='simplerepr') else: raise TypeError('Unknown parameter type: %s, %s' % (type(obj), obj)) @@ -369,11 +364,11 @@ def remove_values(value, no_log_strings): if isinstance(value, text_type): value_is_text = True if PY2: - native_str_value = value.encode('utf-8') + native_str_value = to_bytes(value, encoding='utf-8', errors='surrogate_or_strict') elif isinstance(value, binary_type): value_is_text = False if PY3: - native_str_value = value.decode('utf-8', 'surrogateescape') + native_str_value = to_text(value, encoding='utf-8', errors='surrogate_or_strict') if native_str_value in no_log_strings: return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' @@ -381,9 +376,9 @@ def remove_values(value, no_log_strings): native_str_value = native_str_value.replace(omit_me, '*' * 8) if value_is_text and isinstance(native_str_value, binary_type): - value = native_str_value.decode('utf-8', 'replace') + value = to_text(native_str_value, encoding='utf-8', errors='surrogate_or_replace') elif not value_is_text and isinstance(native_str_value, text_type): - value = native_str_value.encode('utf-8', 'surrogateescape') + value = to_bytes(native_str_value, encoding='utf-8', errors='surrogate_or_replace') else: value = native_str_value elif isinstance(value, SEQUENCETYPE): @@ -391,7 +386,7 @@ def remove_values(value, no_log_strings): elif isinstance(value, Mapping): return dict((k, remove_values(v, no_log_strings)) for k, v in value.items()) elif isinstance(value, tuple(chain(NUMBERTYPES, (bool, NoneType)))): - stringy_value = str(value) + stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict') if stringy_value in no_log_strings: return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' for omit_me in no_log_strings: @@ -735,12 +730,14 @@ class AnsibleModule(object): if path is None: return {} else: - path = os.path.expanduser(path) + path = os.path.expanduser(os.path.expandvars(path)) + b_path = to_bytes(path, errors='surrogate_or_strict') # if the path is a symlink, and we're following links, get # the target of the link instead for testing - if params.get('follow', False) and os.path.islink(path): - path = os.path.realpath(path) + if params.get('follow', False) and os.path.islink(b_path): + b_path = os.path.realpath(b_path) + path = to_native(b_path) mode = params.get('mode', None) owner = params.get('owner', None) @@ -838,8 +835,9 @@ class AnsibleModule(object): return context def user_and_group(self, filename): - filename = os.path.expanduser(filename) - st = os.lstat(filename) + filename = os.path.expanduser(os.path.expandvars(filename)) + b_filename = to_bytes(filename, errors='surrogate_or_strict') + st = os.lstat(b_filename) uid = st.st_uid gid = st.st_gid return (uid, gid) @@ -922,7 +920,8 @@ class AnsibleModule(object): return changed def set_owner_if_different(self, path, owner, changed, diff=None): - path = os.path.expanduser(path) + path = os.path.expanduser(os.path.expandvars(path)) + b_path = to_bytes(path, errors='surrogate_or_strict') if owner is None: return changed orig_uid, orig_gid = self.user_and_group(path) @@ -946,17 +945,18 @@ class AnsibleModule(object): if self.check_mode: return True try: - os.lchown(path, uid, -1) + os.lchown(b_path, uid, -1) except OSError: self.fail_json(path=path, msg='chown failed') changed = True return changed def set_group_if_different(self, path, group, changed, diff=None): - path = os.path.expanduser(path) + path = os.path.expanduser(os.path.expandvars(path)) + b_path = to_bytes(path, errors='surrogate_or_strict') if group is None: return changed - orig_uid, orig_gid = self.user_and_group(path) + orig_uid, orig_gid = self.user_and_group(b_path) try: gid = int(group) except ValueError: @@ -977,15 +977,15 @@ class AnsibleModule(object): if self.check_mode: return True try: - os.lchown(path, -1, gid) + os.lchown(b_path, -1, gid) except OSError: self.fail_json(path=path, msg='chgrp failed') changed = True return changed def set_mode_if_different(self, path, mode, changed, diff=None): - b_path = to_bytes(path) - b_path = os.path.expanduser(b_path) + b_path = to_bytes(path, errors='surrogate_or_strict') + b_path = os.path.expanduser(os.path.expandvars(b_path)) path_stat = os.lstat(b_path) if mode is None: @@ -1183,7 +1183,8 @@ class AnsibleModule(object): path = kwargs.get('path', kwargs.get('dest', None)) if path is None: return kwargs - if os.path.exists(path): + b_path = to_bytes(path, errors='surrogate_or_strict') + if os.path.exists(b_path): (uid, gid) = self.user_and_group(path) kwargs['uid'] = uid kwargs['gid'] = gid @@ -1197,14 +1198,14 @@ class AnsibleModule(object): group = str(gid) kwargs['owner'] = user kwargs['group'] = group - st = os.lstat(path) + st = os.lstat(b_path) kwargs['mode'] = '0%03o' % stat.S_IMODE(st[stat.ST_MODE]) # secontext not yet supported - if os.path.islink(path): + if os.path.islink(b_path): kwargs['state'] = 'link' - elif os.path.isdir(path): + elif os.path.isdir(b_path): kwargs['state'] = 'directory' - elif os.stat(path).st_nlink > 1: + elif os.stat(b_path).st_nlink > 1: kwargs['state'] = 'hard' else: kwargs['state'] = 'file' @@ -1249,8 +1250,8 @@ class AnsibleModule(object): raise Exception("internal error: required and default are mutually exclusive for %s" % k) if aliases is None: continue - if type(aliases) != list: - raise Exception('internal error: aliases must be a list') + if not isinstance(aliases, SEQUENCETYPE) or isinstance(aliases, (binary_type, text_type)): + raise Exception('internal error: aliases must be a list or tuple') for alias in aliases: self._legal_inputs.append(alias) aliases_results[alias] = k @@ -1363,10 +1364,11 @@ class AnsibleModule(object): choices = v.get('choices',None) if choices is None: continue - if isinstance(choices, SEQUENCETYPE): + if isinstance(choices, SEQUENCETYPE) and not isinstance(choices, (binary_type, text_type)): if k in self.params: if self.params[k] not in choices: - # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking the value. If we can't figure this out, module author is responsible. + # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking + # the value. If we can't figure this out, module author is responsible. lowered_choices = None if self.params[k] == 'False': lowered_choices = _lenient_lowercase(choices) @@ -1385,7 +1387,7 @@ class AnsibleModule(object): (self.params[k],) = overlap if self.params[k] not in choices: - choices_str=",".join([str(c) for c in choices]) + choices_str=",".join([to_native(c) for c in choices]) msg="value of %s must be one of: %s, got: %s" % (k, choices_str, self.params[k]) self.fail_json(msg=msg) else: @@ -1749,7 +1751,7 @@ class AnsibleModule(object): def boolean(self, arg): ''' return a bool for the arg ''' - if arg is None or type(arg) == bool: + if arg is None or isinstance(arg, bool): return arg if isinstance(arg, string_types): arg = arg.lower() @@ -1903,8 +1905,8 @@ class AnsibleModule(object): to work around limitations, corner cases and ensure selinux context is saved if possible''' context = None dest_stat = None - b_src = to_bytes(src) - b_dest = to_bytes(dest) + b_src = to_bytes(src, errors='surrogate_or_strict') + b_dest = to_bytes(dest, errors='surrogate_or_strict') if os.path.exists(b_dest): try: dest_stat = os.stat(b_dest) @@ -1957,7 +1959,7 @@ class AnsibleModule(object): except (OSError, IOError): e = get_exception() self.fail_json(msg='The destination directory (%s) is not writable by the current user. Error was: %s' % (os.path.dirname(dest), e)) - b_tmp_dest_name = to_bytes(tmp_dest_name) + b_tmp_dest_name = to_bytes(tmp_dest_name, errors='surrogate_or_strict') try: try: @@ -2056,7 +2058,7 @@ class AnsibleModule(object): # On python2.6 and below, shlex has problems with text type # On python3, shlex needs a text type. if PY2: - args = to_bytes(args) + args = to_bytes(args, errors='surrogate_or_strict') elif PY3: args = to_text(args, errors='surrogateescape') args = shlex.split(args) @@ -2070,7 +2072,7 @@ class AnsibleModule(object): if PY3: prompt_regex = to_bytes(prompt_regex, errors='surrogateescape') elif PY2: - prompt_regex = to_bytes(prompt_regex) + prompt_regex = to_bytes(prompt_regex, errors='surrogate_or_strict') try: prompt_re = re.compile(prompt_regex, re.MULTILINE) except re.error: diff --git a/lib/ansible/module_utils/facts.py b/lib/ansible/module_utils/facts.py index a173e36dc1..3103f6f86e 100644 --- a/lib/ansible/module_utils/facts.py +++ b/lib/ansible/module_utils/facts.py @@ -34,7 +34,7 @@ import pwd from ansible.module_utils.basic import get_all_subclasses from ansible.module_utils.six import PY3, iteritems -from ansible.module_utils._text import to_text +from ansible.module_utils._text import to_native # py2 vs py3; replace with six via ansiballz try: @@ -358,7 +358,7 @@ class Facts(object): proc_1 = os.path.basename(proc_1) if proc_1 is not None: - proc_1 = to_text(proc_1) + proc_1 = to_native(proc_1) proc_1 = proc_1.strip() if proc_1 == 'init' or proc_1.endswith('sh'): diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index a7fc328d64..25400ed1e1 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -678,7 +678,7 @@ class TestModuleUtilsBasic(ModuleTestCase): with patch('os.lchown', return_value=None) as m: self.assertEqual(am.set_owner_if_different('/path/to/file', 0, False), True) - m.assert_called_with('/path/to/file', 0, -1) + m.assert_called_with(b'/path/to/file', 0, -1) def _mock_getpwnam(*args, **kwargs): mock_pw = MagicMock() @@ -688,7 +688,7 @@ class TestModuleUtilsBasic(ModuleTestCase): m.reset_mock() with patch('pwd.getpwnam', side_effect=_mock_getpwnam): self.assertEqual(am.set_owner_if_different('/path/to/file', 'root', False), True) - m.assert_called_with('/path/to/file', 0, -1) + m.assert_called_with(b'/path/to/file', 0, -1) with patch('pwd.getpwnam', side_effect=KeyError): self.assertRaises(SystemExit, am.set_owner_if_different, '/path/to/file', 'root', False) @@ -717,7 +717,7 @@ class TestModuleUtilsBasic(ModuleTestCase): with patch('os.lchown', return_value=None) as m: self.assertEqual(am.set_group_if_different('/path/to/file', 0, False), True) - m.assert_called_with('/path/to/file', -1, 0) + m.assert_called_with(b'/path/to/file', -1, 0) def _mock_getgrnam(*args, **kwargs): mock_gr = MagicMock() @@ -727,7 +727,7 @@ class TestModuleUtilsBasic(ModuleTestCase): m.reset_mock() with patch('grp.getgrnam', side_effect=_mock_getgrnam): self.assertEqual(am.set_group_if_different('/path/to/file', 'root', False), True) - m.assert_called_with('/path/to/file', -1, 0) + m.assert_called_with(b'/path/to/file', -1, 0) with patch('grp.getgrnam', side_effect=KeyError): self.assertRaises(SystemExit, am.set_group_if_different, '/path/to/file', 'root', False) diff --git a/test/units/module_utils/basic/test_text.py b/test/units/module_utils/test_text.py similarity index 100% rename from test/units/module_utils/basic/test_text.py rename to test/units/module_utils/test_text.py