diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py index 109ac2cc72..15540fcd1f 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py @@ -57,7 +57,8 @@ def _is_bad_path(path, base): bool: True if the path is not rooted under the base directory, False otherwise. """ # joinpath will ignore base if path is absolute - return not _get_resolved_path(joinpath(base, path)).startswith(base) + resolved = _get_resolved_path(joinpath(base, path)) + return os.path.commonpath([resolved, base]) != base def _is_bad_link(info, base): @@ -77,19 +78,18 @@ def _is_bad_link(info, base): return _is_bad_path(info.linkname, base=tip) -def _get_safe_members(members): +def _get_safe_members(members, base): """A generator that yields members that are safe to extract. It filters out bad paths and bad links. Args: members (list): A list of members to check. + base (str): The base directory for extraction. Yields: tarfile.TarInfo: The tar file info. """ - base = _get_resolved_path("") - for file_info in members: if _is_bad_path(file_info.name, base): logger.error("%s is blocked (illegal path)", file_info.name) @@ -120,7 +120,8 @@ def custom_extractall_tarfile(tar, extract_path): if hasattr(tarfile, "data_filter"): tar.extractall(path=extract_path, filter="data") else: - tar.extractall(path=extract_path, members=_get_safe_members(tar)) + base = _get_resolved_path(extract_path) + tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base)) def repack(inference_script, model_archive, source_dir=None): # pragma: no cover diff --git a/sagemaker-mlops/tests/unit/workflow/test_repack_model.py b/sagemaker-mlops/tests/unit/workflow/test_repack_model.py index 5d8059a874..24936594be 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_repack_model.py +++ b/sagemaker-mlops/tests/unit/workflow/test_repack_model.py @@ -105,7 +105,7 @@ def test_get_safe_members_all_safe(): mock_member2.islnk = Mock(return_value=False) members = [mock_member1, mock_member2] - safe_members = list(_get_safe_members(members)) + safe_members = list(_get_safe_members(members, "/tmp/extract")) assert len(safe_members) == 2 assert mock_member1 in safe_members @@ -128,7 +128,7 @@ def test_get_safe_members_filters_bad_path(): mock_is_bad.side_effect = lambda name, base: name == "/etc/passwd" members = [mock_member_safe, mock_member_bad] - safe_members = list(_get_safe_members(members)) + safe_members = list(_get_safe_members(members, "/tmp/extract")) assert len(safe_members) == 1 assert mock_member_safe in safe_members @@ -152,7 +152,7 @@ def test_get_safe_members_filters_bad_symlink(): mock_is_bad_link.return_value = True members = [mock_member_safe, mock_member_symlink] - safe_members = list(_get_safe_members(members)) + safe_members = list(_get_safe_members(members, "/tmp/extract")) assert len(safe_members) == 1 assert mock_member_safe in safe_members @@ -176,7 +176,7 @@ def test_get_safe_members_filters_bad_hardlink(): mock_is_bad_link.return_value = True members = [mock_member_safe, mock_member_hardlink] - safe_members = list(_get_safe_members(members)) + safe_members = list(_get_safe_members(members, "/tmp/extract")) assert len(safe_members) == 1 assert mock_member_safe in safe_members