mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Fix mixing of bytes and str in module replacer (caused traceback on python3)
This commit is contained in:
parent
b0bed27211
commit
c29f51804b
2 changed files with 123 additions and 38 deletions
|
@ -21,7 +21,7 @@ from __future__ import (absolute_import, division, print_function)
|
|||
__metaclass__ = type
|
||||
|
||||
# from python and deps
|
||||
from ansible.compat.six.moves import StringIO
|
||||
from io import BytesIO
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
|
@ -30,20 +30,20 @@ import shlex
|
|||
from ansible import __version__
|
||||
from ansible import constants as C
|
||||
from ansible.errors import AnsibleError
|
||||
from ansible.utils.unicode import to_bytes
|
||||
from ansible.utils.unicode import to_bytes, to_unicode
|
||||
|
||||
REPLACER = "#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
|
||||
REPLACER_ARGS = "\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
|
||||
REPLACER_COMPLEX = "\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
|
||||
REPLACER_WINDOWS = "# POWERSHELL_COMMON"
|
||||
REPLACER_WINARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
|
||||
REPLACER_JSONARGS = "<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
|
||||
REPLACER_VERSION = "\"<<ANSIBLE_VERSION>>\""
|
||||
REPLACER_SELINUX = "<<SELINUX_SPECIAL_FILESYSTEMS>>"
|
||||
REPLACER = b"#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
|
||||
REPLACER_ARGS = b"\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
|
||||
REPLACER_COMPLEX = b"\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
|
||||
REPLACER_WINDOWS = b"# POWERSHELL_COMMON"
|
||||
REPLACER_WINARGS = b"<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
|
||||
REPLACER_JSONARGS = b"<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
|
||||
REPLACER_VERSION = b"\"<<ANSIBLE_VERSION>>\""
|
||||
REPLACER_SELINUX = b"<<SELINUX_SPECIAL_FILESYSTEMS>>"
|
||||
|
||||
# We could end up writing out parameters with unicode characters so we need to
|
||||
# specify an encoding for the python source file
|
||||
ENCODING_STRING = '# -*- coding: utf-8 -*-'
|
||||
ENCODING_STRING = b'# -*- coding: utf-8 -*-'
|
||||
|
||||
# we've moved the module_common relative to the snippets, so fix the path
|
||||
_SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
|
||||
|
@ -53,7 +53,7 @@ _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
|
|||
def _slurp(path):
|
||||
if not os.path.exists(path):
|
||||
raise AnsibleError("imported module support code does not exist at %s" % path)
|
||||
fd = open(path)
|
||||
fd = open(path, 'rb')
|
||||
data = fd.read()
|
||||
fd.close()
|
||||
return data
|
||||
|
@ -71,49 +71,49 @@ def _find_snippet_imports(module_data, module_path, strip_comments):
|
|||
module_style = 'new'
|
||||
elif REPLACER_JSONARGS in module_data:
|
||||
module_style = 'new'
|
||||
elif 'from ansible.module_utils.' in module_data:
|
||||
elif b'from ansible.module_utils.' in module_data:
|
||||
module_style = 'new'
|
||||
elif 'WANT_JSON' in module_data:
|
||||
elif b'WANT_JSON' in module_data:
|
||||
module_style = 'non_native_want_json'
|
||||
|
||||
output = StringIO()
|
||||
lines = module_data.split('\n')
|
||||
output = BytesIO()
|
||||
lines = module_data.split(b'\n')
|
||||
snippet_names = []
|
||||
|
||||
for line in lines:
|
||||
|
||||
if REPLACER in line:
|
||||
output.write(_slurp(os.path.join(_SNIPPET_PATH, "basic.py")))
|
||||
snippet_names.append('basic')
|
||||
snippet_names.append(b'basic')
|
||||
if REPLACER_WINDOWS in line:
|
||||
ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1"))
|
||||
output.write(ps_data)
|
||||
snippet_names.append('powershell')
|
||||
elif line.startswith('from ansible.module_utils.'):
|
||||
tokens=line.split(".")
|
||||
snippet_names.append(b'powershell')
|
||||
elif line.startswith(b'from ansible.module_utils.'):
|
||||
tokens=line.split(b".")
|
||||
import_error = False
|
||||
if len(tokens) != 3:
|
||||
import_error = True
|
||||
if " import *" not in line:
|
||||
if b" import *" not in line:
|
||||
import_error = True
|
||||
if import_error:
|
||||
raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.<lib name> import *'" % module_path)
|
||||
snippet_name = tokens[2].split()[0]
|
||||
snippet_names.append(snippet_name)
|
||||
output.write(_slurp(os.path.join(_SNIPPET_PATH, snippet_name + ".py")))
|
||||
output.write(_slurp(os.path.join(_SNIPPET_PATH, to_unicode(snippet_name) + ".py")))
|
||||
else:
|
||||
if strip_comments and line.startswith("#") or line == '':
|
||||
if strip_comments and line.startswith(b"#") or line == b'':
|
||||
pass
|
||||
output.write(line)
|
||||
output.write("\n")
|
||||
output.write(b"\n")
|
||||
|
||||
if not module_path.endswith(".ps1"):
|
||||
# Unixy modules
|
||||
if len(snippet_names) > 0 and not 'basic' in snippet_names:
|
||||
if len(snippet_names) > 0 and not b'basic' in snippet_names:
|
||||
raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path)
|
||||
else:
|
||||
# Windows modules
|
||||
if len(snippet_names) > 0 and not 'powershell' in snippet_names:
|
||||
if len(snippet_names) > 0 and not b'powershell' in snippet_names:
|
||||
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
|
||||
|
||||
return (output.getvalue(), module_style)
|
||||
|
@ -158,28 +158,28 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
|
|||
# * Cache the modified module? If only the args are different and we do
|
||||
# that as the last step we could cache all the work up to that point.
|
||||
|
||||
with open(module_path) as f:
|
||||
with open(module_path, 'rb') as f:
|
||||
|
||||
# read in the module source
|
||||
module_data = f.read()
|
||||
|
||||
(module_data, module_style) = _find_snippet_imports(module_data, module_path, strip_comments)
|
||||
|
||||
module_args_json = json.dumps(module_args).encode('utf-8')
|
||||
python_repred_args = repr(module_args_json)
|
||||
module_args_json = to_bytes(json.dumps(module_args))
|
||||
python_repred_args = to_bytes(repr(module_args_json))
|
||||
|
||||
# these strings should be part of the 'basic' snippet which is required to be included
|
||||
module_data = module_data.replace(REPLACER_VERSION, repr(__version__))
|
||||
module_data = module_data.replace(REPLACER_VERSION, to_bytes(__version__, nonstring='repr'))
|
||||
module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args)
|
||||
module_data = module_data.replace(REPLACER_WINARGS, module_args_json)
|
||||
module_data = module_data.replace(REPLACER_JSONARGS, module_args_json)
|
||||
module_data = module_data.replace(REPLACER_SELINUX, ','.join(C.DEFAULT_SELINUX_SPECIAL_FS))
|
||||
module_data = module_data.replace(REPLACER_SELINUX, to_bytes(','.join(C.DEFAULT_SELINUX_SPECIAL_FS)))
|
||||
|
||||
if module_style == 'new':
|
||||
facility = C.DEFAULT_SYSLOG_FACILITY
|
||||
if 'ansible_syslog_facility' in task_vars:
|
||||
facility = task_vars['ansible_syslog_facility']
|
||||
module_data = module_data.replace('syslog.LOG_USER', "syslog.%s" % facility)
|
||||
module_data = module_data.replace(b'syslog.LOG_USER', to_bytes("syslog.%s" % facility))
|
||||
|
||||
lines = module_data.split(b"\n", 1)
|
||||
shebang = None
|
||||
|
@ -188,12 +188,13 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
|
|||
args = shlex.split(str(shebang[2:]))
|
||||
interpreter = args[0]
|
||||
interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter)
|
||||
interpreter = to_bytes(interpreter)
|
||||
|
||||
if interpreter_config in task_vars:
|
||||
interpreter = to_bytes(task_vars[interpreter_config], errors='strict')
|
||||
lines[0] = shebang = b"#!{0} {1}".format(interpreter, b" ".join(args[1:]))
|
||||
|
||||
if os.path.basename(interpreter).startswith('python'):
|
||||
if os.path.basename(interpreter).startswith(b'python'):
|
||||
lines.insert(1, ENCODING_STRING)
|
||||
else:
|
||||
# No shebang, assume a binary module?
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# (c) 2015, Florian Apolloner <florian@apolloner.eu>
|
||||
#
|
||||
# This file is part of Ansible
|
||||
|
@ -34,15 +35,17 @@ from ansible import constants as C
|
|||
from ansible.compat.six import text_type
|
||||
from ansible.compat.tests import unittest
|
||||
from ansible.compat.tests.mock import patch, MagicMock, mock_open
|
||||
|
||||
from ansible.errors import AnsibleError
|
||||
from ansible.playbook.play_context import PlayContext
|
||||
from ansible.plugins import PluginLoader
|
||||
from ansible.plugins.action import ActionBase
|
||||
from ansible.template import Templar
|
||||
from ansible.utils.unicode import to_bytes
|
||||
|
||||
from units.mock.loader import DictDataLoader
|
||||
|
||||
python_module_replacers = """
|
||||
python_module_replacers = b"""
|
||||
#!/usr/bin/python
|
||||
|
||||
#ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>"
|
||||
|
@ -50,14 +53,95 @@ python_module_replacers = """
|
|||
#MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>"
|
||||
#SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>"
|
||||
|
||||
test = u'Toshio \u304f\u3089\u3068\u307f'
|
||||
from ansible.module_utils.basic import *
|
||||
"""
|
||||
|
||||
powershell_module_replacers = """
|
||||
powershell_module_replacers = b"""
|
||||
WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
|
||||
# POWERSHELL_COMMON
|
||||
"""
|
||||
|
||||
# Prior to 3.4.4, mock_open cannot handle binary read_data
|
||||
if version_info >= (3,) and version_info < (3, 4, 4):
|
||||
file_spec = None
|
||||
|
||||
def _iterate_read_data(read_data):
|
||||
# Helper for mock_open:
|
||||
# Retrieve lines from read_data via a generator so that separate calls to
|
||||
# readline, read, and readlines are properly interleaved
|
||||
sep = b'\n' if isinstance(read_data, bytes) else '\n'
|
||||
data_as_list = [l + sep for l in read_data.split(sep)]
|
||||
|
||||
if data_as_list[-1] == sep:
|
||||
# If the last line ended in a newline, the list comprehension will have an
|
||||
# extra entry that's just a newline. Remove this.
|
||||
data_as_list = data_as_list[:-1]
|
||||
else:
|
||||
# If there wasn't an extra newline by itself, then the file being
|
||||
# emulated doesn't have a newline to end the last line remove the
|
||||
# newline that our naive format() added
|
||||
data_as_list[-1] = data_as_list[-1][:-1]
|
||||
|
||||
for line in data_as_list:
|
||||
yield line
|
||||
|
||||
def mock_open(mock=None, read_data=''):
|
||||
"""
|
||||
A helper function to create a mock to replace the use of `open`. It works
|
||||
for `open` called directly or used as a context manager.
|
||||
|
||||
The `mock` argument is the mock object to configure. If `None` (the
|
||||
default) then a `MagicMock` will be created for you, with the API limited
|
||||
to methods or attributes available on standard file handles.
|
||||
|
||||
`read_data` is a string for the `read` methoddline`, and `readlines` of the
|
||||
file handle to return. This is an empty string by default.
|
||||
"""
|
||||
def _readlines_side_effect(*args, **kwargs):
|
||||
if handle.readlines.return_value is not None:
|
||||
return handle.readlines.return_value
|
||||
return list(_data)
|
||||
|
||||
def _read_side_effect(*args, **kwargs):
|
||||
if handle.read.return_value is not None:
|
||||
return handle.read.return_value
|
||||
return type(read_data)().join(_data)
|
||||
|
||||
def _readline_side_effect():
|
||||
if handle.readline.return_value is not None:
|
||||
while True:
|
||||
yield handle.readline.return_value
|
||||
for line in _data:
|
||||
yield line
|
||||
|
||||
|
||||
global file_spec
|
||||
if file_spec is None:
|
||||
import _io
|
||||
file_spec = list(set(dir(_io.TextIOWrapper)).union(set(dir(_io.BytesIO))))
|
||||
|
||||
if mock is None:
|
||||
mock = MagicMock(name='open', spec=open)
|
||||
|
||||
handle = MagicMock(spec=file_spec)
|
||||
handle.__enter__.return_value = handle
|
||||
|
||||
_data = _iterate_read_data(read_data)
|
||||
|
||||
handle.write.return_value = None
|
||||
handle.read.return_value = None
|
||||
handle.readline.return_value = None
|
||||
handle.readlines.return_value = None
|
||||
|
||||
handle.read.side_effect = _read_side_effect
|
||||
handle.readline.side_effect = _readline_side_effect()
|
||||
handle.readlines.side_effect = _readlines_side_effect
|
||||
|
||||
mock.return_value = handle
|
||||
return mock
|
||||
|
||||
|
||||
class DerivedActionBase(ActionBase):
|
||||
def run(self, tmp=None, task_vars=None):
|
||||
# We're not testing the plugin run() method, just the helper
|
||||
|
@ -124,18 +208,18 @@ class TestActionBase(unittest.TestCase):
|
|||
)
|
||||
|
||||
# test python module formatting
|
||||
with patch.object(builtins, 'open', mock_open(read_data=text_type(python_module_replacers.strip()))) as m:
|
||||
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))) as m:
|
||||
mock_task.args = dict(a=1)
|
||||
mock_connection.module_implementation_preferences = ('',)
|
||||
(style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args)
|
||||
self.assertEqual(style, "new")
|
||||
self.assertEqual(shebang, "#!/usr/bin/python")
|
||||
self.assertEqual(shebang, b"#!/usr/bin/python")
|
||||
|
||||
# test module not found
|
||||
self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args)
|
||||
|
||||
# test powershell module formatting
|
||||
with patch.object(builtins, 'open', mock_open(read_data=text_type(powershell_module_replacers.strip()))) as m:
|
||||
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))) as m:
|
||||
mock_task.action = 'win_copy'
|
||||
mock_task.args = dict(b=2)
|
||||
mock_connection.module_implementation_preferences = ('.ps1',)
|
||||
|
|
Loading…
Reference in a new issue