Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def _is_bad_path(path, base):
bool: True if the path is not rooted under the base directory, False otherwise.
"""
# joinpath will ignore base if path is absolute
return not _get_resolved_path(joinpath(base, path)).startswith(base)
resolved = _get_resolved_path(joinpath(base, path))
return os.path.commonpath([resolved, base]) != base


def _is_bad_link(info, base):
Expand All @@ -77,19 +78,18 @@ def _is_bad_link(info, base):
return _is_bad_path(info.linkname, base=tip)


def _get_safe_members(members):
def _get_safe_members(members, base):
"""A generator that yields members that are safe to extract.

It filters out bad paths and bad links.

Args:
members (list): A list of members to check.
base (str): The base directory for extraction.

Yields:
tarfile.TarInfo: The tar file info.
"""
base = _get_resolved_path("")

for file_info in members:
if _is_bad_path(file_info.name, base):
logger.error("%s is blocked (illegal path)", file_info.name)
Expand Down Expand Up @@ -120,7 +120,8 @@ def custom_extractall_tarfile(tar, extract_path):
if hasattr(tarfile, "data_filter"):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))
base = _get_resolved_path(extract_path)
tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base))


def repack(inference_script, model_archive, source_dir=None): # pragma: no cover
Expand Down
8 changes: 4 additions & 4 deletions sagemaker-mlops/tests/unit/workflow/test_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_get_safe_members_all_safe():
mock_member2.islnk = Mock(return_value=False)

members = [mock_member1, mock_member2]
safe_members = list(_get_safe_members(members))
safe_members = list(_get_safe_members(members, "/tmp/extract"))

assert len(safe_members) == 2
assert mock_member1 in safe_members
Expand All @@ -128,7 +128,7 @@ def test_get_safe_members_filters_bad_path():
mock_is_bad.side_effect = lambda name, base: name == "/etc/passwd"

members = [mock_member_safe, mock_member_bad]
safe_members = list(_get_safe_members(members))
safe_members = list(_get_safe_members(members, "/tmp/extract"))

assert len(safe_members) == 1
assert mock_member_safe in safe_members
Expand All @@ -152,7 +152,7 @@ def test_get_safe_members_filters_bad_symlink():
mock_is_bad_link.return_value = True

members = [mock_member_safe, mock_member_symlink]
safe_members = list(_get_safe_members(members))
safe_members = list(_get_safe_members(members, "/tmp/extract"))

assert len(safe_members) == 1
assert mock_member_safe in safe_members
Expand All @@ -176,7 +176,7 @@ def test_get_safe_members_filters_bad_hardlink():
mock_is_bad_link.return_value = True

members = [mock_member_safe, mock_member_hardlink]
safe_members = list(_get_safe_members(members))
safe_members = list(_get_safe_members(members, "/tmp/extract"))

assert len(safe_members) == 1
assert mock_member_safe in safe_members
Expand Down