1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

archive - fixing determination of archive root when root is '/' (#3036)

* Initial commit

* Fixing units and path joins

* Ensuring paths are consistently ordered

* Adding changelog fragment

* Using os.path.join to ensure trailing slashes are present

* optimizing use of root in add_targets

* Applying initial review suggestions
This commit is contained in:
Ajpantuso 2021-07-24 16:10:56 -04:00 committed by GitHub
parent d057b2e3b2
commit 31189e9645
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 50 deletions

View file

@ -0,0 +1,4 @@
---
bugfixes:
- archive - fixing archive root determination when longest common root is ``/``
(https://github.com/ansible-collections/community.general/pull/3036).

View file

@ -204,7 +204,6 @@ else:
LZMA_IMP_ERR = format_exc() LZMA_IMP_ERR = format_exc()
HAS_LZMA = False HAS_LZMA = False
PATH_SEP = to_bytes(os.sep)
PY27 = version_info[0:2] >= (2, 7) PY27 = version_info[0:2] >= (2, 7)
STATE_ABSENT = 'absent' STATE_ABSENT = 'absent'
@ -213,16 +212,12 @@ STATE_COMPRESSED = 'compress'
STATE_INCOMPLETE = 'incomplete' STATE_INCOMPLETE = 'incomplete'
def _to_bytes(s): def common_path(paths):
return to_bytes(s, errors='surrogate_or_strict') empty = b'' if paths and isinstance(paths[0], six.binary_type) else ''
return os.path.join(
def _to_native(s): os.path.dirname(os.path.commonprefix([os.path.join(os.path.dirname(p), empty) for p in paths])), empty
return to_native(s, errors='surrogate_or_strict') )
def _to_native_ascii(s):
return to_native(s, errors='surrogate_or_strict', encoding='ascii')
def expand_paths(paths): def expand_paths(paths):
@ -239,10 +234,6 @@ def expand_paths(paths):
return expanded_path, is_globby return expanded_path, is_globby
def is_archive(path):
return re.search(br'\.(tar|tar\.(gz|bz2|xz)|tgz|tbz2|zip)$', os.path.basename(path), re.IGNORECASE)
def legacy_filter(path, exclusion_patterns): def legacy_filter(path, exclusion_patterns):
return matches_exclusion_patterns(path, exclusion_patterns) return matches_exclusion_patterns(path, exclusion_patterns)
@ -251,6 +242,26 @@ def matches_exclusion_patterns(path, exclusion_patterns):
return any(fnmatch(path, p) for p in exclusion_patterns) return any(fnmatch(path, p) for p in exclusion_patterns)
def is_archive(path):
return re.search(br'\.(tar|tar\.(gz|bz2|xz)|tgz|tbz2|zip)$', os.path.basename(path), re.IGNORECASE)
def strip_prefix(prefix, string):
return string[len(prefix):] if string.startswith(prefix) else string
def _to_bytes(s):
return to_bytes(s, errors='surrogate_or_strict')
def _to_native(s):
return to_native(s, errors='surrogate_or_strict')
def _to_native_ascii(s):
return to_native(s, errors='surrogate_or_strict', encoding='ascii')
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class Archive(object): class Archive(object):
def __init__(self, module): def __init__(self, module):
@ -266,7 +277,6 @@ class Archive(object):
self.destination_state = STATE_ABSENT self.destination_state = STATE_ABSENT
self.errors = [] self.errors = []
self.file = None self.file = None
self.root = b''
self.successes = [] self.successes = []
self.targets = [] self.targets = []
self.not_found = [] self.not_found = []
@ -275,7 +285,7 @@ class Archive(object):
self.expanded_paths, has_globs = expand_paths(paths) self.expanded_paths, has_globs = expand_paths(paths)
self.expanded_exclude_paths = expand_paths(module.params['exclude_path'])[0] self.expanded_exclude_paths = expand_paths(module.params['exclude_path'])[0]
self.paths = list(set(self.expanded_paths) - set(self.expanded_exclude_paths)) self.paths = sorted(set(self.expanded_paths) - set(self.expanded_exclude_paths))
if not self.paths: if not self.paths:
module.fail_json( module.fail_json(
@ -285,6 +295,8 @@ class Archive(object):
msg='Error, no source paths were found' msg='Error, no source paths were found'
) )
self.root = common_path(self.paths)
if not self.must_archive: if not self.must_archive:
self.must_archive = any([has_globs, os.path.isdir(self.paths[0]), len(self.paths) > 1]) self.must_archive = any([has_globs, os.path.isdir(self.paths[0]), len(self.paths) > 1])
@ -298,6 +310,9 @@ class Archive(object):
msg='Error, must specify "dest" when archiving multiple files or trees' msg='Error, must specify "dest" when archiving multiple files or trees'
) )
if self.remove:
self._check_removal_safety()
self.original_size = self.destination_size() self.original_size = self.destination_size()
def add(self, path, archive_name): def add(self, path, archive_name):
@ -310,9 +325,8 @@ class Archive(object):
def add_single_target(self, path): def add_single_target(self, path):
if self.format in ('zip', 'tar'): if self.format in ('zip', 'tar'):
archive_name = re.sub(br'^%s' % re.escape(self.root), b'', path)
self.open() self.open()
self.add(path, archive_name) self.add(path, strip_prefix(self.root, path))
self.close() self.close()
self.destination_state = STATE_ARCHIVED self.destination_state = STATE_ARCHIVED
else: else:
@ -333,25 +347,18 @@ class Archive(object):
def add_targets(self): def add_targets(self):
self.open() self.open()
try: try:
match_root = re.compile(br'^%s' % re.escape(self.root))
for target in self.targets: for target in self.targets:
if os.path.isdir(target): if os.path.isdir(target):
for directory_path, directory_names, file_names in os.walk(target, topdown=True): for directory_path, directory_names, file_names in os.walk(target, topdown=True):
if not directory_path.endswith(PATH_SEP):
directory_path += PATH_SEP
for directory_name in directory_names: for directory_name in directory_names:
full_path = directory_path + directory_name full_path = os.path.join(directory_path, directory_name)
archive_name = match_root.sub(b'', full_path) self.add(full_path, strip_prefix(self.root, full_path))
self.add(full_path, archive_name)
for file_name in file_names: for file_name in file_names:
full_path = directory_path + file_name full_path = os.path.join(directory_path, file_name)
archive_name = match_root.sub(b'', full_path) self.add(full_path, strip_prefix(self.root, full_path))
self.add(full_path, archive_name)
else: else:
archive_name = match_root.sub(b'', target) self.add(target, strip_prefix(self.root, target))
self.add(target, archive_name)
except Exception as e: except Exception as e:
if self.format in ('zip', 'tar'): if self.format in ('zip', 'tar'):
archive_format = self.format archive_format = self.format
@ -384,26 +391,6 @@ class Archive(object):
def find_targets(self): def find_targets(self):
for path in self.paths: for path in self.paths:
# Use the longest common directory name among all the files as the archive root path
if self.root == b'':
self.root = os.path.dirname(path) + PATH_SEP
else:
for i in range(len(self.root)):
if path[i] != self.root[i]:
break
if i < len(self.root):
self.root = os.path.dirname(self.root[0:i + 1])
self.root += PATH_SEP
# Don't allow archives to be created anywhere within paths to be removed
if self.remove and os.path.isdir(path):
prefix = path if path.endswith(PATH_SEP) else path + PATH_SEP
if self.destination.startswith(prefix):
self.module.fail_json(
path=', '.join(self.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)
if not os.path.lexists(path): if not os.path.lexists(path):
self.not_found.append(path) self.not_found.append(path)
else: else:
@ -470,6 +457,14 @@ class Archive(object):
'expanded_exclude_paths': [_to_native(p) for p in self.expanded_exclude_paths], 'expanded_exclude_paths': [_to_native(p) for p in self.expanded_exclude_paths],
} }
def _check_removal_safety(self):
for path in self.paths:
if os.path.isdir(path) and self.destination.startswith(os.path.join(path, b'')):
self.module.fail_json(
path=b', '.join(self.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)
def _open_compressed_file(self, path, mode): def _open_compressed_file(self, path, mode):
f = None f = None
if self.format == 'gz': if self.format == 'gz':

View file

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import pytest
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.tests.unit.compat.mock import Mock, patch
from ansible_collections.community.general.tests.unit.plugins.modules.utils import ModuleTestCase, set_module_args
from ansible_collections.community.general.plugins.modules.files.archive import get_archive, common_path
class TestArchive(ModuleTestCase):
def setUp(self):
super(TestArchive, self).setUp()
self.mock_os_path_isdir = patch('os.path.isdir')
self.os_path_isdir = self.mock_os_path_isdir.start()
def tearDown(self):
self.os_path_isdir = self.mock_os_path_isdir.stop()
def test_archive_removal_safety(self):
set_module_args(
dict(
path=['/foo', '/bar', '/baz'],
dest='/foo/destination.tgz',
remove=True
)
)
module = AnsibleModule(
argument_spec=dict(
path=dict(type='list', elements='path', required=True),
format=dict(type='str', default='gz', choices=['bz2', 'gz', 'tar', 'xz', 'zip']),
dest=dict(type='path'),
exclude_path=dict(type='list', elements='path', default=[]),
exclusion_patterns=dict(type='list', elements='path'),
force_archive=dict(type='bool', default=False),
remove=dict(type='bool', default=False),
),
add_file_common_args=True,
supports_check_mode=True,
)
self.os_path_isdir.side_effect = [True, False, False, True]
module.fail_json = Mock()
archive = get_archive(module)
module.fail_json.assert_called_once_with(
path=b', '.join(archive.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)
PATHS = (
([], ''),
(['/'], '/'),
([b'/'], b'/'),
(['/foo', '/bar', '/baz', '/foobar', '/barbaz', '/foo/bar'], '/'),
([b'/foo', b'/bar', b'/baz', b'/foobar', b'/barbaz', b'/foo/bar'], b'/'),
(['/foo/bar/baz', '/foo/bar'], '/foo/'),
(['/foo/bar/baz', '/foo/bar/'], '/foo/bar/'),
)
@pytest.mark.parametrize("paths,root", PATHS)
def test_common_path(paths, root):
assert common_path(paths) == root