diff --git a/changelogs/fragments/3036-archive-root-path-fix.yml b/changelogs/fragments/3036-archive-root-path-fix.yml new file mode 100644 index 0000000000..fa460f82b9 --- /dev/null +++ b/changelogs/fragments/3036-archive-root-path-fix.yml @@ -0,0 +1,4 @@ +--- +bugfixes: + - archive - fixing archive root determination when longest common root is ``/`` + (https://github.com/ansible-collections/community.general/pull/3036). diff --git a/plugins/modules/files/archive.py b/plugins/modules/files/archive.py index 91a8f688f5..30c4de5aa8 100644 --- a/plugins/modules/files/archive.py +++ b/plugins/modules/files/archive.py @@ -204,7 +204,6 @@ else: LZMA_IMP_ERR = format_exc() HAS_LZMA = False -PATH_SEP = to_bytes(os.sep) PY27 = version_info[0:2] >= (2, 7) STATE_ABSENT = 'absent' @@ -213,16 +212,12 @@ STATE_COMPRESSED = 'compress' STATE_INCOMPLETE = 'incomplete' -def _to_bytes(s): - return to_bytes(s, errors='surrogate_or_strict') +def common_path(paths): + empty = b'' if paths and isinstance(paths[0], six.binary_type) else '' - -def _to_native(s): - return to_native(s, errors='surrogate_or_strict') - - -def _to_native_ascii(s): - return to_native(s, errors='surrogate_or_strict', encoding='ascii') + return os.path.join( + os.path.dirname(os.path.commonprefix([os.path.join(os.path.dirname(p), empty) for p in paths])), empty + ) def expand_paths(paths): @@ -239,10 +234,6 @@ def expand_paths(paths): return expanded_path, is_globby -def is_archive(path): - return re.search(br'\.(tar|tar\.(gz|bz2|xz)|tgz|tbz2|zip)$', os.path.basename(path), re.IGNORECASE) - - def legacy_filter(path, exclusion_patterns): return matches_exclusion_patterns(path, exclusion_patterns) @@ -251,6 +242,26 @@ def matches_exclusion_patterns(path, exclusion_patterns): return any(fnmatch(path, p) for p in exclusion_patterns) +def is_archive(path): + return re.search(br'\.(tar|tar\.(gz|bz2|xz)|tgz|tbz2|zip)$', os.path.basename(path), re.IGNORECASE) + + +def strip_prefix(prefix, string): + return string[len(prefix):] if string.startswith(prefix) else string + + +def _to_bytes(s): + return to_bytes(s, errors='surrogate_or_strict') + + +def _to_native(s): + return to_native(s, errors='surrogate_or_strict') + + +def _to_native_ascii(s): + return to_native(s, errors='surrogate_or_strict', encoding='ascii') + + @six.add_metaclass(abc.ABCMeta) class Archive(object): def __init__(self, module): @@ -266,7 +277,6 @@ class Archive(object): self.destination_state = STATE_ABSENT self.errors = [] self.file = None - self.root = b'' self.successes = [] self.targets = [] self.not_found = [] @@ -275,7 +285,7 @@ class Archive(object): self.expanded_paths, has_globs = expand_paths(paths) self.expanded_exclude_paths = expand_paths(module.params['exclude_path'])[0] - self.paths = list(set(self.expanded_paths) - set(self.expanded_exclude_paths)) + self.paths = sorted(set(self.expanded_paths) - set(self.expanded_exclude_paths)) if not self.paths: module.fail_json( @@ -285,6 +295,8 @@ class Archive(object): msg='Error, no source paths were found' ) + self.root = common_path(self.paths) + if not self.must_archive: self.must_archive = any([has_globs, os.path.isdir(self.paths[0]), len(self.paths) > 1]) @@ -298,6 +310,9 @@ class Archive(object): msg='Error, must specify "dest" when archiving multiple files or trees' ) + if self.remove: + self._check_removal_safety() + self.original_size = self.destination_size() def add(self, path, archive_name): @@ -310,9 +325,8 @@ class Archive(object): def add_single_target(self, path): if self.format in ('zip', 'tar'): - archive_name = re.sub(br'^%s' % re.escape(self.root), b'', path) self.open() - self.add(path, archive_name) + self.add(path, strip_prefix(self.root, path)) self.close() self.destination_state = STATE_ARCHIVED else: @@ -333,25 +347,18 @@ class Archive(object): def add_targets(self): self.open() try: - match_root = re.compile(br'^%s' % re.escape(self.root)) for target in self.targets: if os.path.isdir(target): for directory_path, directory_names, file_names in os.walk(target, topdown=True): - if not directory_path.endswith(PATH_SEP): - directory_path += PATH_SEP - for directory_name in directory_names: - full_path = directory_path + directory_name - archive_name = match_root.sub(b'', full_path) - self.add(full_path, archive_name) + full_path = os.path.join(directory_path, directory_name) + self.add(full_path, strip_prefix(self.root, full_path)) for file_name in file_names: - full_path = directory_path + file_name - archive_name = match_root.sub(b'', full_path) - self.add(full_path, archive_name) + full_path = os.path.join(directory_path, file_name) + self.add(full_path, strip_prefix(self.root, full_path)) else: - archive_name = match_root.sub(b'', target) - self.add(target, archive_name) + self.add(target, strip_prefix(self.root, target)) except Exception as e: if self.format in ('zip', 'tar'): archive_format = self.format @@ -384,26 +391,6 @@ class Archive(object): def find_targets(self): for path in self.paths: - # Use the longest common directory name among all the files as the archive root path - if self.root == b'': - self.root = os.path.dirname(path) + PATH_SEP - else: - for i in range(len(self.root)): - if path[i] != self.root[i]: - break - - if i < len(self.root): - self.root = os.path.dirname(self.root[0:i + 1]) - - self.root += PATH_SEP - # Don't allow archives to be created anywhere within paths to be removed - if self.remove and os.path.isdir(path): - prefix = path if path.endswith(PATH_SEP) else path + PATH_SEP - if self.destination.startswith(prefix): - self.module.fail_json( - path=', '.join(self.paths), - msg='Error, created archive can not be contained in source paths when remove=true' - ) if not os.path.lexists(path): self.not_found.append(path) else: @@ -470,6 +457,14 @@ class Archive(object): 'expanded_exclude_paths': [_to_native(p) for p in self.expanded_exclude_paths], } + def _check_removal_safety(self): + for path in self.paths: + if os.path.isdir(path) and self.destination.startswith(os.path.join(path, b'')): + self.module.fail_json( + path=b', '.join(self.paths), + msg='Error, created archive can not be contained in source paths when remove=true' + ) + def _open_compressed_file(self, path, mode): f = None if self.format == 'gz': diff --git a/tests/unit/plugins/modules/files/test_archive.py b/tests/unit/plugins/modules/files/test_archive.py new file mode 100644 index 0000000000..9fae51e7b7 --- /dev/null +++ b/tests/unit/plugins/modules/files/test_archive.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.community.general.tests.unit.compat.mock import Mock, patch +from ansible_collections.community.general.tests.unit.plugins.modules.utils import ModuleTestCase, set_module_args +from ansible_collections.community.general.plugins.modules.files.archive import get_archive, common_path + + +class TestArchive(ModuleTestCase): + def setUp(self): + super(TestArchive, self).setUp() + + self.mock_os_path_isdir = patch('os.path.isdir') + self.os_path_isdir = self.mock_os_path_isdir.start() + + def tearDown(self): + self.os_path_isdir = self.mock_os_path_isdir.stop() + + def test_archive_removal_safety(self): + set_module_args( + dict( + path=['/foo', '/bar', '/baz'], + dest='/foo/destination.tgz', + remove=True + ) + ) + + module = AnsibleModule( + argument_spec=dict( + path=dict(type='list', elements='path', required=True), + format=dict(type='str', default='gz', choices=['bz2', 'gz', 'tar', 'xz', 'zip']), + dest=dict(type='path'), + exclude_path=dict(type='list', elements='path', default=[]), + exclusion_patterns=dict(type='list', elements='path'), + force_archive=dict(type='bool', default=False), + remove=dict(type='bool', default=False), + ), + add_file_common_args=True, + supports_check_mode=True, + ) + + self.os_path_isdir.side_effect = [True, False, False, True] + + module.fail_json = Mock() + + archive = get_archive(module) + + module.fail_json.assert_called_once_with( + path=b', '.join(archive.paths), + msg='Error, created archive can not be contained in source paths when remove=true' + ) + + +PATHS = ( + ([], ''), + (['/'], '/'), + ([b'/'], b'/'), + (['/foo', '/bar', '/baz', '/foobar', '/barbaz', '/foo/bar'], '/'), + ([b'/foo', b'/bar', b'/baz', b'/foobar', b'/barbaz', b'/foo/bar'], b'/'), + (['/foo/bar/baz', '/foo/bar'], '/foo/'), + (['/foo/bar/baz', '/foo/bar/'], '/foo/bar/'), +) + + +@pytest.mark.parametrize("paths,root", PATHS) +def test_common_path(paths, root): + assert common_path(paths) == root