1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Fix mixing of bytes and str in module replacer (caused traceback on python3)

This commit is contained in:
Toshio Kuratomi 2016-02-26 16:41:13 -08:00
parent b0bed27211
commit c29f51804b
2 changed files with 123 additions and 38 deletions

View file

@ -21,7 +21,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
# from python and deps # from python and deps
from ansible.compat.six.moves import StringIO from io import BytesIO
import json import json
import os import os
import shlex import shlex
@ -30,20 +30,20 @@ import shlex
from ansible import __version__ from ansible import __version__
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.utils.unicode import to_bytes from ansible.utils.unicode import to_bytes, to_unicode
REPLACER = "#<<INCLUDE_ANSIBLE_MODULE_COMMON>>" REPLACER = b"#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
REPLACER_ARGS = "\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\"" REPLACER_ARGS = b"\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
REPLACER_COMPLEX = "\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\"" REPLACER_COMPLEX = b"\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
REPLACER_WINDOWS = "# POWERSHELL_COMMON" REPLACER_WINDOWS = b"# POWERSHELL_COMMON"
REPLACER_WINARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>" REPLACER_WINARGS = b"<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
REPLACER_JSONARGS = "<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>" REPLACER_JSONARGS = b"<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
REPLACER_VERSION = "\"<<ANSIBLE_VERSION>>\"" REPLACER_VERSION = b"\"<<ANSIBLE_VERSION>>\""
REPLACER_SELINUX = "<<SELINUX_SPECIAL_FILESYSTEMS>>" REPLACER_SELINUX = b"<<SELINUX_SPECIAL_FILESYSTEMS>>"
# We could end up writing out parameters with unicode characters so we need to # We could end up writing out parameters with unicode characters so we need to
# specify an encoding for the python source file # specify an encoding for the python source file
ENCODING_STRING = '# -*- coding: utf-8 -*-' ENCODING_STRING = b'# -*- coding: utf-8 -*-'
# we've moved the module_common relative to the snippets, so fix the path # we've moved the module_common relative to the snippets, so fix the path
_SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils') _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
@ -53,7 +53,7 @@ _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
def _slurp(path): def _slurp(path):
if not os.path.exists(path): if not os.path.exists(path):
raise AnsibleError("imported module support code does not exist at %s" % path) raise AnsibleError("imported module support code does not exist at %s" % path)
fd = open(path) fd = open(path, 'rb')
data = fd.read() data = fd.read()
fd.close() fd.close()
return data return data
@ -71,49 +71,49 @@ def _find_snippet_imports(module_data, module_path, strip_comments):
module_style = 'new' module_style = 'new'
elif REPLACER_JSONARGS in module_data: elif REPLACER_JSONARGS in module_data:
module_style = 'new' module_style = 'new'
elif 'from ansible.module_utils.' in module_data: elif b'from ansible.module_utils.' in module_data:
module_style = 'new' module_style = 'new'
elif 'WANT_JSON' in module_data: elif b'WANT_JSON' in module_data:
module_style = 'non_native_want_json' module_style = 'non_native_want_json'
output = StringIO() output = BytesIO()
lines = module_data.split('\n') lines = module_data.split(b'\n')
snippet_names = [] snippet_names = []
for line in lines: for line in lines:
if REPLACER in line: if REPLACER in line:
output.write(_slurp(os.path.join(_SNIPPET_PATH, "basic.py"))) output.write(_slurp(os.path.join(_SNIPPET_PATH, "basic.py")))
snippet_names.append('basic') snippet_names.append(b'basic')
if REPLACER_WINDOWS in line: if REPLACER_WINDOWS in line:
ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1")) ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1"))
output.write(ps_data) output.write(ps_data)
snippet_names.append('powershell') snippet_names.append(b'powershell')
elif line.startswith('from ansible.module_utils.'): elif line.startswith(b'from ansible.module_utils.'):
tokens=line.split(".") tokens=line.split(b".")
import_error = False import_error = False
if len(tokens) != 3: if len(tokens) != 3:
import_error = True import_error = True
if " import *" not in line: if b" import *" not in line:
import_error = True import_error = True
if import_error: if import_error:
raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.<lib name> import *'" % module_path) raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.<lib name> import *'" % module_path)
snippet_name = tokens[2].split()[0] snippet_name = tokens[2].split()[0]
snippet_names.append(snippet_name) snippet_names.append(snippet_name)
output.write(_slurp(os.path.join(_SNIPPET_PATH, snippet_name + ".py"))) output.write(_slurp(os.path.join(_SNIPPET_PATH, to_unicode(snippet_name) + ".py")))
else: else:
if strip_comments and line.startswith("#") or line == '': if strip_comments and line.startswith(b"#") or line == b'':
pass pass
output.write(line) output.write(line)
output.write("\n") output.write(b"\n")
if not module_path.endswith(".ps1"): if not module_path.endswith(".ps1"):
# Unixy modules # Unixy modules
if len(snippet_names) > 0 and not 'basic' in snippet_names: if len(snippet_names) > 0 and not b'basic' in snippet_names:
raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path) raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path)
else: else:
# Windows modules # Windows modules
if len(snippet_names) > 0 and not 'powershell' in snippet_names: if len(snippet_names) > 0 and not b'powershell' in snippet_names:
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path) raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
return (output.getvalue(), module_style) return (output.getvalue(), module_style)
@ -158,28 +158,28 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
# * Cache the modified module? If only the args are different and we do # * Cache the modified module? If only the args are different and we do
# that as the last step we could cache all the work up to that point. # that as the last step we could cache all the work up to that point.
with open(module_path) as f: with open(module_path, 'rb') as f:
# read in the module source # read in the module source
module_data = f.read() module_data = f.read()
(module_data, module_style) = _find_snippet_imports(module_data, module_path, strip_comments) (module_data, module_style) = _find_snippet_imports(module_data, module_path, strip_comments)
module_args_json = json.dumps(module_args).encode('utf-8') module_args_json = to_bytes(json.dumps(module_args))
python_repred_args = repr(module_args_json) python_repred_args = to_bytes(repr(module_args_json))
# these strings should be part of the 'basic' snippet which is required to be included # these strings should be part of the 'basic' snippet which is required to be included
module_data = module_data.replace(REPLACER_VERSION, repr(__version__)) module_data = module_data.replace(REPLACER_VERSION, to_bytes(__version__, nonstring='repr'))
module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args) module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args)
module_data = module_data.replace(REPLACER_WINARGS, module_args_json) module_data = module_data.replace(REPLACER_WINARGS, module_args_json)
module_data = module_data.replace(REPLACER_JSONARGS, module_args_json) module_data = module_data.replace(REPLACER_JSONARGS, module_args_json)
module_data = module_data.replace(REPLACER_SELINUX, ','.join(C.DEFAULT_SELINUX_SPECIAL_FS)) module_data = module_data.replace(REPLACER_SELINUX, to_bytes(','.join(C.DEFAULT_SELINUX_SPECIAL_FS)))
if module_style == 'new': if module_style == 'new':
facility = C.DEFAULT_SYSLOG_FACILITY facility = C.DEFAULT_SYSLOG_FACILITY
if 'ansible_syslog_facility' in task_vars: if 'ansible_syslog_facility' in task_vars:
facility = task_vars['ansible_syslog_facility'] facility = task_vars['ansible_syslog_facility']
module_data = module_data.replace('syslog.LOG_USER', "syslog.%s" % facility) module_data = module_data.replace(b'syslog.LOG_USER', to_bytes("syslog.%s" % facility))
lines = module_data.split(b"\n", 1) lines = module_data.split(b"\n", 1)
shebang = None shebang = None
@ -188,12 +188,13 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
args = shlex.split(str(shebang[2:])) args = shlex.split(str(shebang[2:]))
interpreter = args[0] interpreter = args[0]
interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter) interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter)
interpreter = to_bytes(interpreter)
if interpreter_config in task_vars: if interpreter_config in task_vars:
interpreter = to_bytes(task_vars[interpreter_config], errors='strict') interpreter = to_bytes(task_vars[interpreter_config], errors='strict')
lines[0] = shebang = b"#!{0} {1}".format(interpreter, b" ".join(args[1:])) lines[0] = shebang = b"#!{0} {1}".format(interpreter, b" ".join(args[1:]))
if os.path.basename(interpreter).startswith('python'): if os.path.basename(interpreter).startswith(b'python'):
lines.insert(1, ENCODING_STRING) lines.insert(1, ENCODING_STRING)
else: else:
# No shebang, assume a binary module? # No shebang, assume a binary module?

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# (c) 2015, Florian Apolloner <florian@apolloner.eu> # (c) 2015, Florian Apolloner <florian@apolloner.eu>
# #
# This file is part of Ansible # This file is part of Ansible
@ -34,15 +35,17 @@ from ansible import constants as C
from ansible.compat.six import text_type from ansible.compat.six import text_type
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock, mock_open from ansible.compat.tests.mock import patch, MagicMock, mock_open
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.playbook.play_context import PlayContext from ansible.playbook.play_context import PlayContext
from ansible.plugins import PluginLoader from ansible.plugins import PluginLoader
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.template import Templar from ansible.template import Templar
from ansible.utils.unicode import to_bytes
from units.mock.loader import DictDataLoader from units.mock.loader import DictDataLoader
python_module_replacers = """ python_module_replacers = b"""
#!/usr/bin/python #!/usr/bin/python
#ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>" #ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>"
@ -50,14 +53,95 @@ python_module_replacers = """
#MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>" #MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>"
#SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>" #SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>"
test = u'Toshio \u304f\u3089\u3068\u307f'
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
""" """
powershell_module_replacers = """ powershell_module_replacers = b"""
WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>" WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
# POWERSHELL_COMMON # POWERSHELL_COMMON
""" """
# Prior to 3.4.4, mock_open cannot handle binary read_data
if version_info >= (3,) and version_info < (3, 4, 4):
file_spec = None
def _iterate_read_data(read_data):
# Helper for mock_open:
# Retrieve lines from read_data via a generator so that separate calls to
# readline, read, and readlines are properly interleaved
sep = b'\n' if isinstance(read_data, bytes) else '\n'
data_as_list = [l + sep for l in read_data.split(sep)]
if data_as_list[-1] == sep:
# If the last line ended in a newline, the list comprehension will have an
# extra entry that's just a newline. Remove this.
data_as_list = data_as_list[:-1]
else:
# If there wasn't an extra newline by itself, then the file being
# emulated doesn't have a newline to end the last line remove the
# newline that our naive format() added
data_as_list[-1] = data_as_list[-1][:-1]
for line in data_as_list:
yield line
def mock_open(mock=None, read_data=''):
"""
A helper function to create a mock to replace the use of `open`. It works
for `open` called directly or used as a context manager.
The `mock` argument is the mock object to configure. If `None` (the
default) then a `MagicMock` will be created for you, with the API limited
to methods or attributes available on standard file handles.
`read_data` is a string for the `read` methoddline`, and `readlines` of the
file handle to return. This is an empty string by default.
"""
def _readlines_side_effect(*args, **kwargs):
if handle.readlines.return_value is not None:
return handle.readlines.return_value
return list(_data)
def _read_side_effect(*args, **kwargs):
if handle.read.return_value is not None:
return handle.read.return_value
return type(read_data)().join(_data)
def _readline_side_effect():
if handle.readline.return_value is not None:
while True:
yield handle.readline.return_value
for line in _data:
yield line
global file_spec
if file_spec is None:
import _io
file_spec = list(set(dir(_io.TextIOWrapper)).union(set(dir(_io.BytesIO))))
if mock is None:
mock = MagicMock(name='open', spec=open)
handle = MagicMock(spec=file_spec)
handle.__enter__.return_value = handle
_data = _iterate_read_data(read_data)
handle.write.return_value = None
handle.read.return_value = None
handle.readline.return_value = None
handle.readlines.return_value = None
handle.read.side_effect = _read_side_effect
handle.readline.side_effect = _readline_side_effect()
handle.readlines.side_effect = _readlines_side_effect
mock.return_value = handle
return mock
class DerivedActionBase(ActionBase): class DerivedActionBase(ActionBase):
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None):
# We're not testing the plugin run() method, just the helper # We're not testing the plugin run() method, just the helper
@ -124,18 +208,18 @@ class TestActionBase(unittest.TestCase):
) )
# test python module formatting # test python module formatting
with patch.object(builtins, 'open', mock_open(read_data=text_type(python_module_replacers.strip()))) as m: with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))) as m:
mock_task.args = dict(a=1) mock_task.args = dict(a=1)
mock_connection.module_implementation_preferences = ('',) mock_connection.module_implementation_preferences = ('',)
(style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args) (style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args)
self.assertEqual(style, "new") self.assertEqual(style, "new")
self.assertEqual(shebang, "#!/usr/bin/python") self.assertEqual(shebang, b"#!/usr/bin/python")
# test module not found # test module not found
self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args) self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args)
# test powershell module formatting # test powershell module formatting
with patch.object(builtins, 'open', mock_open(read_data=text_type(powershell_module_replacers.strip()))) as m: with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))) as m:
mock_task.action = 'win_copy' mock_task.action = 'win_copy'
mock_task.args = dict(b=2) mock_task.args = dict(b=2)
mock_connection.module_implementation_preferences = ('.ps1',) mock_connection.module_implementation_preferences = ('.ps1',)