1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Various python3 updates for module_utils: (#17345)

* Port set_*_if_different functions to python3
* Add surrogate_or_strict and surrogate_or_replace error handlers for
  to_text, to_bytes, to_native
* Set default error handler to surrogate_or_replace
* Make use of the new error handlers in the already ported code
* Move the unittests for module_utils._text as they aren't in basic.py
* Cleanup around SEQUENCETYPE.  On python2.6+ SEQUENCETYPE includes
  strings so make sure code omits those explicitly if necessary
* Allow arg_spec aliases to be other sequence types
This commit is contained in:
Toshio Kuratomi 2016-09-01 04:19:03 -07:00 committed by GitHub
parent d8f0ceee61
commit 28227546fa
5 changed files with 101 additions and 62 deletions

View file

@ -35,6 +35,13 @@
from ansible.module_utils.six import PY3, text_type, binary_type from ansible.module_utils.six import PY3, text_type, binary_type
import codecs
try:
codecs.lookup_error('surrogateescape')
HAS_SURROGATEESCAPE = True
except LookupError:
HAS_SURROGATEESCAPE = False
def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
"""Make sure that a string is a byte string """Make sure that a string is a byte string
@ -47,8 +54,22 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
:kwarg errors: The error handler to use if the text string is not :kwarg errors: The error handler to use if the text string is not
encodable using the specified encoding. Any valid `codecs error encodable using the specified encoding. Any valid `codecs error
handler <https://docs.python.org/2/library/codecs.html#codec-base-classes>`_ handler <https://docs.python.org/2/library/codecs.html#codec-base-classes>`_
may be specified. On Python3 this defaults to 'surrogateescape'. On may be specified. There are two additional error strategies
Python2, this defaults to 'replace'. specifically aimed at helping people to port code:
:surrogate_or_strict: Will use surrogateescape if it is a valid
handler, otherwise it will use strict
:surrogate_or_replace: Will use surrogateescape if it is a valid
handler, otherwise it will use replace.
Because surrogateescape was added in Python3 this usually means that
Python3 will use surrogateescape and Python2 will use the fallback
error handler. Note that the code checks for surrogateescape when the
module is imported. If you have a backport of surrogateescape for
python2, be sure to register the error handler prior to importing this
module.
The default is `surrogate_or_replace`
:kwarg nonstring: The strategy to use if a nonstring is specified in :kwarg nonstring: The strategy to use if a nonstring is specified in
``obj``. Default is 'simplerepr'. Valid values are: ``obj``. Default is 'simplerepr'. Valid values are:
@ -71,11 +92,16 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
if isinstance(obj, binary_type): if isinstance(obj, binary_type):
return obj return obj
if errors is None: if errors in (None, 'surrogate_or_replace'):
if PY3: if HAS_SURROGATEESCAPE:
errors = 'surrogateescape' errors = 'surrogateescape'
else: else:
errors = 'replace' errors = 'replace'
elif errors == 'surrogate_or_strict':
if HAS_SURROGATEESCAPE:
errors = 'surrogateescape'
else:
errors = 'strict'
if isinstance(obj, text_type): if isinstance(obj, text_type):
return obj.encode(encoding, errors) return obj.encode(encoding, errors)
@ -126,6 +152,17 @@ def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
if isinstance(obj, text_type): if isinstance(obj, text_type):
return obj return obj
if errors in (None, 'surrogate_or_replace'):
if HAS_SURROGATEESCAPE:
errors = 'surrogateescape'
else:
errors = 'replace'
elif errors == 'surrogate_or_strict':
if HAS_SURROGATEESCAPE:
errors = 'surrogateescape'
else:
errors = 'strict'
if errors is None: if errors is None:
if PY3: if PY3:
errors = 'surrogateescape' errors = 'surrogateescape'

View file

@ -87,6 +87,8 @@ except ImportError:
Sequence = (list, tuple) Sequence = (list, tuple)
Mapping = (dict,) Mapping = (dict,)
# Note: When getting Sequence from collections, it matches with strings. If
# this matters, make sure to check for strings before checking for sequencetype
try: try:
from collections.abc import KeysView from collections.abc import KeysView
SEQUENCETYPE = (Sequence, KeysView) SEQUENCETYPE = (Sequence, KeysView)
@ -292,7 +294,7 @@ def load_platform_subclass(cls, *args, **kwargs):
return super(cls, subclass).__new__(subclass) return super(cls, subclass).__new__(subclass)
def json_dict_unicode_to_bytes(d, encoding='utf-8'): def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'):
''' Recursively convert dict keys and values to byte str ''' Recursively convert dict keys and values to byte str
Specialized for json return because this only handles, lists, tuples, Specialized for json return because this only handles, lists, tuples,
@ -300,17 +302,17 @@ def json_dict_unicode_to_bytes(d, encoding='utf-8'):
''' '''
if isinstance(d, text_type): if isinstance(d, text_type):
return d.encode(encoding) return to_bytes(d, encoding=encoding, errors=errors)
elif isinstance(d, dict): elif isinstance(d, dict):
return dict(map(json_dict_unicode_to_bytes, iteritems(d), repeat(encoding))) return dict(map(json_dict_unicode_to_bytes, iteritems(d), repeat(encoding), repeat(errors)))
elif isinstance(d, list): elif isinstance(d, list):
return list(map(json_dict_unicode_to_bytes, d, repeat(encoding))) return list(map(json_dict_unicode_to_bytes, d, repeat(encoding), repeat(errors)))
elif isinstance(d, tuple): elif isinstance(d, tuple):
return tuple(map(json_dict_unicode_to_bytes, d, repeat(encoding))) return tuple(map(json_dict_unicode_to_bytes, d, repeat(encoding), repeat(errors)))
else: else:
return d return d
def json_dict_bytes_to_unicode(d, encoding='utf-8'): def json_dict_bytes_to_unicode(d, encoding='utf-8', errors='surrogate_or_strict'):
''' Recursively convert dict keys and values to byte str ''' Recursively convert dict keys and values to byte str
Specialized for json return because this only handles, lists, tuples, Specialized for json return because this only handles, lists, tuples,
@ -319,13 +321,13 @@ def json_dict_bytes_to_unicode(d, encoding='utf-8'):
if isinstance(d, binary_type): if isinstance(d, binary_type):
# Warning, can traceback # Warning, can traceback
return d.decode(encoding) return to_text(d, encoding=encoding, errors=errors)
elif isinstance(d, dict): elif isinstance(d, dict):
return dict(map(json_dict_bytes_to_unicode, iteritems(d), repeat(encoding))) return dict(map(json_dict_bytes_to_unicode, iteritems(d), repeat(encoding), repeat(errors)))
elif isinstance(d, list): elif isinstance(d, list):
return list(map(json_dict_bytes_to_unicode, d, repeat(encoding))) return list(map(json_dict_bytes_to_unicode, d, repeat(encoding), repeat(errors)))
elif isinstance(d, tuple): elif isinstance(d, tuple):
return tuple(map(json_dict_bytes_to_unicode, d, repeat(encoding))) return tuple(map(json_dict_bytes_to_unicode, d, repeat(encoding), repeat(errors)))
else: else:
return d return d
@ -335,14 +337,7 @@ def return_values(obj):
For use with removing sensitive values pre-jsonification.""" For use with removing sensitive values pre-jsonification."""
if isinstance(obj, (text_type, binary_type)): if isinstance(obj, (text_type, binary_type)):
if obj: if obj:
if isinstance(obj, text_type) and PY2: yield to_native(obj, errors='surrogate_or_strict')
# Unicode objects should all convert to utf-8
yield obj.encode('utf-8')
elif isinstance(obj, binary_type) and PY3:
yield obj.decode('utf-8', 'surrogateescape')
else:
# Already native string for this python version
yield obj
return return
elif isinstance(obj, SEQUENCETYPE): elif isinstance(obj, SEQUENCETYPE):
for element in obj: for element in obj:
@ -356,7 +351,7 @@ def return_values(obj):
# This must come before int because bools are also ints # This must come before int because bools are also ints
return return
elif isinstance(obj, NUMBERTYPES): elif isinstance(obj, NUMBERTYPES):
yield str(obj) yield to_native(obj, nonstring='simplerepr')
else: else:
raise TypeError('Unknown parameter type: %s, %s' % (type(obj), obj)) raise TypeError('Unknown parameter type: %s, %s' % (type(obj), obj))
@ -369,11 +364,11 @@ def remove_values(value, no_log_strings):
if isinstance(value, text_type): if isinstance(value, text_type):
value_is_text = True value_is_text = True
if PY2: if PY2:
native_str_value = value.encode('utf-8') native_str_value = to_bytes(value, encoding='utf-8', errors='surrogate_or_strict')
elif isinstance(value, binary_type): elif isinstance(value, binary_type):
value_is_text = False value_is_text = False
if PY3: if PY3:
native_str_value = value.decode('utf-8', 'surrogateescape') native_str_value = to_text(value, encoding='utf-8', errors='surrogate_or_strict')
if native_str_value in no_log_strings: if native_str_value in no_log_strings:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
@ -381,9 +376,9 @@ def remove_values(value, no_log_strings):
native_str_value = native_str_value.replace(omit_me, '*' * 8) native_str_value = native_str_value.replace(omit_me, '*' * 8)
if value_is_text and isinstance(native_str_value, binary_type): if value_is_text and isinstance(native_str_value, binary_type):
value = native_str_value.decode('utf-8', 'replace') value = to_text(native_str_value, encoding='utf-8', errors='surrogate_or_replace')
elif not value_is_text and isinstance(native_str_value, text_type): elif not value_is_text and isinstance(native_str_value, text_type):
value = native_str_value.encode('utf-8', 'surrogateescape') value = to_bytes(native_str_value, encoding='utf-8', errors='surrogate_or_replace')
else: else:
value = native_str_value value = native_str_value
elif isinstance(value, SEQUENCETYPE): elif isinstance(value, SEQUENCETYPE):
@ -391,7 +386,7 @@ def remove_values(value, no_log_strings):
elif isinstance(value, Mapping): elif isinstance(value, Mapping):
return dict((k, remove_values(v, no_log_strings)) for k, v in value.items()) return dict((k, remove_values(v, no_log_strings)) for k, v in value.items())
elif isinstance(value, tuple(chain(NUMBERTYPES, (bool, NoneType)))): elif isinstance(value, tuple(chain(NUMBERTYPES, (bool, NoneType)))):
stringy_value = str(value) stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict')
if stringy_value in no_log_strings: if stringy_value in no_log_strings:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
for omit_me in no_log_strings: for omit_me in no_log_strings:
@ -735,12 +730,14 @@ class AnsibleModule(object):
if path is None: if path is None:
return {} return {}
else: else:
path = os.path.expanduser(path) path = os.path.expanduser(os.path.expandvars(path))
b_path = to_bytes(path, errors='surrogate_or_strict')
# if the path is a symlink, and we're following links, get # if the path is a symlink, and we're following links, get
# the target of the link instead for testing # the target of the link instead for testing
if params.get('follow', False) and os.path.islink(path): if params.get('follow', False) and os.path.islink(b_path):
path = os.path.realpath(path) b_path = os.path.realpath(b_path)
path = to_native(b_path)
mode = params.get('mode', None) mode = params.get('mode', None)
owner = params.get('owner', None) owner = params.get('owner', None)
@ -838,8 +835,9 @@ class AnsibleModule(object):
return context return context
def user_and_group(self, filename): def user_and_group(self, filename):
filename = os.path.expanduser(filename) filename = os.path.expanduser(os.path.expandvars(filename))
st = os.lstat(filename) b_filename = to_bytes(filename, errors='surrogate_or_strict')
st = os.lstat(b_filename)
uid = st.st_uid uid = st.st_uid
gid = st.st_gid gid = st.st_gid
return (uid, gid) return (uid, gid)
@ -922,7 +920,8 @@ class AnsibleModule(object):
return changed return changed
def set_owner_if_different(self, path, owner, changed, diff=None): def set_owner_if_different(self, path, owner, changed, diff=None):
path = os.path.expanduser(path) path = os.path.expanduser(os.path.expandvars(path))
b_path = to_bytes(path, errors='surrogate_or_strict')
if owner is None: if owner is None:
return changed return changed
orig_uid, orig_gid = self.user_and_group(path) orig_uid, orig_gid = self.user_and_group(path)
@ -946,17 +945,18 @@ class AnsibleModule(object):
if self.check_mode: if self.check_mode:
return True return True
try: try:
os.lchown(path, uid, -1) os.lchown(b_path, uid, -1)
except OSError: except OSError:
self.fail_json(path=path, msg='chown failed') self.fail_json(path=path, msg='chown failed')
changed = True changed = True
return changed return changed
def set_group_if_different(self, path, group, changed, diff=None): def set_group_if_different(self, path, group, changed, diff=None):
path = os.path.expanduser(path) path = os.path.expanduser(os.path.expandvars(path))
b_path = to_bytes(path, errors='surrogate_or_strict')
if group is None: if group is None:
return changed return changed
orig_uid, orig_gid = self.user_and_group(path) orig_uid, orig_gid = self.user_and_group(b_path)
try: try:
gid = int(group) gid = int(group)
except ValueError: except ValueError:
@ -977,15 +977,15 @@ class AnsibleModule(object):
if self.check_mode: if self.check_mode:
return True return True
try: try:
os.lchown(path, -1, gid) os.lchown(b_path, -1, gid)
except OSError: except OSError:
self.fail_json(path=path, msg='chgrp failed') self.fail_json(path=path, msg='chgrp failed')
changed = True changed = True
return changed return changed
def set_mode_if_different(self, path, mode, changed, diff=None): def set_mode_if_different(self, path, mode, changed, diff=None):
b_path = to_bytes(path) b_path = to_bytes(path, errors='surrogate_or_strict')
b_path = os.path.expanduser(b_path) b_path = os.path.expanduser(os.path.expandvars(b_path))
path_stat = os.lstat(b_path) path_stat = os.lstat(b_path)
if mode is None: if mode is None:
@ -1183,7 +1183,8 @@ class AnsibleModule(object):
path = kwargs.get('path', kwargs.get('dest', None)) path = kwargs.get('path', kwargs.get('dest', None))
if path is None: if path is None:
return kwargs return kwargs
if os.path.exists(path): b_path = to_bytes(path, errors='surrogate_or_strict')
if os.path.exists(b_path):
(uid, gid) = self.user_and_group(path) (uid, gid) = self.user_and_group(path)
kwargs['uid'] = uid kwargs['uid'] = uid
kwargs['gid'] = gid kwargs['gid'] = gid
@ -1197,14 +1198,14 @@ class AnsibleModule(object):
group = str(gid) group = str(gid)
kwargs['owner'] = user kwargs['owner'] = user
kwargs['group'] = group kwargs['group'] = group
st = os.lstat(path) st = os.lstat(b_path)
kwargs['mode'] = '0%03o' % stat.S_IMODE(st[stat.ST_MODE]) kwargs['mode'] = '0%03o' % stat.S_IMODE(st[stat.ST_MODE])
# secontext not yet supported # secontext not yet supported
if os.path.islink(path): if os.path.islink(b_path):
kwargs['state'] = 'link' kwargs['state'] = 'link'
elif os.path.isdir(path): elif os.path.isdir(b_path):
kwargs['state'] = 'directory' kwargs['state'] = 'directory'
elif os.stat(path).st_nlink > 1: elif os.stat(b_path).st_nlink > 1:
kwargs['state'] = 'hard' kwargs['state'] = 'hard'
else: else:
kwargs['state'] = 'file' kwargs['state'] = 'file'
@ -1249,8 +1250,8 @@ class AnsibleModule(object):
raise Exception("internal error: required and default are mutually exclusive for %s" % k) raise Exception("internal error: required and default are mutually exclusive for %s" % k)
if aliases is None: if aliases is None:
continue continue
if type(aliases) != list: if not isinstance(aliases, SEQUENCETYPE) or isinstance(aliases, (binary_type, text_type)):
raise Exception('internal error: aliases must be a list') raise Exception('internal error: aliases must be a list or tuple')
for alias in aliases: for alias in aliases:
self._legal_inputs.append(alias) self._legal_inputs.append(alias)
aliases_results[alias] = k aliases_results[alias] = k
@ -1363,10 +1364,11 @@ class AnsibleModule(object):
choices = v.get('choices',None) choices = v.get('choices',None)
if choices is None: if choices is None:
continue continue
if isinstance(choices, SEQUENCETYPE): if isinstance(choices, SEQUENCETYPE) and not isinstance(choices, (binary_type, text_type)):
if k in self.params: if k in self.params:
if self.params[k] not in choices: if self.params[k] not in choices:
# PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking the value. If we can't figure this out, module author is responsible. # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking
# the value. If we can't figure this out, module author is responsible.
lowered_choices = None lowered_choices = None
if self.params[k] == 'False': if self.params[k] == 'False':
lowered_choices = _lenient_lowercase(choices) lowered_choices = _lenient_lowercase(choices)
@ -1385,7 +1387,7 @@ class AnsibleModule(object):
(self.params[k],) = overlap (self.params[k],) = overlap
if self.params[k] not in choices: if self.params[k] not in choices:
choices_str=",".join([str(c) for c in choices]) choices_str=",".join([to_native(c) for c in choices])
msg="value of %s must be one of: %s, got: %s" % (k, choices_str, self.params[k]) msg="value of %s must be one of: %s, got: %s" % (k, choices_str, self.params[k])
self.fail_json(msg=msg) self.fail_json(msg=msg)
else: else:
@ -1749,7 +1751,7 @@ class AnsibleModule(object):
def boolean(self, arg): def boolean(self, arg):
''' return a bool for the arg ''' ''' return a bool for the arg '''
if arg is None or type(arg) == bool: if arg is None or isinstance(arg, bool):
return arg return arg
if isinstance(arg, string_types): if isinstance(arg, string_types):
arg = arg.lower() arg = arg.lower()
@ -1903,8 +1905,8 @@ class AnsibleModule(object):
to work around limitations, corner cases and ensure selinux context is saved if possible''' to work around limitations, corner cases and ensure selinux context is saved if possible'''
context = None context = None
dest_stat = None dest_stat = None
b_src = to_bytes(src) b_src = to_bytes(src, errors='surrogate_or_strict')
b_dest = to_bytes(dest) b_dest = to_bytes(dest, errors='surrogate_or_strict')
if os.path.exists(b_dest): if os.path.exists(b_dest):
try: try:
dest_stat = os.stat(b_dest) dest_stat = os.stat(b_dest)
@ -1957,7 +1959,7 @@ class AnsibleModule(object):
except (OSError, IOError): except (OSError, IOError):
e = get_exception() e = get_exception()
self.fail_json(msg='The destination directory (%s) is not writable by the current user. Error was: %s' % (os.path.dirname(dest), e)) self.fail_json(msg='The destination directory (%s) is not writable by the current user. Error was: %s' % (os.path.dirname(dest), e))
b_tmp_dest_name = to_bytes(tmp_dest_name) b_tmp_dest_name = to_bytes(tmp_dest_name, errors='surrogate_or_strict')
try: try:
try: try:
@ -2056,7 +2058,7 @@ class AnsibleModule(object):
# On python2.6 and below, shlex has problems with text type # On python2.6 and below, shlex has problems with text type
# On python3, shlex needs a text type. # On python3, shlex needs a text type.
if PY2: if PY2:
args = to_bytes(args) args = to_bytes(args, errors='surrogate_or_strict')
elif PY3: elif PY3:
args = to_text(args, errors='surrogateescape') args = to_text(args, errors='surrogateescape')
args = shlex.split(args) args = shlex.split(args)
@ -2070,7 +2072,7 @@ class AnsibleModule(object):
if PY3: if PY3:
prompt_regex = to_bytes(prompt_regex, errors='surrogateescape') prompt_regex = to_bytes(prompt_regex, errors='surrogateescape')
elif PY2: elif PY2:
prompt_regex = to_bytes(prompt_regex) prompt_regex = to_bytes(prompt_regex, errors='surrogate_or_strict')
try: try:
prompt_re = re.compile(prompt_regex, re.MULTILINE) prompt_re = re.compile(prompt_regex, re.MULTILINE)
except re.error: except re.error:

View file

@ -34,7 +34,7 @@ import pwd
from ansible.module_utils.basic import get_all_subclasses from ansible.module_utils.basic import get_all_subclasses
from ansible.module_utils.six import PY3, iteritems from ansible.module_utils.six import PY3, iteritems
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_native
# py2 vs py3; replace with six via ansiballz # py2 vs py3; replace with six via ansiballz
try: try:
@ -358,7 +358,7 @@ class Facts(object):
proc_1 = os.path.basename(proc_1) proc_1 = os.path.basename(proc_1)
if proc_1 is not None: if proc_1 is not None:
proc_1 = to_text(proc_1) proc_1 = to_native(proc_1)
proc_1 = proc_1.strip() proc_1 = proc_1.strip()
if proc_1 == 'init' or proc_1.endswith('sh'): if proc_1 == 'init' or proc_1.endswith('sh'):

View file

@ -678,7 +678,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
with patch('os.lchown', return_value=None) as m: with patch('os.lchown', return_value=None) as m:
self.assertEqual(am.set_owner_if_different('/path/to/file', 0, False), True) self.assertEqual(am.set_owner_if_different('/path/to/file', 0, False), True)
m.assert_called_with('/path/to/file', 0, -1) m.assert_called_with(b'/path/to/file', 0, -1)
def _mock_getpwnam(*args, **kwargs): def _mock_getpwnam(*args, **kwargs):
mock_pw = MagicMock() mock_pw = MagicMock()
@ -688,7 +688,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
m.reset_mock() m.reset_mock()
with patch('pwd.getpwnam', side_effect=_mock_getpwnam): with patch('pwd.getpwnam', side_effect=_mock_getpwnam):
self.assertEqual(am.set_owner_if_different('/path/to/file', 'root', False), True) self.assertEqual(am.set_owner_if_different('/path/to/file', 'root', False), True)
m.assert_called_with('/path/to/file', 0, -1) m.assert_called_with(b'/path/to/file', 0, -1)
with patch('pwd.getpwnam', side_effect=KeyError): with patch('pwd.getpwnam', side_effect=KeyError):
self.assertRaises(SystemExit, am.set_owner_if_different, '/path/to/file', 'root', False) self.assertRaises(SystemExit, am.set_owner_if_different, '/path/to/file', 'root', False)
@ -717,7 +717,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
with patch('os.lchown', return_value=None) as m: with patch('os.lchown', return_value=None) as m:
self.assertEqual(am.set_group_if_different('/path/to/file', 0, False), True) self.assertEqual(am.set_group_if_different('/path/to/file', 0, False), True)
m.assert_called_with('/path/to/file', -1, 0) m.assert_called_with(b'/path/to/file', -1, 0)
def _mock_getgrnam(*args, **kwargs): def _mock_getgrnam(*args, **kwargs):
mock_gr = MagicMock() mock_gr = MagicMock()
@ -727,7 +727,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
m.reset_mock() m.reset_mock()
with patch('grp.getgrnam', side_effect=_mock_getgrnam): with patch('grp.getgrnam', side_effect=_mock_getgrnam):
self.assertEqual(am.set_group_if_different('/path/to/file', 'root', False), True) self.assertEqual(am.set_group_if_different('/path/to/file', 'root', False), True)
m.assert_called_with('/path/to/file', -1, 0) m.assert_called_with(b'/path/to/file', -1, 0)
with patch('grp.getgrnam', side_effect=KeyError): with patch('grp.getgrnam', side_effect=KeyError):
self.assertRaises(SystemExit, am.set_group_if_different, '/path/to/file', 'root', False) self.assertRaises(SystemExit, am.set_group_if_different, '/path/to/file', 'root', False)