From c8a8f546d454ee7c5c2b501521132f71df707103 Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Thu, 14 Jul 2016 09:22:54 -0700 Subject: [PATCH] A little unittest refactoring (#16704) A little unittest refactoring * Add a class decorator to generate tests when using a unittest.TestCase base class * Add a TestCase subclass with setUp() and tearDown() that sets up module parameter parsing * Move test_safe_eval to use the class decorator and ModuleTestCase base class * Move testing of set_mode_if_different into its own file and separate some test methods out so we get better errors and more coverage in case of errors. * Naming convention for test cases doesn't need to duplicate information that's already in the file path. --- test/units/mock/generator.py | 71 +++++++++ test/units/mock/procenv.py | 20 ++- .../module_utils/basic/test_safe_eval.py | 103 +++++++----- .../basic/test_set_mode_if_different.py | 149 ++++++++++++++++++ test/units/module_utils/test_basic.py | 72 +-------- 5 files changed, 306 insertions(+), 109 deletions(-) create mode 100644 test/units/mock/generator.py create mode 100644 test/units/module_utils/basic/test_set_mode_if_different.py diff --git a/test/units/mock/generator.py b/test/units/mock/generator.py new file mode 100644 index 0000000000..3ae91f8951 --- /dev/null +++ b/test/units/mock/generator.py @@ -0,0 +1,71 @@ +# Copyright 2016 Toshio Kuratomi +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from collections import Mapping + +def make_method(func, args, kwargs): + + def test_method(self): + func(self, *args, **kwargs) + + # Format the argument string + arg_string = ', '.join(repr(a) for a in args) + kwarg_string = ', '.join('{0}={1}'.format(item[0], repr(item[1])) for item in kwargs.items()) + arg_list = [] + if arg_string: + arg_list.append(arg_string) + if kwarg_string: + arg_list.append(kwarg_string) + + test_method.__name__ = 'test_{0}({1})'.format(func.__name__, ', '.join(arg_list)) + return test_method + + +def add_method(func, *combined_args): + """ + Add a test case via a class decorator. + + nose uses generators for this but doesn't work with unittest.TestCase + subclasses. So we have to write our own. + + The first argument to this decorator is a test function. All subsequent + arguments are the arguments to create each generated test function with in + the following format: + + Each set of arguments is a two-tuple. The first element is an iterable of + positional arguments. the second is a dict representing the kwargs. + """ + def wrapper(cls): + for combined_arg in combined_args: + if len(combined_arg) == 2: + args = combined_arg[0] + kwargs = combined_arg[1] + elif isinstance(combined_arg[0], Mapping): + args = [] + kwargs = combined_arg[0] + else: + args = combined_arg[0] + kwargs = {} + test_method = make_method(func, args, kwargs) + setattr(cls, test_method.__name__, test_method) + return cls + + return wrapper diff --git a/test/units/mock/procenv.py b/test/units/mock/procenv.py index ae0ea5abf5..66019be6bd 100644 --- a/test/units/mock/procenv.py +++ b/test/units/mock/procenv.py @@ -1,4 +1,5 @@ # (c) 2016, Matt Davis +# (c) 2016, Toshio Kuratomi # # This file is part of Ansible # @@ -20,10 +21,12 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type import sys +import json from contextlib import contextmanager from io import BytesIO, StringIO from ansible.compat.six import PY3 +from ansible.compat.tests import unittest from ansible.utils.unicode import to_bytes @contextmanager @@ -54,4 +57,19 @@ def swap_stdout(): fake_stream = BytesIO() sys.stdout = fake_stream yield fake_stream - sys.stdout = old_stdout \ No newline at end of file + sys.stdout = old_stdout + +class ModuleTestCase(unittest.TestCase): + def setUp(self, module_args=None): + if module_args is None: + module_args = {} + + args = json.dumps(dict(ANSIBLE_MODULE_ARGS=module_args)) + + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap.__enter__() + + def tearDown(self): + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap.__exit__(None, None, None) diff --git a/test/units/module_utils/basic/test_safe_eval.py b/test/units/module_utils/basic/test_safe_eval.py index 393416c0fd..0983002049 100644 --- a/test/units/module_utils/basic/test_safe_eval.py +++ b/test/units/module_utils/basic/test_safe_eval.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# (c) 2015, Toshio Kuratomi +# (c) 2015-2016, Toshio Kuratomi # # This file is part of Ansible # @@ -24,47 +24,72 @@ import sys import json from ansible.compat.tests import unittest -from units.mock.procenv import swap_stdin_and_argv +from units.mock.procenv import ModuleTestCase +from units.mock.generator import add_method -class TestAnsibleModuleExitJson(unittest.TestCase): - def test_module_utils_basic_safe_eval(self): +# Strings that should be converted into a typed value +VALID_STRINGS = ( + [("'a'", 'a')], + [("'1'", '1')], + [("1", 1)], + [("True", True)], + [("False", False)], + [("{}", {})], + ) + +# Passing things that aren't strings should just return the object +NONSTRINGS = ( + [({'a':1}, {'a':1})], + ) + +# These strings are not basic types. For security, these should not be +# executed. We return the same string and get an exception for some +INVALID_STRINGS = ( + [("a=1", "a=1", SyntaxError)], + [("a.foo()", "a.foo()", None)], + [("import foo", "import foo", None)], + [("__import__('foo')", "__import__('foo')", ValueError)], + ) + + +def _check_simple_types(self, code, expected): + # test some basic usage for various types + self.assertEqual(self.am.safe_eval(code), expected) + +def _check_simple_types_with_exceptions(self, code, expected): + # Test simple types with exceptions requested + self.assertEqual(self.am.safe_eval(code, include_exceptions=True), (expected, None)) + +def _check_invalid_strings(self, code, expected): + self.assertEqual(self.am.safe_eval(code), expected) + +def _check_invalid_strings_with_exceptions(self, code, expected, exception): + res = self.am.safe_eval("a=1", include_exceptions=True) + self.assertEqual(res[0], "a=1") + self.assertEqual(type(res[1]), SyntaxError) + +@add_method(_check_simple_types, *VALID_STRINGS) +@add_method(_check_simple_types, *NONSTRINGS) +@add_method(_check_simple_types_with_exceptions, *VALID_STRINGS) +@add_method(_check_simple_types_with_exceptions, *NONSTRINGS) +@add_method(_check_invalid_strings, *[[i[0][0:-1]] for i in INVALID_STRINGS]) +@add_method(_check_invalid_strings_with_exceptions, *INVALID_STRINGS) +class TestSafeEval(ModuleTestCase): + + def setUp(self): + super(TestSafeEval, self).setUp() + from ansible.module_utils import basic + self.old_ansible_args = basic._ANSIBLE_ARGS - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={})) + basic._ANSIBLE_ARGS = None + self.am = basic.AnsibleModule( + argument_spec=dict(), + ) - with swap_stdin_and_argv(stdin_data=args): - basic._ANSIBLE_ARGS = None - am = basic.AnsibleModule( - argument_spec=dict(), - ) - - # test some basic usage - # string (and with exceptions included), integer, bool - self.assertEqual(am.safe_eval("'a'"), 'a') - self.assertEqual(am.safe_eval("'a'", include_exceptions=True), ('a', None)) - self.assertEqual(am.safe_eval("1"), 1) - self.assertEqual(am.safe_eval("True"), True) - self.assertEqual(am.safe_eval("False"), False) - self.assertEqual(am.safe_eval("{}"), {}) - # not passing in a string to convert - self.assertEqual(am.safe_eval({'a':1}), {'a':1}) - self.assertEqual(am.safe_eval({'a':1}, include_exceptions=True), ({'a':1}, None)) - # invalid literal eval - self.assertEqual(am.safe_eval("a=1"), "a=1") - res = am.safe_eval("a=1", include_exceptions=True) - self.assertEqual(res[0], "a=1") - self.assertEqual(type(res[1]), SyntaxError) - self.assertEqual(am.safe_eval("a.foo()"), "a.foo()") - res = am.safe_eval("a.foo()", include_exceptions=True) - self.assertEqual(res[0], "a.foo()") - self.assertEqual(res[1], None) - self.assertEqual(am.safe_eval("import foo"), "import foo") - res = am.safe_eval("import foo", include_exceptions=True) - self.assertEqual(res[0], "import foo") - self.assertEqual(res[1], None) - self.assertEqual(am.safe_eval("__import__('foo')"), "__import__('foo')") - res = am.safe_eval("__import__('foo')", include_exceptions=True) - self.assertEqual(res[0], "__import__('foo')") - self.assertEqual(type(res[1]), ValueError) + def tearDown(self): + super(TestSafeEval, self).tearDown() + from ansible.module_utils import basic + basic._ANSIBLE_ARGS = self.old_ansible_args diff --git a/test/units/module_utils/basic/test_set_mode_if_different.py b/test/units/module_utils/basic/test_set_mode_if_different.py new file mode 100644 index 0000000000..d36a4a3622 --- /dev/null +++ b/test/units/module_utils/basic/test_set_mode_if_different.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# (c) 2016, Toshio Kuratomi +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +try: + import builtins +except ImportError: + import __builtin__ as builtins + +from ansible.compat.tests import unittest +from ansible.compat.tests.mock import patch, MagicMock +from ansible.module_utils import known_hosts + +from units.mock.procenv import ModuleTestCase +from units.mock.generator import add_method + +class TestSetModeIfDifferentBase(ModuleTestCase): + + def setUp(self): + self.mock_stat1 = MagicMock() + self.mock_stat1.st_mode = 0o444 + self.mock_stat2 = MagicMock() + self.mock_stat2.st_mode = 0o660 + + super(TestSetModeIfDifferentBase, self).setUp() + from ansible.module_utils import basic + self.old_ANSIBLE_ARGS = basic._ANSIBLE_ARGS + basic._ANSIBLE_ARGS = None + + self.am = basic.AnsibleModule( + argument_spec = dict(), + ) + + def tearDown(self): + super(TestSetModeIfDifferentBase, self).tearDown() + from ansible.module_utils import basic + basic._ANSIBLE_ARGS = self.old_ANSIBLE_ARGS + + +def _check_no_mode_given_returns_previous_changes(self, previous_changes=True): + with patch('os.lstat', side_effect=[self.mock_stat1]): + self.assertEqual(self.am.set_mode_if_different('/path/to/file', None, previous_changes), previous_changes) + +def _check_mode_changed_to_0660(self, mode): + # Note: This is for checking that all the different ways of specifying + # 0660 mode work. It cannot be used to check that setting a mode that is + # not equivalent to 0660 works. + with patch('os.lstat', side_effect=[self.mock_stat1, self.mock_stat2, self.mock_stat2]) as m_lstat: + with patch('os.lchmod', return_value=None, create=True) as m_lchmod: + self.assertEqual(self.am.set_mode_if_different('/path/to/file', mode, False), True) + m_lchmod.assert_called_with('/path/to/file', 0o660) + +def _check_mode_unchanged_when_already_0660(self, mode): + # Note: This is for checking that all the different ways of specifying + # 0660 mode work. It cannot be used to check that setting a mode that is + # not equivalent to 0660 works. + with patch('os.lstat', side_effect=[self.mock_stat2, self.mock_stat2, self.mock_stat2]) as m_lstat: + self.assertEqual(self.am.set_mode_if_different('/path/to/file', mode, False), False) + + +SYNONYMS_0660 = ( + [[0o660]], + [['0o660']], + [['660']], + ) + +@add_method(_check_no_mode_given_returns_previous_changes, + [dict(previous_changes=True)], + [dict(previous_changes=False)], + ) +@add_method(_check_mode_changed_to_0660, + *SYNONYMS_0660 + ) +@add_method(_check_mode_unchanged_when_already_0660, + *SYNONYMS_0660 + ) +class TestSetModeIfDifferent(TestSetModeIfDifferentBase): + def test_module_utils_basic_ansible_module_set_mode_if_different(self): + with patch('os.lstat') as m: + with patch('os.lchmod', return_value=None, create=True) as m_os: + m.side_effect = [self.mock_stat1, self.mock_stat2, self.mock_stat2] + self.am._symbolic_mode_to_octal = MagicMock(side_effect=Exception) + self.assertRaises(SystemExit, self.am.set_mode_if_different, '/path/to/file', 'o+w,g+w,a-r', False) + + original_hasattr = hasattr + def _hasattr(obj, name): + if obj == os and name == 'lchmod': + return False + return original_hasattr(obj, name) + + # FIXME: this isn't working yet + with patch('os.lstat', side_effect=[self.mock_stat1, self.mock_stat2]): + with patch.object(builtins, 'hasattr', side_effect=_hasattr): + with patch('os.path.islink', return_value=False): + with patch('os.chmod', return_value=None) as m_chmod: + self.assertEqual(self.am.set_mode_if_different('/path/to/file/no_lchmod', 0o660, False), True) + with patch('os.lstat', side_effect=[self.mock_stat1, self.mock_stat2]): + with patch.object(builtins, 'hasattr', side_effect=_hasattr): + with patch('os.path.islink', return_value=True): + with patch('os.chmod', return_value=None) as m_chmod: + with patch('os.stat', return_value=self.mock_stat2): + self.assertEqual(self.am.set_mode_if_different('/path/to/file', 0o660, False), True) + + +def _check_knows_to_change_to_0660_in_check_mode(self, mode): + # Note: This is for checking that all the different ways of specifying + # 0660 mode work. It cannot be used to check that setting a mode that is + # not equivalent to 0660 works. + with patch('os.lstat', side_effect=[self.mock_stat1, self.mock_stat2, self.mock_stat2]) as m_lstat: + self.assertEqual(self.am.set_mode_if_different('/path/to/file', mode, False), True) + +@add_method(_check_no_mode_given_returns_previous_changes, + [dict(previous_changes=True)], + [dict(previous_changes=False)], + ) +@add_method(_check_knows_to_change_to_0660_in_check_mode, + *SYNONYMS_0660 + ) +@add_method(_check_mode_unchanged_when_already_0660, + *SYNONYMS_0660 + ) +class TestSetModeIfDifferentWithCheckMode(TestSetModeIfDifferentBase): + def setUp(self): + super(TestSetModeIfDifferentWithCheckMode, self).setUp() + self.am.check_mode = True + + def tearDown(self): + super(TestSetModeIfDifferentWithCheckMode, self).tearDown() + self.am.check_mode = False diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index ac775400c8..758fe38a30 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # (c) 2012-2014, Michael DeHaan +# (c) 2016 Toshio Kuratomi # # This file is part of Ansible # @@ -31,24 +32,14 @@ try: except ImportError: import __builtin__ as builtins -from units.mock.procenv import swap_stdin_and_argv +from units.mock.procenv import ModuleTestCase, swap_stdin_and_argv from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, MagicMock, mock_open, Mock, call realimport = builtins.__import__ -class TestModuleUtilsBasic(unittest.TestCase): - - def setUp(self): - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={})) - # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually - self.stdin_swap = swap_stdin_and_argv(stdin_data=args) - self.stdin_swap.__enter__() - - def tearDown(self): - # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually - self.stdin_swap.__exit__(None, None, None) +class TestModuleUtilsBasic(ModuleTestCase): def clear_modules(self, mods): for mod in mods: @@ -761,63 +752,6 @@ class TestModuleUtilsBasic(unittest.TestCase): with patch('os.lchown', side_effect=OSError) as m: self.assertRaises(SystemExit, am.set_group_if_different, '/path/to/file', 'root', False) - def test_module_utils_basic_ansible_module_set_mode_if_different(self): - from ansible.module_utils import basic - basic._ANSIBLE_ARGS = None - - am = basic.AnsibleModule( - argument_spec = dict(), - ) - - mock_stat1 = MagicMock() - mock_stat1.st_mode = 0o444 - mock_stat2 = MagicMock() - mock_stat2.st_mode = 0o660 - - with patch('os.lstat', side_effect=[mock_stat1]): - self.assertEqual(am.set_mode_if_different('/path/to/file', None, True), True) - with patch('os.lstat', side_effect=[mock_stat1]): - self.assertEqual(am.set_mode_if_different('/path/to/file', None, False), False) - - with patch('os.lstat') as m: - with patch('os.lchmod', return_value=None, create=True) as m_os: - m.side_effect = [mock_stat1, mock_stat2, mock_stat2] - self.assertEqual(am.set_mode_if_different('/path/to/file', 0o660, False), True) - m_os.assert_called_with('/path/to/file', 0o660) - - m.side_effect = [mock_stat1, mock_stat2, mock_stat2] - am._symbolic_mode_to_octal = MagicMock(return_value=0o660) - self.assertEqual(am.set_mode_if_different('/path/to/file', 'o+w,g+w,a-r', False), True) - m_os.assert_called_with('/path/to/file', 0o660) - - m.side_effect = [mock_stat1, mock_stat2, mock_stat2] - am._symbolic_mode_to_octal = MagicMock(side_effect=Exception) - self.assertRaises(SystemExit, am.set_mode_if_different, '/path/to/file', 'o+w,g+w,a-r', False) - - m.side_effect = [mock_stat1, mock_stat2, mock_stat2] - am.check_mode = True - self.assertEqual(am.set_mode_if_different('/path/to/file', 0o660, False), True) - am.check_mode = False - - original_hasattr = hasattr - def _hasattr(obj, name): - if obj == os and name == 'lchmod': - return False - return original_hasattr(obj, name) - - # FIXME: this isn't working yet - with patch('os.lstat', side_effect=[mock_stat1, mock_stat2]): - with patch.object(builtins, 'hasattr', side_effect=_hasattr): - with patch('os.path.islink', return_value=False): - with patch('os.chmod', return_value=None) as m_chmod: - self.assertEqual(am.set_mode_if_different('/path/to/file/no_lchmod', 0o660, False), True) - with patch('os.lstat', side_effect=[mock_stat1, mock_stat2]): - with patch.object(builtins, 'hasattr', side_effect=_hasattr): - with patch('os.path.islink', return_value=True): - with patch('os.chmod', return_value=None) as m_chmod: - with patch('os.stat', return_value=mock_stat2): - self.assertEqual(am.set_mode_if_different('/path/to/file', 0o660, False), True) - @patch('tempfile.NamedTemporaryFile') @patch('os.umask') @patch('shutil.copyfileobj')