1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

[PR #3036/31189e96 backport][stable-2] archive - fixing determination of archive root when root is '/' (#3065)

* archive - fixing determination of archive root when root is '/' (#3036)

* Initial commit

* Fixing units and path joins

* Ensuring paths are consistently ordered

* Adding changelog fragment

* Using os.path.join to ensure trailing slashes are present

* optimizing use of root in add_targets

* Applying initial review suggestions

(cherry picked from commit 31189e9645)

* removing unneccessary addition
This commit is contained in:
Ajpantuso 2021-07-25 03:52:37 -04:00 committed by GitHub
parent 4281331639
commit 1da0db4984
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 46 deletions

View file

@ -0,0 +1,4 @@
---
bugfixes:
- archive - fixing archive root determination when longest common root is ``/``
(https://github.com/ansible-collections/community.general/pull/3036).

View file

@ -184,7 +184,6 @@ else:
LZMA_IMP_ERR = format_exc()
HAS_LZMA = False
PATH_SEP = to_bytes(os.sep)
PY27 = version_info[0:2] >= (2, 7)
STATE_ABSENT = 'absent'
@ -193,16 +192,12 @@ STATE_COMPRESSED = 'compress'
STATE_INCOMPLETE = 'incomplete'
def _to_bytes(s):
return to_bytes(s, errors='surrogate_or_strict')
def common_path(paths):
empty = b'' if paths and isinstance(paths[0], six.binary_type) else ''
def _to_native(s):
return to_native(s, errors='surrogate_or_strict')
def _to_native_ascii(s):
return to_native(s, errors='surrogate_or_strict', encoding='ascii')
return os.path.join(
os.path.dirname(os.path.commonprefix([os.path.join(os.path.dirname(p), empty) for p in paths])), empty
)
def expand_paths(paths):
@ -223,6 +218,22 @@ def is_archive(path):
return re.search(br'\.(tar|tar\.(gz|bz2|xz)|tgz|tbz2|zip)$', os.path.basename(path), re.IGNORECASE)
def strip_prefix(prefix, string):
return string[len(prefix):] if string.startswith(prefix) else string
def _to_bytes(s):
return to_bytes(s, errors='surrogate_or_strict')
def _to_native(s):
return to_native(s, errors='surrogate_or_strict')
def _to_native_ascii(s):
return to_native(s, errors='surrogate_or_strict', encoding='ascii')
@six.add_metaclass(abc.ABCMeta)
class Archive(object):
def __init__(self, module):
@ -237,7 +248,6 @@ class Archive(object):
self.destination_state = STATE_ABSENT
self.errors = []
self.file = None
self.root = b''
self.successes = []
self.targets = []
self.not_found = []
@ -246,7 +256,7 @@ class Archive(object):
self.expanded_paths, has_globs = expand_paths(paths)
self.expanded_exclude_paths = expand_paths(module.params['exclude_path'])[0]
self.paths = list(set(self.expanded_paths) - set(self.expanded_exclude_paths))
self.paths = sorted(set(self.expanded_paths) - set(self.expanded_exclude_paths))
if not self.paths:
module.fail_json(
@ -256,6 +266,8 @@ class Archive(object):
msg='Error, no source paths were found'
)
self.root = common_path(self.paths)
if not self.must_archive:
self.must_archive = any([has_globs, os.path.isdir(self.paths[0]), len(self.paths) > 1])
@ -269,6 +281,9 @@ class Archive(object):
msg='Error, must specify "dest" when archiving multiple files or trees'
)
if self.remove:
self._check_removal_safety()
def add(self, path, archive_name):
try:
self._add(_to_native_ascii(path), _to_native(archive_name))
@ -279,9 +294,8 @@ class Archive(object):
def add_single_target(self, path):
if self.format in ('zip', 'tar'):
archive_name = re.sub(br'^%s' % re.escape(self.root), b'', path)
self.open()
self.add(path, archive_name)
self.add(path, strip_prefix(self.root, path))
self.close()
self.destination_state = STATE_ARCHIVED
else:
@ -302,25 +316,18 @@ class Archive(object):
def add_targets(self):
self.open()
try:
match_root = re.compile(br'^%s' % re.escape(self.root))
for target in self.targets:
if os.path.isdir(target):
for directory_path, directory_names, file_names in os.walk(target, topdown=True):
if not directory_path.endswith(PATH_SEP):
directory_path += PATH_SEP
for directory_name in directory_names:
full_path = directory_path + directory_name
archive_name = match_root.sub(b'', full_path)
self.add(full_path, archive_name)
full_path = os.path.join(directory_path, directory_name)
self.add(full_path, strip_prefix(self.root, full_path))
for file_name in file_names:
full_path = directory_path + file_name
archive_name = match_root.sub(b'', full_path)
self.add(full_path, archive_name)
full_path = os.path.join(directory_path, file_name)
self.add(full_path, strip_prefix(self.root, full_path))
else:
archive_name = match_root.sub(b'', target)
self.add(target, archive_name)
self.add(target, strip_prefix(self.root, target))
except Exception as e:
if self.format in ('zip', 'tar'):
archive_format = self.format
@ -347,26 +354,6 @@ class Archive(object):
def find_targets(self):
for path in self.paths:
# Use the longest common directory name among all the files as the archive root path
if self.root == b'':
self.root = os.path.dirname(path) + PATH_SEP
else:
for i in range(len(self.root)):
if path[i] != self.root[i]:
break
if i < len(self.root):
self.root = os.path.dirname(self.root[0:i + 1])
self.root += PATH_SEP
# Don't allow archives to be created anywhere within paths to be removed
if self.remove and os.path.isdir(path):
prefix = path if path.endswith(PATH_SEP) else path + PATH_SEP
if self.destination.startswith(prefix):
self.module.fail_json(
path=', '.join(self.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)
if not os.path.lexists(path):
self.not_found.append(path)
else:
@ -423,6 +410,14 @@ class Archive(object):
'expanded_exclude_paths': [_to_native(p) for p in self.expanded_exclude_paths],
}
def _check_removal_safety(self):
for path in self.paths:
if os.path.isdir(path) and self.destination.startswith(os.path.join(path, b'')):
self.module.fail_json(
path=b', '.join(self.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)
def _open_compressed_file(self, path):
f = None
if self.format == 'gz':

View file

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import pytest
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.tests.unit.compat.mock import Mock, patch
from ansible_collections.community.general.tests.unit.plugins.modules.utils import ModuleTestCase, set_module_args
from ansible_collections.community.general.plugins.modules.files.archive import get_archive, common_path
class TestArchive(ModuleTestCase):
def setUp(self):
super(TestArchive, self).setUp()
self.mock_os_path_isdir = patch('os.path.isdir')
self.os_path_isdir = self.mock_os_path_isdir.start()
def tearDown(self):
self.os_path_isdir = self.mock_os_path_isdir.stop()
def test_archive_removal_safety(self):
set_module_args(
dict(
path=['/foo', '/bar', '/baz'],
dest='/foo/destination.tgz',
remove=True
)
)
module = AnsibleModule(
argument_spec=dict(
path=dict(type='list', elements='path', required=True),
format=dict(type='str', default='gz', choices=['bz2', 'gz', 'tar', 'xz', 'zip']),
dest=dict(type='path'),
exclude_path=dict(type='list', elements='path', default=[]),
exclusion_patterns=dict(type='list', elements='path'),
force_archive=dict(type='bool', default=False),
remove=dict(type='bool', default=False),
),
add_file_common_args=True,
supports_check_mode=True,
)
self.os_path_isdir.side_effect = [True, False, False, True]
module.fail_json = Mock()
archive = get_archive(module)
module.fail_json.assert_called_once_with(
path=b', '.join(archive.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)
PATHS = (
([], ''),
(['/'], '/'),
([b'/'], b'/'),
(['/foo', '/bar', '/baz', '/foobar', '/barbaz', '/foo/bar'], '/'),
([b'/foo', b'/bar', b'/baz', b'/foobar', b'/barbaz', b'/foo/bar'], b'/'),
(['/foo/bar/baz', '/foo/bar'], '/foo/'),
(['/foo/bar/baz', '/foo/bar/'], '/foo/bar/'),
)
@pytest.mark.parametrize("paths,root", PATHS)
def test_common_path(paths, root):
assert common_path(paths) == root