From 712e4b46e648bf8b3164bd2eda07e477858c73ba Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 16 Mar 2026 05:24:21 -0400 Subject: [PATCH 1/3] fix: enforce decompression limits for compressed tar wrappers --- modelaudit/scanners/tar_scanner.py | 71 ++++++++++++++++++++++++++++++ tests/scanners/test_tar_scanner.py | 38 +++++++++++++++- 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/modelaudit/scanners/tar_scanner.py b/modelaudit/scanners/tar_scanner.py index 533e7903..f6456089 100644 --- a/modelaudit/scanners/tar_scanner.py +++ b/modelaudit/scanners/tar_scanner.py @@ -28,6 +28,10 @@ ] DEFAULT_MAX_TAR_ENTRY_SIZE = 1024 * 1024 * 1024 +DEFAULT_MAX_DECOMPRESSED_BYTES = 512 * 1024 * 1024 +DEFAULT_MAX_DECOMPRESSION_RATIO = 250.0 + +COMPRESSED_TAR_EXTENSIONS = {".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar.xz", ".txz"} class TarScanner(BaseScanner): @@ -49,6 +53,12 @@ def __init__(self, config: dict[str, Any] | None = None) -> None: super().__init__(config) self.max_depth = self.config.get("max_tar_depth", 5) self.max_entries = self.config.get("max_tar_entries", 10000) + self.max_decompressed_bytes = int( + self.config.get("compressed_max_decompressed_bytes", DEFAULT_MAX_DECOMPRESSED_BYTES), + ) + self.max_decompression_ratio = float( + self.config.get("compressed_max_decompression_ratio", DEFAULT_MAX_DECOMPRESSION_RATIO), + ) @classmethod def can_handle(cls, path: str) -> bool: @@ -197,6 +207,7 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: with tarfile.open(path, "r:*") as tar: members = tar.getmembers() + total_member_size = sum(member.size for member in members if member.isfile()) if len(members) > self.max_entries: result.add_check( name="Entry Count Limit Check", @@ -218,6 +229,66 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: rule_code=None, # Passing check ) + compressed_size = os.path.getsize(path) + lower_path = path.lower() + if any(lower_path.endswith(ext) for ext in COMPRESSED_TAR_EXTENSIONS): + if total_member_size > self.max_decompressed_bytes: + result.add_check( + name="Compressed Wrapper Decompression Limits", + passed=False, + message=( + f"Decompressed size exceeded limit ({total_member_size} > {self.max_decompressed_bytes})" + ), + severity=IssueSeverity.WARNING, + location=path, + details={ + "decompressed_size": total_member_size, + "compressed_size": compressed_size, + "max_decompressed_size": self.max_decompressed_bytes, + }, + rule_code="S902", + ) + return result + + if compressed_size > 0 and (total_member_size / compressed_size) > self.max_decompression_ratio: + result.add_check( + name="Compressed Wrapper Decompression Limits", + passed=False, + message=( + "Decompression ratio exceeded limit " + f"({total_member_size / compressed_size:.1f}x > " + f"{self.max_decompression_ratio:.1f}x)" + ), + severity=IssueSeverity.WARNING, + location=path, + details={ + "decompressed_size": total_member_size, + "compressed_size": compressed_size, + "max_ratio": self.max_decompression_ratio, + "actual_ratio": total_member_size / compressed_size, + }, + rule_code="S902", + ) + return result + + result.add_check( + name="Compressed Wrapper Decompression Limits", + passed=True, + message=( + "Decompressed size/ratio are within limits " + f"({total_member_size} bytes, " + f"{(total_member_size / compressed_size) if compressed_size else 0:.1f}x)" + ), + location=path, + details={ + "decompressed_size": total_member_size, + "compressed_size": compressed_size, + "max_decompressed_size": self.max_decompressed_bytes, + "max_ratio": self.max_decompression_ratio, + }, + rule_code=None, + ) + for member in members: name = member.name temp_base = os.path.join(tempfile.gettempdir(), "extract_tar") diff --git a/tests/scanners/test_tar_scanner.py b/tests/scanners/test_tar_scanner.py index 46e5ba05..ec8686f5 100644 --- a/tests/scanners/test_tar_scanner.py +++ b/tests/scanners/test_tar_scanner.py @@ -5,7 +5,7 @@ import pytest -from modelaudit.scanners.base import IssueSeverity +from modelaudit.scanners.base import CheckStatus, IssueSeverity from modelaudit.scanners.tar_scanner import DEFAULT_MAX_TAR_ENTRY_SIZE, TarScanner @@ -511,3 +511,39 @@ def test_scan_truncated_tar(self, tmp_path: Path) -> None: assert len(format_checks) == 1 assert "not a valid tar file" in format_checks[0].message.lower() assert any("not a valid tar file" in issue.message.lower() for issue in result.issues) + + def test_scan_tar_gz_enforces_decompression_ratio_limit(self, tmp_path: Path) -> None: + """Compressed TAR wrappers should enforce decompression ratio limits.""" + archive_path = tmp_path / "ratio_limit.tar.gz" + payload = b"A" * 1_000_000 + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompression_ratio": 2.0}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompression ratio exceeded" in limit_checks[0].message.lower() + + def test_scan_tar_gz_enforces_decompressed_size_limit(self, tmp_path: Path) -> None: + """Compressed TAR wrappers should enforce decompressed size limits.""" + archive_path = tmp_path / "size_limit.tar.gz" + payload = b"B" * 10_000 + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompressed_bytes": 1024}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompressed size exceeded" in limit_checks[0].message.lower() From 31faa504effcd0b9f3ba37f58858a007c9a67a43 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Mon, 16 Mar 2026 04:00:18 -0700 Subject: [PATCH 2/3] fix: enforce tar wrapper limits by content --- modelaudit/scanners/tar_scanner.py | 40 ++++++++++++-- tests/scanners/test_tar_scanner.py | 87 +++++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/modelaudit/scanners/tar_scanner.py b/modelaudit/scanners/tar_scanner.py index f6456089..a8298ac1 100644 --- a/modelaudit/scanners/tar_scanner.py +++ b/modelaudit/scanners/tar_scanner.py @@ -31,7 +31,9 @@ DEFAULT_MAX_DECOMPRESSED_BYTES = 512 * 1024 * 1024 DEFAULT_MAX_DECOMPRESSION_RATIO = 250.0 -COMPRESSED_TAR_EXTENSIONS = {".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar.xz", ".txz"} +_GZIP_MAGIC = b"\x1f\x8b" +_BZIP2_MAGIC = b"BZh" +_XZ_MAGIC = b"\xfd7zXZ\x00" class TarScanner(BaseScanner): @@ -180,6 +182,32 @@ def _extract_member_to_tempfile( assert tmp_path is not None return tmp_path, total_size + @staticmethod + def _detect_compressed_tar_wrapper(path: str) -> str | None: + """Detect compressed TAR wrappers by content, not by filename suffix.""" + with open(path, "rb") as file_obj: + header = file_obj.read(6) + + if header.startswith(_GZIP_MAGIC): + return "gzip" + if header.startswith(_BZIP2_MAGIC): + return "bzip2" + if header.startswith(_XZ_MAGIC): + return "xz" + return None + + @staticmethod + def _estimate_tar_stream_size(members: list[tarfile.TarInfo]) -> int: + """Estimate the decompressed TAR wrapper size including headers, padding, and EOF blocks.""" + total_size = 1024 # Two 512-byte EOF blocks terminate the TAR stream. + + for member in members: + total_size += 512 # Each TAR entry has a 512-byte header. + if member.isfile(): + total_size += ((member.size + 511) // 512) * 512 + + return total_size + def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: result = ScanResult(scanner_name=self.name) contents: list[dict[str, Any]] = [] @@ -207,7 +235,6 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: with tarfile.open(path, "r:*") as tar: members = tar.getmembers() - total_member_size = sum(member.size for member in members if member.isfile()) if len(members) > self.max_entries: result.add_check( name="Entry Count Limit Check", @@ -230,8 +257,10 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: ) compressed_size = os.path.getsize(path) - lower_path = path.lower() - if any(lower_path.endswith(ext) for ext in COMPRESSED_TAR_EXTENSIONS): + compression_codec = self._detect_compressed_tar_wrapper(path) + if compression_codec is not None: + total_member_size = self._estimate_tar_stream_size(members) + if total_member_size > self.max_decompressed_bytes: result.add_check( name="Compressed Wrapper Decompression Limits", @@ -245,6 +274,7 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: "decompressed_size": total_member_size, "compressed_size": compressed_size, "max_decompressed_size": self.max_decompressed_bytes, + "compression": compression_codec, }, rule_code="S902", ) @@ -266,6 +296,7 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: "compressed_size": compressed_size, "max_ratio": self.max_decompression_ratio, "actual_ratio": total_member_size / compressed_size, + "compression": compression_codec, }, rule_code="S902", ) @@ -285,6 +316,7 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: "compressed_size": compressed_size, "max_decompressed_size": self.max_decompressed_bytes, "max_ratio": self.max_decompression_ratio, + "compression": compression_codec, }, rule_code=None, ) diff --git a/tests/scanners/test_tar_scanner.py b/tests/scanners/test_tar_scanner.py index ec8686f5..7fd0991b 100644 --- a/tests/scanners/test_tar_scanner.py +++ b/tests/scanners/test_tar_scanner.py @@ -2,6 +2,7 @@ import tarfile import tempfile from pathlib import Path +from typing import Literal import pytest @@ -512,12 +513,22 @@ def test_scan_truncated_tar(self, tmp_path: Path) -> None: assert "not a valid tar file" in format_checks[0].message.lower() assert any("not a valid tar file" in issue.message.lower() for issue in result.issues) - def test_scan_tar_gz_enforces_decompression_ratio_limit(self, tmp_path: Path) -> None: - """Compressed TAR wrappers should enforce decompression ratio limits.""" - archive_path = tmp_path / "ratio_limit.tar.gz" + @pytest.mark.parametrize( + ("suffix", "mode"), + [ + (".tar.gz", "w:gz"), + (".tar.bz2", "w:bz2"), + (".tar.xz", "w:xz"), + ], + ) + def test_scan_compressed_tar_enforces_decompression_ratio_limit( + self, tmp_path: Path, suffix: str, mode: Literal["w:gz", "w:bz2", "w:xz"] + ) -> None: + """Compressed TAR wrappers should enforce decompression ratio limits across supported codecs.""" + archive_path = tmp_path / f"ratio_limit{suffix}" payload = b"A" * 1_000_000 - with tarfile.open(archive_path, "w:gz") as archive: + with tarfile.open(archive_path, mode) as archive: info = tarfile.TarInfo("payload.bin") info.size = len(payload) archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] @@ -530,12 +541,22 @@ def test_scan_tar_gz_enforces_decompression_ratio_limit(self, tmp_path: Path) -> assert limit_checks[0].status == CheckStatus.FAILED assert "decompression ratio exceeded" in limit_checks[0].message.lower() - def test_scan_tar_gz_enforces_decompressed_size_limit(self, tmp_path: Path) -> None: - """Compressed TAR wrappers should enforce decompressed size limits.""" - archive_path = tmp_path / "size_limit.tar.gz" + @pytest.mark.parametrize( + ("suffix", "mode"), + [ + (".tar.gz", "w:gz"), + (".tar.bz2", "w:bz2"), + (".tar.xz", "w:xz"), + ], + ) + def test_scan_compressed_tar_enforces_decompressed_size_limit( + self, tmp_path: Path, suffix: str, mode: Literal["w:gz", "w:bz2", "w:xz"] + ) -> None: + """Compressed TAR wrappers should enforce size limits across supported codecs.""" + archive_path = tmp_path / f"size_limit{suffix}" payload = b"B" * 10_000 - with tarfile.open(archive_path, "w:gz") as archive: + with tarfile.open(archive_path, mode) as archive: info = tarfile.TarInfo("payload.bin") info.size = len(payload) archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] @@ -547,3 +568,53 @@ def test_scan_tar_gz_enforces_decompressed_size_limit(self, tmp_path: Path) -> N assert len(limit_checks) == 1 assert limit_checks[0].status == CheckStatus.FAILED assert "decompressed size exceeded" in limit_checks[0].message.lower() + + @pytest.mark.parametrize( + ("suffix", "mode"), + [ + (".tar.gz", "w:gz"), + (".tar.bz2", "w:bz2"), + (".tar.xz", "w:xz"), + ], + ) + def test_scan_compressed_tar_within_limits_passes_decompression_checks( + self, tmp_path: Path, suffix: str, mode: Literal["w:gz", "w:bz2", "w:xz"] + ) -> None: + """Compressed TAR wrappers within safe bounds should produce a passing decompression check.""" + archive_path = tmp_path / f"within_limit{suffix}" + payload = b"safe-payload" + + with tarfile.open(archive_path, mode) as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner( + config={ + "compressed_max_decompression_ratio": 1_000.0, + "compressed_max_decompressed_bytes": 10_000, + } + ) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.PASSED + + def test_scan_compressed_tar_detects_wrapper_by_content_not_suffix(self, tmp_path: Path) -> None: + """Compressed TARs with plain .tar suffix should still enforce wrapper limits by magic bytes.""" + archive_path = tmp_path / "disguised_compressed.tar" + payload = b"C" * 1_000_000 + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompression_ratio": 2.0}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompression ratio exceeded" in limit_checks[0].message.lower() From 099d45cae6c2486ee63da51001e611cf448bab97 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Tue, 17 Mar 2026 07:21:08 -0700 Subject: [PATCH 3/3] fix(tar): route disguised tar wrappers through scanner --- modelaudit/core.py | 1 + modelaudit/scanners/tar_scanner.py | 267 ++++++++++++++++++----------- modelaudit/utils/file/detection.py | 13 ++ tests/scanners/test_tar_scanner.py | 67 +++++++- tests/utils/file/test_filetype.py | 13 ++ 5 files changed, 260 insertions(+), 101 deletions(-) diff --git a/modelaudit/core.py b/modelaudit/core.py index 1a2a7eff..cc2c4168 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -1347,6 +1347,7 @@ def _scan_file_internal(path: str, config: dict[str, Any] | None = None) -> Scan "safetensors": "safetensors", "tensorflow_directory": "tf_savedmodel", "protobuf": "tf_savedmodel", + "tar": "tar", "zip": "zip", "onnx": "onnx", "gguf": "gguf", diff --git a/modelaudit/scanners/tar_scanner.py b/modelaudit/scanners/tar_scanner.py index a8298ac1..d6cf2a2f 100644 --- a/modelaudit/scanners/tar_scanner.py +++ b/modelaudit/scanners/tar_scanner.py @@ -67,11 +67,6 @@ def can_handle(cls, path: str) -> bool: if not os.path.isfile(path): return False - # Check for compound extensions like .tar.gz - path_lower = path.lower() - if not any(path_lower.endswith(ext) for ext in cls.supported_extensions): - return False - try: return tarfile.is_tarfile(path) except Exception: @@ -197,16 +192,165 @@ def _detect_compressed_tar_wrapper(path: str) -> str | None: return None @staticmethod - def _estimate_tar_stream_size(members: list[tarfile.TarInfo]) -> int: - """Estimate the decompressed TAR wrapper size including headers, padding, and EOF blocks.""" - total_size = 1024 # Two 512-byte EOF blocks terminate the TAR stream. + def _finalize_tar_stream_size(consumed_size: int) -> int: + """Return the minimum TAR stream size after EOF blocks and record padding.""" + total_size = max(consumed_size + (2 * tarfile.BLOCKSIZE), tarfile.RECORDSIZE) + return ((total_size + tarfile.RECORDSIZE - 1) // tarfile.RECORDSIZE) * tarfile.RECORDSIZE + + def _add_compressed_wrapper_limit_check( + self, + result: ScanResult, + *, + passed: bool, + path: str, + message: str, + decompressed_size: int, + compressed_size: int, + compression_codec: str, + actual_ratio: float | None = None, + ) -> None: + """Record compressed-wrapper policy checks with consistent details.""" + details: dict[str, Any] = { + "decompressed_size": decompressed_size, + "compressed_size": compressed_size, + "max_decompressed_size": self.max_decompressed_bytes, + "max_ratio": self.max_decompression_ratio, + "compression": compression_codec, + } + if actual_ratio is not None: + details["actual_ratio"] = actual_ratio + + result.add_check( + name="Compressed Wrapper Decompression Limits", + passed=passed, + message=message, + severity=None if passed else IssueSeverity.WARNING, + location=path, + details=details, + rule_code=None if passed else "S902", + ) + + def _preflight_tar_archive(self, path: str, result: ScanResult) -> bool: + """Stream TAR headers once to enforce entry-count and wrapper-size limits before extraction.""" + entry_count = 0 + compressed_size = os.path.getsize(path) + compression_codec = self._detect_compressed_tar_wrapper(path) + consumed_size = 0 + + with tarfile.open(path, "r:*") as tar: + while True: + member = tar.next() + if member is None: + break + + entry_count += 1 + if entry_count > self.max_entries: + result.add_check( + name="Entry Count Limit Check", + passed=False, + message=f"TAR file contains too many entries ({entry_count} > {self.max_entries})", + rule_code="S902", + severity=IssueSeverity.WARNING, + location=path, + details={"entries": entry_count, "max_entries": self.max_entries}, + ) + return False + + if compression_codec is not None: + consumed_size = max(consumed_size, tar.offset) + estimated_stream_size = self._finalize_tar_stream_size(consumed_size) + actual_ratio = (estimated_stream_size / compressed_size) if compressed_size > 0 else 0.0 + + if estimated_stream_size > self.max_decompressed_bytes: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + f"Decompressed size exceeded limit " + f"({estimated_stream_size} > {self.max_decompressed_bytes})" + ), + decompressed_size=estimated_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False + + if compressed_size > 0 and actual_ratio > self.max_decompression_ratio: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + "Decompression ratio exceeded limit " + f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)" + ), + decompressed_size=estimated_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False + + result.add_check( + name="Entry Count Limit Check", + passed=True, + message=f"Entry count ({entry_count}) is within limits", + location=path, + details={"entries": entry_count, "max_entries": self.max_entries}, + rule_code=None, + ) + + if compression_codec is not None: + final_stream_size = self._finalize_tar_stream_size(max(consumed_size, tar.offset)) + actual_ratio = (final_stream_size / compressed_size) if compressed_size > 0 else 0.0 + + if final_stream_size > self.max_decompressed_bytes: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + f"Decompressed size exceeded limit ({final_stream_size} > {self.max_decompressed_bytes})" + ), + decompressed_size=final_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False + + if compressed_size > 0 and actual_ratio > self.max_decompression_ratio: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + "Decompression ratio exceeded limit " + f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)" + ), + decompressed_size=final_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False - for member in members: - total_size += 512 # Each TAR entry has a 512-byte header. - if member.isfile(): - total_size += ((member.size + 511) // 512) * 512 + self._add_compressed_wrapper_limit_check( + result, + passed=True, + path=path, + message=( + f"Decompressed size/ratio are within limits ({final_stream_size} bytes, {actual_ratio:.1f}x)" + ), + decompressed_size=final_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) - return total_size + return True def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: result = ScanResult(scanner_name=self.name) @@ -233,95 +377,18 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: rule_code=None, # Passing check ) - with tarfile.open(path, "r:*") as tar: - members = tar.getmembers() - if len(members) > self.max_entries: - result.add_check( - name="Entry Count Limit Check", - passed=False, - message=f"TAR file contains too many entries ({len(members)} > {self.max_entries})", - rule_code="S902", - severity=IssueSeverity.WARNING, - location=path, - details={"entries": len(members), "max_entries": self.max_entries}, - ) - return result - else: - result.add_check( - name="Entry Count Limit Check", - passed=True, - message=f"Entry count ({len(members)}) is within limits", - location=path, - details={"entries": len(members), "max_entries": self.max_entries}, - rule_code=None, # Passing check - ) - - compressed_size = os.path.getsize(path) - compression_codec = self._detect_compressed_tar_wrapper(path) - if compression_codec is not None: - total_member_size = self._estimate_tar_stream_size(members) - - if total_member_size > self.max_decompressed_bytes: - result.add_check( - name="Compressed Wrapper Decompression Limits", - passed=False, - message=( - f"Decompressed size exceeded limit ({total_member_size} > {self.max_decompressed_bytes})" - ), - severity=IssueSeverity.WARNING, - location=path, - details={ - "decompressed_size": total_member_size, - "compressed_size": compressed_size, - "max_decompressed_size": self.max_decompressed_bytes, - "compression": compression_codec, - }, - rule_code="S902", - ) - return result - - if compressed_size > 0 and (total_member_size / compressed_size) > self.max_decompression_ratio: - result.add_check( - name="Compressed Wrapper Decompression Limits", - passed=False, - message=( - "Decompression ratio exceeded limit " - f"({total_member_size / compressed_size:.1f}x > " - f"{self.max_decompression_ratio:.1f}x)" - ), - severity=IssueSeverity.WARNING, - location=path, - details={ - "decompressed_size": total_member_size, - "compressed_size": compressed_size, - "max_ratio": self.max_decompression_ratio, - "actual_ratio": total_member_size / compressed_size, - "compression": compression_codec, - }, - rule_code="S902", - ) - return result + if not self._preflight_tar_archive(path, result): + result.metadata["contents"] = contents + result.metadata["file_size"] = os.path.getsize(path) + result.finish(success=not result.has_errors) + return result - result.add_check( - name="Compressed Wrapper Decompression Limits", - passed=True, - message=( - "Decompressed size/ratio are within limits " - f"({total_member_size} bytes, " - f"{(total_member_size / compressed_size) if compressed_size else 0:.1f}x)" - ), - location=path, - details={ - "decompressed_size": total_member_size, - "compressed_size": compressed_size, - "max_decompressed_size": self.max_decompressed_bytes, - "max_ratio": self.max_decompression_ratio, - "compression": compression_codec, - }, - rule_code=None, - ) + with tarfile.open(path, "r:*") as tar: + while True: + member = tar.next() + if member is None: + break - for member in members: name = member.name temp_base = os.path.join(tempfile.gettempdir(), "extract_tar") resolved_name, is_safe = sanitize_archive_path(name, temp_base) diff --git a/modelaudit/utils/file/detection.py b/modelaudit/utils/file/detection.py index 33793ca5..8981954e 100644 --- a/modelaudit/utils/file/detection.py +++ b/modelaudit/utils/file/detection.py @@ -2,6 +2,7 @@ import pickletools import re import struct +import tarfile import zipfile from pathlib import Path, PurePosixPath @@ -215,6 +216,14 @@ def is_torchserve_mar_archive(path: str) -> bool: return False +def _is_tar_archive(path: str) -> bool: + """Return whether a path is a TAR archive, including compressed wrappers.""" + try: + return tarfile.is_tarfile(path) + except Exception: + return False + + def is_zipfile(path: str) -> bool: """Check if file is a ZIP by reading the signature.""" file_path = Path(path) @@ -539,10 +548,14 @@ def detect_file_format(path: str) -> str: compression_format = _detect_compression_format(header) if ext in _COMPRESSED_EXTENSION_CODECS: + if _is_tar_archive(path): + return "tar" expected_codec = _COMPRESSED_EXTENSION_CODECS[ext] if compression_format == expected_codec: return "compressed" return "unknown" + if _is_tar_archive(path): + return "tar" # Check ZIP magic first (for .pt/.pth files that are actually zips) if magic4.startswith(b"PK"): if ext == ".mar" and is_torchserve_mar_archive(path): diff --git a/tests/scanners/test_tar_scanner.py b/tests/scanners/test_tar_scanner.py index 7fd0991b..5d3a7692 100644 --- a/tests/scanners/test_tar_scanner.py +++ b/tests/scanners/test_tar_scanner.py @@ -6,6 +6,7 @@ import pytest +from modelaudit import core from modelaudit.scanners.base import CheckStatus, IssueSeverity from modelaudit.scanners.tar_scanner import DEFAULT_MAX_TAR_ENTRY_SIZE, TarScanner @@ -592,7 +593,7 @@ def test_scan_compressed_tar_within_limits_passes_decompression_checks( scanner = TarScanner( config={ "compressed_max_decompression_ratio": 1_000.0, - "compressed_max_decompressed_bytes": 10_000, + "compressed_max_decompressed_bytes": 20_000, } ) result = scanner.scan(str(archive_path)) @@ -601,6 +602,49 @@ def test_scan_compressed_tar_within_limits_passes_decompression_checks( assert len(limit_checks) == 1 assert limit_checks[0].status == CheckStatus.PASSED + def test_scan_compressed_tar_accounts_for_tar_record_padding(self, tmp_path: Path) -> None: + """Wrapper limits should account for TAR record padding, even on tiny archives.""" + archive_path = tmp_path / "tiny.tar.gz" + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + payload = b"tiny" + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompressed_bytes": 4_096}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompressed size exceeded" in limit_checks[0].message.lower() + + def test_scan_tar_preflight_streams_members_without_getmembers( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Preflight should stream TAR members instead of materializing them with getmembers().""" + archive_path = tmp_path / "streamed.tar.gz" + + with tarfile.open(archive_path, "w:gz") as archive: + for index in range(3): + info = tarfile.TarInfo(f"payload-{index}.bin") + payload = f"payload-{index}".encode() + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + def fail_getmembers(self: tarfile.TarFile) -> list[tarfile.TarInfo]: + raise AssertionError("TarScanner should not call getmembers() during preflight") + + monkeypatch.setattr(tarfile.TarFile, "getmembers", fail_getmembers) + + result = self.scanner.scan(str(archive_path)) + + assert result.success is True + entry_checks = [check for check in result.checks if check.name == "Entry Count Limit Check"] + assert len(entry_checks) == 1 + assert entry_checks[0].status == CheckStatus.PASSED + def test_scan_compressed_tar_detects_wrapper_by_content_not_suffix(self, tmp_path: Path) -> None: """Compressed TARs with plain .tar suffix should still enforce wrapper limits by magic bytes.""" archive_path = tmp_path / "disguised_compressed.tar" @@ -618,3 +662,24 @@ def test_scan_compressed_tar_detects_wrapper_by_content_not_suffix(self, tmp_pat assert len(limit_checks) == 1 assert limit_checks[0].status == CheckStatus.FAILED assert "decompression ratio exceeded" in limit_checks[0].message.lower() + + def test_core_routes_disguised_compressed_tar_without_tar_suffix(self, tmp_path: Path) -> None: + """Compressed TAR wrappers renamed to generic suffixes should still route to TarScanner.""" + archive_path = tmp_path / "disguised_payload.bin" + payload = b"D" * 1_000_000 + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + result = core.scan_file( + str(archive_path), + config={"compressed_max_decompression_ratio": 2.0}, + ) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert result.scanner_name == "tar" + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompression ratio exceeded" in limit_checks[0].message.lower() diff --git a/tests/utils/file/test_filetype.py b/tests/utils/file/test_filetype.py index 8d5a92f3..ebf2e92d 100644 --- a/tests/utils/file/test_filetype.py +++ b/tests/utils/file/test_filetype.py @@ -379,6 +379,19 @@ def test_detect_file_format_tar_wrappers_preserve_tar_routing(tmp_path: Path) -> assert validate_file_type(str(tar_gz)) is True +def test_detect_file_format_disguised_compressed_tar_by_content(tmp_path: Path) -> None: + archive_path = tmp_path / "archive.bin" + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + payload = b"payload" + info.size = len(payload) + archive.addfile(info, io.BytesIO(payload)) + + assert detect_file_format(str(archive_path)) == "tar" + assert detect_file_format_from_magic(str(archive_path)) == "gzip" + assert validate_file_type(str(archive_path)) is False + + def test_zip_magic_variants(tmp_path): """Ensure alternate PK signatures are detected as ZIP.""" for sig in (b"PK\x06\x06", b"PK\x06\x07"):