mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Fix @contextmanager leak on exception. (#21031)
* Fix @contextmanager leak on exception. * Fix test leaks of global module args cache.
This commit is contained in:
parent
bb9ee0cf6f
commit
272ff10fa1
4 changed files with 25 additions and 13 deletions
|
@ -36,18 +36,22 @@ def swap_stdin_and_argv(stdin_data='', argv_data=tuple()):
|
||||||
context manager that temporarily masks the test runner's values for stdin and argv
|
context manager that temporarily masks the test runner's values for stdin and argv
|
||||||
"""
|
"""
|
||||||
real_stdin = sys.stdin
|
real_stdin = sys.stdin
|
||||||
|
real_argv = sys.argv
|
||||||
|
|
||||||
if PY3:
|
if PY3:
|
||||||
sys.stdin = StringIO(stdin_data)
|
fake_stream = StringIO(stdin_data)
|
||||||
sys.stdin.buffer = BytesIO(to_bytes(stdin_data))
|
fake_stream.buffer = BytesIO(to_bytes(stdin_data))
|
||||||
else:
|
else:
|
||||||
sys.stdin = BytesIO(to_bytes(stdin_data))
|
fake_stream = BytesIO(to_bytes(stdin_data))
|
||||||
|
|
||||||
real_argv = sys.argv
|
try:
|
||||||
sys.argv = argv_data
|
sys.stdin = fake_stream
|
||||||
yield
|
sys.argv = argv_data
|
||||||
sys.stdin = real_stdin
|
|
||||||
sys.argv = real_argv
|
yield
|
||||||
|
finally:
|
||||||
|
sys.stdin = real_stdin
|
||||||
|
sys.argv = real_argv
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -56,13 +60,18 @@ def swap_stdout():
|
||||||
context manager that temporarily replaces stdout for tests that need to verify output
|
context manager that temporarily replaces stdout for tests that need to verify output
|
||||||
"""
|
"""
|
||||||
old_stdout = sys.stdout
|
old_stdout = sys.stdout
|
||||||
|
|
||||||
if PY3:
|
if PY3:
|
||||||
fake_stream = StringIO()
|
fake_stream = StringIO()
|
||||||
else:
|
else:
|
||||||
fake_stream = BytesIO()
|
fake_stream = BytesIO()
|
||||||
sys.stdout = fake_stream
|
|
||||||
yield fake_stream
|
try:
|
||||||
sys.stdout = old_stdout
|
sys.stdout = fake_stream
|
||||||
|
|
||||||
|
yield fake_stream
|
||||||
|
finally:
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
|
||||||
|
|
||||||
class ModuleTestCase(unittest.TestCase):
|
class ModuleTestCase(unittest.TestCase):
|
||||||
|
|
|
@ -40,6 +40,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
||||||
from ansible.module_utils import basic
|
from ansible.module_utils import basic
|
||||||
|
|
||||||
# test basic log invocation
|
# test basic log invocation
|
||||||
|
basic._ANSIBLE_ARGS = None
|
||||||
am = basic.AnsibleModule(
|
am = basic.AnsibleModule(
|
||||||
argument_spec=dict(
|
argument_spec=dict(
|
||||||
foo = dict(default=True, type='bool'),
|
foo = dict(default=True, type='bool'),
|
||||||
|
|
|
@ -311,6 +311,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
|
||||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}))
|
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}))
|
||||||
|
|
||||||
with swap_stdin_and_argv(stdin_data=args):
|
with swap_stdin_and_argv(stdin_data=args):
|
||||||
|
basic._ANSIBLE_ARGS = None
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
SystemExit,
|
SystemExit,
|
||||||
basic.AnsibleModule,
|
basic.AnsibleModule,
|
||||||
|
@ -327,6 +328,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
|
||||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}))
|
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}))
|
||||||
|
|
||||||
with swap_stdin_and_argv(stdin_data=args):
|
with swap_stdin_and_argv(stdin_data=args):
|
||||||
|
basic._ANSIBLE_ARGS = None
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
SystemExit,
|
SystemExit,
|
||||||
basic.AnsibleModule,
|
basic.AnsibleModule,
|
||||||
|
@ -580,12 +582,11 @@ class TestModuleUtilsBasic(ModuleTestCase):
|
||||||
|
|
||||||
def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
|
def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
|
||||||
from ansible.module_utils import basic
|
from ansible.module_utils import basic
|
||||||
basic._ANSIBLE_ARGS = None
|
|
||||||
|
|
||||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"}))
|
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"}))
|
||||||
|
|
||||||
with swap_stdin_and_argv(stdin_data=args):
|
with swap_stdin_and_argv(stdin_data=args):
|
||||||
|
basic._ANSIBLE_ARGS = None
|
||||||
am = basic.AnsibleModule(
|
am = basic.AnsibleModule(
|
||||||
argument_spec = dict(),
|
argument_spec = dict(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -721,6 +721,7 @@ def test_distribution_version():
|
||||||
|
|
||||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}))
|
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}))
|
||||||
with swap_stdin_and_argv(stdin_data=args):
|
with swap_stdin_and_argv(stdin_data=args):
|
||||||
|
basic._ANSIBLE_ARGS = None
|
||||||
module = basic.AnsibleModule(argument_spec=dict())
|
module = basic.AnsibleModule(argument_spec=dict())
|
||||||
|
|
||||||
for t in TESTSETS:
|
for t in TESTSETS:
|
||||||
|
|
Loading…
Reference in a new issue