diff --git a/modelaudit/core.py b/modelaudit/core.py index 1a2a7eff..cc2c4168 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -1347,6 +1347,7 @@ def _scan_file_internal(path: str, config: dict[str, Any] | None = None) -> Scan "safetensors": "safetensors", "tensorflow_directory": "tf_savedmodel", "protobuf": "tf_savedmodel", + "tar": "tar", "zip": "zip", "onnx": "onnx", "gguf": "gguf", diff --git a/modelaudit/scanners/tar_scanner.py b/modelaudit/scanners/tar_scanner.py index 533e7903..d6cf2a2f 100644 --- a/modelaudit/scanners/tar_scanner.py +++ b/modelaudit/scanners/tar_scanner.py @@ -28,6 +28,12 @@ ] DEFAULT_MAX_TAR_ENTRY_SIZE = 1024 * 1024 * 1024 +DEFAULT_MAX_DECOMPRESSED_BYTES = 512 * 1024 * 1024 +DEFAULT_MAX_DECOMPRESSION_RATIO = 250.0 + +_GZIP_MAGIC = b"\x1f\x8b" +_BZIP2_MAGIC = b"BZh" +_XZ_MAGIC = b"\xfd7zXZ\x00" class TarScanner(BaseScanner): @@ -49,17 +55,18 @@ def __init__(self, config: dict[str, Any] | None = None) -> None: super().__init__(config) self.max_depth = self.config.get("max_tar_depth", 5) self.max_entries = self.config.get("max_tar_entries", 10000) + self.max_decompressed_bytes = int( + self.config.get("compressed_max_decompressed_bytes", DEFAULT_MAX_DECOMPRESSED_BYTES), + ) + self.max_decompression_ratio = float( + self.config.get("compressed_max_decompression_ratio", DEFAULT_MAX_DECOMPRESSION_RATIO), + ) @classmethod def can_handle(cls, path: str) -> bool: if not os.path.isfile(path): return False - # Check for compound extensions like .tar.gz - path_lower = path.lower() - if not any(path_lower.endswith(ext) for ext in cls.supported_extensions): - return False - try: return tarfile.is_tarfile(path) except Exception: @@ -170,6 +177,181 @@ def _extract_member_to_tempfile( assert tmp_path is not None return tmp_path, total_size + @staticmethod + def _detect_compressed_tar_wrapper(path: str) -> str | None: + """Detect compressed TAR wrappers by content, not by filename suffix.""" + with open(path, "rb") as file_obj: + header = file_obj.read(6) + + if header.startswith(_GZIP_MAGIC): + return "gzip" + if header.startswith(_BZIP2_MAGIC): + return "bzip2" + if header.startswith(_XZ_MAGIC): + return "xz" + return None + + @staticmethod + def _finalize_tar_stream_size(consumed_size: int) -> int: + """Return the minimum TAR stream size after EOF blocks and record padding.""" + total_size = max(consumed_size + (2 * tarfile.BLOCKSIZE), tarfile.RECORDSIZE) + return ((total_size + tarfile.RECORDSIZE - 1) // tarfile.RECORDSIZE) * tarfile.RECORDSIZE + + def _add_compressed_wrapper_limit_check( + self, + result: ScanResult, + *, + passed: bool, + path: str, + message: str, + decompressed_size: int, + compressed_size: int, + compression_codec: str, + actual_ratio: float | None = None, + ) -> None: + """Record compressed-wrapper policy checks with consistent details.""" + details: dict[str, Any] = { + "decompressed_size": decompressed_size, + "compressed_size": compressed_size, + "max_decompressed_size": self.max_decompressed_bytes, + "max_ratio": self.max_decompression_ratio, + "compression": compression_codec, + } + if actual_ratio is not None: + details["actual_ratio"] = actual_ratio + + result.add_check( + name="Compressed Wrapper Decompression Limits", + passed=passed, + message=message, + severity=None if passed else IssueSeverity.WARNING, + location=path, + details=details, + rule_code=None if passed else "S902", + ) + + def _preflight_tar_archive(self, path: str, result: ScanResult) -> bool: + """Stream TAR headers once to enforce entry-count and wrapper-size limits before extraction.""" + entry_count = 0 + compressed_size = os.path.getsize(path) + compression_codec = self._detect_compressed_tar_wrapper(path) + consumed_size = 0 + + with tarfile.open(path, "r:*") as tar: + while True: + member = tar.next() + if member is None: + break + + entry_count += 1 + if entry_count > self.max_entries: + result.add_check( + name="Entry Count Limit Check", + passed=False, + message=f"TAR file contains too many entries ({entry_count} > {self.max_entries})", + rule_code="S902", + severity=IssueSeverity.WARNING, + location=path, + details={"entries": entry_count, "max_entries": self.max_entries}, + ) + return False + + if compression_codec is not None: + consumed_size = max(consumed_size, tar.offset) + estimated_stream_size = self._finalize_tar_stream_size(consumed_size) + actual_ratio = (estimated_stream_size / compressed_size) if compressed_size > 0 else 0.0 + + if estimated_stream_size > self.max_decompressed_bytes: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + f"Decompressed size exceeded limit " + f"({estimated_stream_size} > {self.max_decompressed_bytes})" + ), + decompressed_size=estimated_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False + + if compressed_size > 0 and actual_ratio > self.max_decompression_ratio: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + "Decompression ratio exceeded limit " + f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)" + ), + decompressed_size=estimated_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False + + result.add_check( + name="Entry Count Limit Check", + passed=True, + message=f"Entry count ({entry_count}) is within limits", + location=path, + details={"entries": entry_count, "max_entries": self.max_entries}, + rule_code=None, + ) + + if compression_codec is not None: + final_stream_size = self._finalize_tar_stream_size(max(consumed_size, tar.offset)) + actual_ratio = (final_stream_size / compressed_size) if compressed_size > 0 else 0.0 + + if final_stream_size > self.max_decompressed_bytes: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + f"Decompressed size exceeded limit ({final_stream_size} > {self.max_decompressed_bytes})" + ), + decompressed_size=final_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False + + if compressed_size > 0 and actual_ratio > self.max_decompression_ratio: + self._add_compressed_wrapper_limit_check( + result, + passed=False, + path=path, + message=( + "Decompression ratio exceeded limit " + f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)" + ), + decompressed_size=final_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + return False + + self._add_compressed_wrapper_limit_check( + result, + passed=True, + path=path, + message=( + f"Decompressed size/ratio are within limits ({final_stream_size} bytes, {actual_ratio:.1f}x)" + ), + decompressed_size=final_stream_size, + compressed_size=compressed_size, + compression_codec=compression_codec, + actual_ratio=actual_ratio, + ) + + return True + def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: result = ScanResult(scanner_name=self.name) contents: list[dict[str, Any]] = [] @@ -195,30 +377,18 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult: rule_code=None, # Passing check ) + if not self._preflight_tar_archive(path, result): + result.metadata["contents"] = contents + result.metadata["file_size"] = os.path.getsize(path) + result.finish(success=not result.has_errors) + return result + with tarfile.open(path, "r:*") as tar: - members = tar.getmembers() - if len(members) > self.max_entries: - result.add_check( - name="Entry Count Limit Check", - passed=False, - message=f"TAR file contains too many entries ({len(members)} > {self.max_entries})", - rule_code="S902", - severity=IssueSeverity.WARNING, - location=path, - details={"entries": len(members), "max_entries": self.max_entries}, - ) - return result - else: - result.add_check( - name="Entry Count Limit Check", - passed=True, - message=f"Entry count ({len(members)}) is within limits", - location=path, - details={"entries": len(members), "max_entries": self.max_entries}, - rule_code=None, # Passing check - ) + while True: + member = tar.next() + if member is None: + break - for member in members: name = member.name temp_base = os.path.join(tempfile.gettempdir(), "extract_tar") resolved_name, is_safe = sanitize_archive_path(name, temp_base) diff --git a/modelaudit/utils/file/detection.py b/modelaudit/utils/file/detection.py index 33793ca5..8981954e 100644 --- a/modelaudit/utils/file/detection.py +++ b/modelaudit/utils/file/detection.py @@ -2,6 +2,7 @@ import pickletools import re import struct +import tarfile import zipfile from pathlib import Path, PurePosixPath @@ -215,6 +216,14 @@ def is_torchserve_mar_archive(path: str) -> bool: return False +def _is_tar_archive(path: str) -> bool: + """Return whether a path is a TAR archive, including compressed wrappers.""" + try: + return tarfile.is_tarfile(path) + except Exception: + return False + + def is_zipfile(path: str) -> bool: """Check if file is a ZIP by reading the signature.""" file_path = Path(path) @@ -539,10 +548,14 @@ def detect_file_format(path: str) -> str: compression_format = _detect_compression_format(header) if ext in _COMPRESSED_EXTENSION_CODECS: + if _is_tar_archive(path): + return "tar" expected_codec = _COMPRESSED_EXTENSION_CODECS[ext] if compression_format == expected_codec: return "compressed" return "unknown" + if _is_tar_archive(path): + return "tar" # Check ZIP magic first (for .pt/.pth files that are actually zips) if magic4.startswith(b"PK"): if ext == ".mar" and is_torchserve_mar_archive(path): diff --git a/tests/scanners/test_tar_scanner.py b/tests/scanners/test_tar_scanner.py index 46e5ba05..5d3a7692 100644 --- a/tests/scanners/test_tar_scanner.py +++ b/tests/scanners/test_tar_scanner.py @@ -2,10 +2,12 @@ import tarfile import tempfile from pathlib import Path +from typing import Literal import pytest -from modelaudit.scanners.base import IssueSeverity +from modelaudit import core +from modelaudit.scanners.base import CheckStatus, IssueSeverity from modelaudit.scanners.tar_scanner import DEFAULT_MAX_TAR_ENTRY_SIZE, TarScanner @@ -511,3 +513,173 @@ def test_scan_truncated_tar(self, tmp_path: Path) -> None: assert len(format_checks) == 1 assert "not a valid tar file" in format_checks[0].message.lower() assert any("not a valid tar file" in issue.message.lower() for issue in result.issues) + + @pytest.mark.parametrize( + ("suffix", "mode"), + [ + (".tar.gz", "w:gz"), + (".tar.bz2", "w:bz2"), + (".tar.xz", "w:xz"), + ], + ) + def test_scan_compressed_tar_enforces_decompression_ratio_limit( + self, tmp_path: Path, suffix: str, mode: Literal["w:gz", "w:bz2", "w:xz"] + ) -> None: + """Compressed TAR wrappers should enforce decompression ratio limits across supported codecs.""" + archive_path = tmp_path / f"ratio_limit{suffix}" + payload = b"A" * 1_000_000 + + with tarfile.open(archive_path, mode) as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompression_ratio": 2.0}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompression ratio exceeded" in limit_checks[0].message.lower() + + @pytest.mark.parametrize( + ("suffix", "mode"), + [ + (".tar.gz", "w:gz"), + (".tar.bz2", "w:bz2"), + (".tar.xz", "w:xz"), + ], + ) + def test_scan_compressed_tar_enforces_decompressed_size_limit( + self, tmp_path: Path, suffix: str, mode: Literal["w:gz", "w:bz2", "w:xz"] + ) -> None: + """Compressed TAR wrappers should enforce size limits across supported codecs.""" + archive_path = tmp_path / f"size_limit{suffix}" + payload = b"B" * 10_000 + + with tarfile.open(archive_path, mode) as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompressed_bytes": 1024}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompressed size exceeded" in limit_checks[0].message.lower() + + @pytest.mark.parametrize( + ("suffix", "mode"), + [ + (".tar.gz", "w:gz"), + (".tar.bz2", "w:bz2"), + (".tar.xz", "w:xz"), + ], + ) + def test_scan_compressed_tar_within_limits_passes_decompression_checks( + self, tmp_path: Path, suffix: str, mode: Literal["w:gz", "w:bz2", "w:xz"] + ) -> None: + """Compressed TAR wrappers within safe bounds should produce a passing decompression check.""" + archive_path = tmp_path / f"within_limit{suffix}" + payload = b"safe-payload" + + with tarfile.open(archive_path, mode) as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner( + config={ + "compressed_max_decompression_ratio": 1_000.0, + "compressed_max_decompressed_bytes": 20_000, + } + ) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.PASSED + + def test_scan_compressed_tar_accounts_for_tar_record_padding(self, tmp_path: Path) -> None: + """Wrapper limits should account for TAR record padding, even on tiny archives.""" + archive_path = tmp_path / "tiny.tar.gz" + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + payload = b"tiny" + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompressed_bytes": 4_096}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompressed size exceeded" in limit_checks[0].message.lower() + + def test_scan_tar_preflight_streams_members_without_getmembers( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Preflight should stream TAR members instead of materializing them with getmembers().""" + archive_path = tmp_path / "streamed.tar.gz" + + with tarfile.open(archive_path, "w:gz") as archive: + for index in range(3): + info = tarfile.TarInfo(f"payload-{index}.bin") + payload = f"payload-{index}".encode() + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + def fail_getmembers(self: tarfile.TarFile) -> list[tarfile.TarInfo]: + raise AssertionError("TarScanner should not call getmembers() during preflight") + + monkeypatch.setattr(tarfile.TarFile, "getmembers", fail_getmembers) + + result = self.scanner.scan(str(archive_path)) + + assert result.success is True + entry_checks = [check for check in result.checks if check.name == "Entry Count Limit Check"] + assert len(entry_checks) == 1 + assert entry_checks[0].status == CheckStatus.PASSED + + def test_scan_compressed_tar_detects_wrapper_by_content_not_suffix(self, tmp_path: Path) -> None: + """Compressed TARs with plain .tar suffix should still enforce wrapper limits by magic bytes.""" + archive_path = tmp_path / "disguised_compressed.tar" + payload = b"C" * 1_000_000 + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + scanner = TarScanner(config={"compressed_max_decompression_ratio": 2.0}) + result = scanner.scan(str(archive_path)) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompression ratio exceeded" in limit_checks[0].message.lower() + + def test_core_routes_disguised_compressed_tar_without_tar_suffix(self, tmp_path: Path) -> None: + """Compressed TAR wrappers renamed to generic suffixes should still route to TarScanner.""" + archive_path = tmp_path / "disguised_payload.bin" + payload = b"D" * 1_000_000 + + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + info.size = len(payload) + archive.addfile(info, tarfile.io.BytesIO(payload)) # type: ignore[attr-defined] + + result = core.scan_file( + str(archive_path), + config={"compressed_max_decompression_ratio": 2.0}, + ) + + limit_checks = [check for check in result.checks if check.name == "Compressed Wrapper Decompression Limits"] + assert result.scanner_name == "tar" + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert "decompression ratio exceeded" in limit_checks[0].message.lower() diff --git a/tests/utils/file/test_filetype.py b/tests/utils/file/test_filetype.py index 8d5a92f3..ebf2e92d 100644 --- a/tests/utils/file/test_filetype.py +++ b/tests/utils/file/test_filetype.py @@ -379,6 +379,19 @@ def test_detect_file_format_tar_wrappers_preserve_tar_routing(tmp_path: Path) -> assert validate_file_type(str(tar_gz)) is True +def test_detect_file_format_disguised_compressed_tar_by_content(tmp_path: Path) -> None: + archive_path = tmp_path / "archive.bin" + with tarfile.open(archive_path, "w:gz") as archive: + info = tarfile.TarInfo("payload.bin") + payload = b"payload" + info.size = len(payload) + archive.addfile(info, io.BytesIO(payload)) + + assert detect_file_format(str(archive_path)) == "tar" + assert detect_file_format_from_magic(str(archive_path)) == "gzip" + assert validate_file_type(str(archive_path)) is False + + def test_zip_magic_variants(tmp_path): """Ensure alternate PK signatures are detected as ZIP.""" for sig in (b"PK\x06\x06", b"PK\x06\x07"):