mirror of
https://github.com/ansible-collections/community.general.git
synced 2024-09-14 20:13:21 +02:00
Fix mixing of bytes and str in module replacer (caused traceback on python3)
This commit is contained in:
parent
b0bed27211
commit
c29f51804b
2 changed files with 123 additions and 38 deletions
|
@ -21,7 +21,7 @@ from __future__ import (absolute_import, division, print_function)
|
||||||
__metaclass__ = type
|
__metaclass__ = type
|
||||||
|
|
||||||
# from python and deps
|
# from python and deps
|
||||||
from ansible.compat.six.moves import StringIO
|
from io import BytesIO
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
|
@ -30,20 +30,20 @@ import shlex
|
||||||
from ansible import __version__
|
from ansible import __version__
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
from ansible.errors import AnsibleError
|
from ansible.errors import AnsibleError
|
||||||
from ansible.utils.unicode import to_bytes
|
from ansible.utils.unicode import to_bytes, to_unicode
|
||||||
|
|
||||||
REPLACER = "#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
|
REPLACER = b"#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
|
||||||
REPLACER_ARGS = "\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
|
REPLACER_ARGS = b"\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
|
||||||
REPLACER_COMPLEX = "\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
|
REPLACER_COMPLEX = b"\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
|
||||||
REPLACER_WINDOWS = "# POWERSHELL_COMMON"
|
REPLACER_WINDOWS = b"# POWERSHELL_COMMON"
|
||||||
REPLACER_WINARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
|
REPLACER_WINARGS = b"<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
|
||||||
REPLACER_JSONARGS = "<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
|
REPLACER_JSONARGS = b"<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
|
||||||
REPLACER_VERSION = "\"<<ANSIBLE_VERSION>>\""
|
REPLACER_VERSION = b"\"<<ANSIBLE_VERSION>>\""
|
||||||
REPLACER_SELINUX = "<<SELINUX_SPECIAL_FILESYSTEMS>>"
|
REPLACER_SELINUX = b"<<SELINUX_SPECIAL_FILESYSTEMS>>"
|
||||||
|
|
||||||
# We could end up writing out parameters with unicode characters so we need to
|
# We could end up writing out parameters with unicode characters so we need to
|
||||||
# specify an encoding for the python source file
|
# specify an encoding for the python source file
|
||||||
ENCODING_STRING = '# -*- coding: utf-8 -*-'
|
ENCODING_STRING = b'# -*- coding: utf-8 -*-'
|
||||||
|
|
||||||
# we've moved the module_common relative to the snippets, so fix the path
|
# we've moved the module_common relative to the snippets, so fix the path
|
||||||
_SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
|
_SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
|
||||||
|
@ -53,7 +53,7 @@ _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
|
||||||
def _slurp(path):
|
def _slurp(path):
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
raise AnsibleError("imported module support code does not exist at %s" % path)
|
raise AnsibleError("imported module support code does not exist at %s" % path)
|
||||||
fd = open(path)
|
fd = open(path, 'rb')
|
||||||
data = fd.read()
|
data = fd.read()
|
||||||
fd.close()
|
fd.close()
|
||||||
return data
|
return data
|
||||||
|
@ -71,49 +71,49 @@ def _find_snippet_imports(module_data, module_path, strip_comments):
|
||||||
module_style = 'new'
|
module_style = 'new'
|
||||||
elif REPLACER_JSONARGS in module_data:
|
elif REPLACER_JSONARGS in module_data:
|
||||||
module_style = 'new'
|
module_style = 'new'
|
||||||
elif 'from ansible.module_utils.' in module_data:
|
elif b'from ansible.module_utils.' in module_data:
|
||||||
module_style = 'new'
|
module_style = 'new'
|
||||||
elif 'WANT_JSON' in module_data:
|
elif b'WANT_JSON' in module_data:
|
||||||
module_style = 'non_native_want_json'
|
module_style = 'non_native_want_json'
|
||||||
|
|
||||||
output = StringIO()
|
output = BytesIO()
|
||||||
lines = module_data.split('\n')
|
lines = module_data.split(b'\n')
|
||||||
snippet_names = []
|
snippet_names = []
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
|
||||||
if REPLACER in line:
|
if REPLACER in line:
|
||||||
output.write(_slurp(os.path.join(_SNIPPET_PATH, "basic.py")))
|
output.write(_slurp(os.path.join(_SNIPPET_PATH, "basic.py")))
|
||||||
snippet_names.append('basic')
|
snippet_names.append(b'basic')
|
||||||
if REPLACER_WINDOWS in line:
|
if REPLACER_WINDOWS in line:
|
||||||
ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1"))
|
ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1"))
|
||||||
output.write(ps_data)
|
output.write(ps_data)
|
||||||
snippet_names.append('powershell')
|
snippet_names.append(b'powershell')
|
||||||
elif line.startswith('from ansible.module_utils.'):
|
elif line.startswith(b'from ansible.module_utils.'):
|
||||||
tokens=line.split(".")
|
tokens=line.split(b".")
|
||||||
import_error = False
|
import_error = False
|
||||||
if len(tokens) != 3:
|
if len(tokens) != 3:
|
||||||
import_error = True
|
import_error = True
|
||||||
if " import *" not in line:
|
if b" import *" not in line:
|
||||||
import_error = True
|
import_error = True
|
||||||
if import_error:
|
if import_error:
|
||||||
raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.<lib name> import *'" % module_path)
|
raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.<lib name> import *'" % module_path)
|
||||||
snippet_name = tokens[2].split()[0]
|
snippet_name = tokens[2].split()[0]
|
||||||
snippet_names.append(snippet_name)
|
snippet_names.append(snippet_name)
|
||||||
output.write(_slurp(os.path.join(_SNIPPET_PATH, snippet_name + ".py")))
|
output.write(_slurp(os.path.join(_SNIPPET_PATH, to_unicode(snippet_name) + ".py")))
|
||||||
else:
|
else:
|
||||||
if strip_comments and line.startswith("#") or line == '':
|
if strip_comments and line.startswith(b"#") or line == b'':
|
||||||
pass
|
pass
|
||||||
output.write(line)
|
output.write(line)
|
||||||
output.write("\n")
|
output.write(b"\n")
|
||||||
|
|
||||||
if not module_path.endswith(".ps1"):
|
if not module_path.endswith(".ps1"):
|
||||||
# Unixy modules
|
# Unixy modules
|
||||||
if len(snippet_names) > 0 and not 'basic' in snippet_names:
|
if len(snippet_names) > 0 and not b'basic' in snippet_names:
|
||||||
raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path)
|
raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path)
|
||||||
else:
|
else:
|
||||||
# Windows modules
|
# Windows modules
|
||||||
if len(snippet_names) > 0 and not 'powershell' in snippet_names:
|
if len(snippet_names) > 0 and not b'powershell' in snippet_names:
|
||||||
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
|
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
|
||||||
|
|
||||||
return (output.getvalue(), module_style)
|
return (output.getvalue(), module_style)
|
||||||
|
@ -158,28 +158,28 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
|
||||||
# * Cache the modified module? If only the args are different and we do
|
# * Cache the modified module? If only the args are different and we do
|
||||||
# that as the last step we could cache all the work up to that point.
|
# that as the last step we could cache all the work up to that point.
|
||||||
|
|
||||||
with open(module_path) as f:
|
with open(module_path, 'rb') as f:
|
||||||
|
|
||||||
# read in the module source
|
# read in the module source
|
||||||
module_data = f.read()
|
module_data = f.read()
|
||||||
|
|
||||||
(module_data, module_style) = _find_snippet_imports(module_data, module_path, strip_comments)
|
(module_data, module_style) = _find_snippet_imports(module_data, module_path, strip_comments)
|
||||||
|
|
||||||
module_args_json = json.dumps(module_args).encode('utf-8')
|
module_args_json = to_bytes(json.dumps(module_args))
|
||||||
python_repred_args = repr(module_args_json)
|
python_repred_args = to_bytes(repr(module_args_json))
|
||||||
|
|
||||||
# these strings should be part of the 'basic' snippet which is required to be included
|
# these strings should be part of the 'basic' snippet which is required to be included
|
||||||
module_data = module_data.replace(REPLACER_VERSION, repr(__version__))
|
module_data = module_data.replace(REPLACER_VERSION, to_bytes(__version__, nonstring='repr'))
|
||||||
module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args)
|
module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args)
|
||||||
module_data = module_data.replace(REPLACER_WINARGS, module_args_json)
|
module_data = module_data.replace(REPLACER_WINARGS, module_args_json)
|
||||||
module_data = module_data.replace(REPLACER_JSONARGS, module_args_json)
|
module_data = module_data.replace(REPLACER_JSONARGS, module_args_json)
|
||||||
module_data = module_data.replace(REPLACER_SELINUX, ','.join(C.DEFAULT_SELINUX_SPECIAL_FS))
|
module_data = module_data.replace(REPLACER_SELINUX, to_bytes(','.join(C.DEFAULT_SELINUX_SPECIAL_FS)))
|
||||||
|
|
||||||
if module_style == 'new':
|
if module_style == 'new':
|
||||||
facility = C.DEFAULT_SYSLOG_FACILITY
|
facility = C.DEFAULT_SYSLOG_FACILITY
|
||||||
if 'ansible_syslog_facility' in task_vars:
|
if 'ansible_syslog_facility' in task_vars:
|
||||||
facility = task_vars['ansible_syslog_facility']
|
facility = task_vars['ansible_syslog_facility']
|
||||||
module_data = module_data.replace('syslog.LOG_USER', "syslog.%s" % facility)
|
module_data = module_data.replace(b'syslog.LOG_USER', to_bytes("syslog.%s" % facility))
|
||||||
|
|
||||||
lines = module_data.split(b"\n", 1)
|
lines = module_data.split(b"\n", 1)
|
||||||
shebang = None
|
shebang = None
|
||||||
|
@ -188,12 +188,13 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
|
||||||
args = shlex.split(str(shebang[2:]))
|
args = shlex.split(str(shebang[2:]))
|
||||||
interpreter = args[0]
|
interpreter = args[0]
|
||||||
interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter)
|
interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter)
|
||||||
|
interpreter = to_bytes(interpreter)
|
||||||
|
|
||||||
if interpreter_config in task_vars:
|
if interpreter_config in task_vars:
|
||||||
interpreter = to_bytes(task_vars[interpreter_config], errors='strict')
|
interpreter = to_bytes(task_vars[interpreter_config], errors='strict')
|
||||||
lines[0] = shebang = b"#!{0} {1}".format(interpreter, b" ".join(args[1:]))
|
lines[0] = shebang = b"#!{0} {1}".format(interpreter, b" ".join(args[1:]))
|
||||||
|
|
||||||
if os.path.basename(interpreter).startswith('python'):
|
if os.path.basename(interpreter).startswith(b'python'):
|
||||||
lines.insert(1, ENCODING_STRING)
|
lines.insert(1, ENCODING_STRING)
|
||||||
else:
|
else:
|
||||||
# No shebang, assume a binary module?
|
# No shebang, assume a binary module?
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
# (c) 2015, Florian Apolloner <florian@apolloner.eu>
|
# (c) 2015, Florian Apolloner <florian@apolloner.eu>
|
||||||
#
|
#
|
||||||
# This file is part of Ansible
|
# This file is part of Ansible
|
||||||
|
@ -34,15 +35,17 @@ from ansible import constants as C
|
||||||
from ansible.compat.six import text_type
|
from ansible.compat.six import text_type
|
||||||
from ansible.compat.tests import unittest
|
from ansible.compat.tests import unittest
|
||||||
from ansible.compat.tests.mock import patch, MagicMock, mock_open
|
from ansible.compat.tests.mock import patch, MagicMock, mock_open
|
||||||
|
|
||||||
from ansible.errors import AnsibleError
|
from ansible.errors import AnsibleError
|
||||||
from ansible.playbook.play_context import PlayContext
|
from ansible.playbook.play_context import PlayContext
|
||||||
from ansible.plugins import PluginLoader
|
from ansible.plugins import PluginLoader
|
||||||
from ansible.plugins.action import ActionBase
|
from ansible.plugins.action import ActionBase
|
||||||
from ansible.template import Templar
|
from ansible.template import Templar
|
||||||
|
from ansible.utils.unicode import to_bytes
|
||||||
|
|
||||||
from units.mock.loader import DictDataLoader
|
from units.mock.loader import DictDataLoader
|
||||||
|
|
||||||
python_module_replacers = """
|
python_module_replacers = b"""
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
#ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>"
|
#ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>"
|
||||||
|
@ -50,14 +53,95 @@ python_module_replacers = """
|
||||||
#MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>"
|
#MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>"
|
||||||
#SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>"
|
#SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>"
|
||||||
|
|
||||||
|
test = u'Toshio \u304f\u3089\u3068\u307f'
|
||||||
from ansible.module_utils.basic import *
|
from ansible.module_utils.basic import *
|
||||||
"""
|
"""
|
||||||
|
|
||||||
powershell_module_replacers = """
|
powershell_module_replacers = b"""
|
||||||
WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
|
WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
|
||||||
# POWERSHELL_COMMON
|
# POWERSHELL_COMMON
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Prior to 3.4.4, mock_open cannot handle binary read_data
|
||||||
|
if version_info >= (3,) and version_info < (3, 4, 4):
|
||||||
|
file_spec = None
|
||||||
|
|
||||||
|
def _iterate_read_data(read_data):
|
||||||
|
# Helper for mock_open:
|
||||||
|
# Retrieve lines from read_data via a generator so that separate calls to
|
||||||
|
# readline, read, and readlines are properly interleaved
|
||||||
|
sep = b'\n' if isinstance(read_data, bytes) else '\n'
|
||||||
|
data_as_list = [l + sep for l in read_data.split(sep)]
|
||||||
|
|
||||||
|
if data_as_list[-1] == sep:
|
||||||
|
# If the last line ended in a newline, the list comprehension will have an
|
||||||
|
# extra entry that's just a newline. Remove this.
|
||||||
|
data_as_list = data_as_list[:-1]
|
||||||
|
else:
|
||||||
|
# If there wasn't an extra newline by itself, then the file being
|
||||||
|
# emulated doesn't have a newline to end the last line remove the
|
||||||
|
# newline that our naive format() added
|
||||||
|
data_as_list[-1] = data_as_list[-1][:-1]
|
||||||
|
|
||||||
|
for line in data_as_list:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
def mock_open(mock=None, read_data=''):
|
||||||
|
"""
|
||||||
|
A helper function to create a mock to replace the use of `open`. It works
|
||||||
|
for `open` called directly or used as a context manager.
|
||||||
|
|
||||||
|
The `mock` argument is the mock object to configure. If `None` (the
|
||||||
|
default) then a `MagicMock` will be created for you, with the API limited
|
||||||
|
to methods or attributes available on standard file handles.
|
||||||
|
|
||||||
|
`read_data` is a string for the `read` methoddline`, and `readlines` of the
|
||||||
|
file handle to return. This is an empty string by default.
|
||||||
|
"""
|
||||||
|
def _readlines_side_effect(*args, **kwargs):
|
||||||
|
if handle.readlines.return_value is not None:
|
||||||
|
return handle.readlines.return_value
|
||||||
|
return list(_data)
|
||||||
|
|
||||||
|
def _read_side_effect(*args, **kwargs):
|
||||||
|
if handle.read.return_value is not None:
|
||||||
|
return handle.read.return_value
|
||||||
|
return type(read_data)().join(_data)
|
||||||
|
|
||||||
|
def _readline_side_effect():
|
||||||
|
if handle.readline.return_value is not None:
|
||||||
|
while True:
|
||||||
|
yield handle.readline.return_value
|
||||||
|
for line in _data:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
global file_spec
|
||||||
|
if file_spec is None:
|
||||||
|
import _io
|
||||||
|
file_spec = list(set(dir(_io.TextIOWrapper)).union(set(dir(_io.BytesIO))))
|
||||||
|
|
||||||
|
if mock is None:
|
||||||
|
mock = MagicMock(name='open', spec=open)
|
||||||
|
|
||||||
|
handle = MagicMock(spec=file_spec)
|
||||||
|
handle.__enter__.return_value = handle
|
||||||
|
|
||||||
|
_data = _iterate_read_data(read_data)
|
||||||
|
|
||||||
|
handle.write.return_value = None
|
||||||
|
handle.read.return_value = None
|
||||||
|
handle.readline.return_value = None
|
||||||
|
handle.readlines.return_value = None
|
||||||
|
|
||||||
|
handle.read.side_effect = _read_side_effect
|
||||||
|
handle.readline.side_effect = _readline_side_effect()
|
||||||
|
handle.readlines.side_effect = _readlines_side_effect
|
||||||
|
|
||||||
|
mock.return_value = handle
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
class DerivedActionBase(ActionBase):
|
class DerivedActionBase(ActionBase):
|
||||||
def run(self, tmp=None, task_vars=None):
|
def run(self, tmp=None, task_vars=None):
|
||||||
# We're not testing the plugin run() method, just the helper
|
# We're not testing the plugin run() method, just the helper
|
||||||
|
@ -124,18 +208,18 @@ class TestActionBase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# test python module formatting
|
# test python module formatting
|
||||||
with patch.object(builtins, 'open', mock_open(read_data=text_type(python_module_replacers.strip()))) as m:
|
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))) as m:
|
||||||
mock_task.args = dict(a=1)
|
mock_task.args = dict(a=1)
|
||||||
mock_connection.module_implementation_preferences = ('',)
|
mock_connection.module_implementation_preferences = ('',)
|
||||||
(style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args)
|
(style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args)
|
||||||
self.assertEqual(style, "new")
|
self.assertEqual(style, "new")
|
||||||
self.assertEqual(shebang, "#!/usr/bin/python")
|
self.assertEqual(shebang, b"#!/usr/bin/python")
|
||||||
|
|
||||||
# test module not found
|
# test module not found
|
||||||
self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args)
|
self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args)
|
||||||
|
|
||||||
# test powershell module formatting
|
# test powershell module formatting
|
||||||
with patch.object(builtins, 'open', mock_open(read_data=text_type(powershell_module_replacers.strip()))) as m:
|
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))) as m:
|
||||||
mock_task.action = 'win_copy'
|
mock_task.action = 'win_copy'
|
||||||
mock_task.args = dict(b=2)
|
mock_task.args = dict(b=2)
|
||||||
mock_connection.module_implementation_preferences = ('.ps1',)
|
mock_connection.module_implementation_preferences = ('.ps1',)
|
||||||
|
|
Loading…
Reference in a new issue