Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelaudit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,7 @@ def _scan_file_internal(path: str, config: dict[str, Any] | None = None) -> Scan
"safetensors": "safetensors",
"tensorflow_directory": "tf_savedmodel",
"protobuf": "tf_savedmodel",
"tar": "tar",
"zip": "zip",
"onnx": "onnx",
"gguf": "gguf",
Expand Down
224 changes: 197 additions & 27 deletions modelaudit/scanners/tar_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
]

DEFAULT_MAX_TAR_ENTRY_SIZE = 1024 * 1024 * 1024
DEFAULT_MAX_DECOMPRESSED_BYTES = 512 * 1024 * 1024
DEFAULT_MAX_DECOMPRESSION_RATIO = 250.0

_GZIP_MAGIC = b"\x1f\x8b"
_BZIP2_MAGIC = b"BZh"
_XZ_MAGIC = b"\xfd7zXZ\x00"


class TarScanner(BaseScanner):
Expand All @@ -49,17 +55,18 @@ def __init__(self, config: dict[str, Any] | None = None) -> None:
super().__init__(config)
self.max_depth = self.config.get("max_tar_depth", 5)
self.max_entries = self.config.get("max_tar_entries", 10000)
self.max_decompressed_bytes = int(
self.config.get("compressed_max_decompressed_bytes", DEFAULT_MAX_DECOMPRESSED_BYTES),
)
self.max_decompression_ratio = float(
self.config.get("compressed_max_decompression_ratio", DEFAULT_MAX_DECOMPRESSION_RATIO),
)

@classmethod
def can_handle(cls, path: str) -> bool:
if not os.path.isfile(path):
return False

# Check for compound extensions like .tar.gz
path_lower = path.lower()
if not any(path_lower.endswith(ext) for ext in cls.supported_extensions):
return False

try:
return tarfile.is_tarfile(path)
except Exception:
Expand Down Expand Up @@ -170,6 +177,181 @@ def _extract_member_to_tempfile(
assert tmp_path is not None
return tmp_path, total_size

@staticmethod
def _detect_compressed_tar_wrapper(path: str) -> str | None:
"""Detect compressed TAR wrappers by content, not by filename suffix."""
with open(path, "rb") as file_obj:
header = file_obj.read(6)

if header.startswith(_GZIP_MAGIC):
return "gzip"
if header.startswith(_BZIP2_MAGIC):
return "bzip2"
if header.startswith(_XZ_MAGIC):
return "xz"
return None

@staticmethod
def _finalize_tar_stream_size(consumed_size: int) -> int:
"""Return the minimum TAR stream size after EOF blocks and record padding."""
total_size = max(consumed_size + (2 * tarfile.BLOCKSIZE), tarfile.RECORDSIZE)
return ((total_size + tarfile.RECORDSIZE - 1) // tarfile.RECORDSIZE) * tarfile.RECORDSIZE

def _add_compressed_wrapper_limit_check(
self,
result: ScanResult,
*,
passed: bool,
path: str,
message: str,
decompressed_size: int,
compressed_size: int,
compression_codec: str,
actual_ratio: float | None = None,
) -> None:
"""Record compressed-wrapper policy checks with consistent details."""
details: dict[str, Any] = {
"decompressed_size": decompressed_size,
"compressed_size": compressed_size,
"max_decompressed_size": self.max_decompressed_bytes,
"max_ratio": self.max_decompression_ratio,
"compression": compression_codec,
}
if actual_ratio is not None:
details["actual_ratio"] = actual_ratio

result.add_check(
name="Compressed Wrapper Decompression Limits",
passed=passed,
message=message,
severity=None if passed else IssueSeverity.WARNING,
location=path,
details=details,
rule_code=None if passed else "S902",
)

def _preflight_tar_archive(self, path: str, result: ScanResult) -> bool:
"""Stream TAR headers once to enforce entry-count and wrapper-size limits before extraction."""
entry_count = 0
compressed_size = os.path.getsize(path)
compression_codec = self._detect_compressed_tar_wrapper(path)
consumed_size = 0

with tarfile.open(path, "r:*") as tar:
while True:
member = tar.next()
if member is None:
break

entry_count += 1
if entry_count > self.max_entries:
result.add_check(
name="Entry Count Limit Check",
passed=False,
message=f"TAR file contains too many entries ({entry_count} > {self.max_entries})",
rule_code="S902",
severity=IssueSeverity.WARNING,
location=path,
details={"entries": entry_count, "max_entries": self.max_entries},
)
return False

if compression_codec is not None:
consumed_size = max(consumed_size, tar.offset)
estimated_stream_size = self._finalize_tar_stream_size(consumed_size)
actual_ratio = (estimated_stream_size / compressed_size) if compressed_size > 0 else 0.0

if estimated_stream_size > self.max_decompressed_bytes:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
f"Decompressed size exceeded limit "
f"({estimated_stream_size} > {self.max_decompressed_bytes})"
),
decompressed_size=estimated_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

if compressed_size > 0 and actual_ratio > self.max_decompression_ratio:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
"Decompression ratio exceeded limit "
f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)"
),
decompressed_size=estimated_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

result.add_check(
name="Entry Count Limit Check",
passed=True,
message=f"Entry count ({entry_count}) is within limits",
location=path,
details={"entries": entry_count, "max_entries": self.max_entries},
rule_code=None,
)

if compression_codec is not None:
final_stream_size = self._finalize_tar_stream_size(max(consumed_size, tar.offset))
actual_ratio = (final_stream_size / compressed_size) if compressed_size > 0 else 0.0

if final_stream_size > self.max_decompressed_bytes:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
f"Decompressed size exceeded limit ({final_stream_size} > {self.max_decompressed_bytes})"
),
decompressed_size=final_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

if compressed_size > 0 and actual_ratio > self.max_decompression_ratio:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
"Decompression ratio exceeded limit "
f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)"
),
decompressed_size=final_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

self._add_compressed_wrapper_limit_check(
result,
passed=True,
path=path,
message=(
f"Decompressed size/ratio are within limits ({final_stream_size} bytes, {actual_ratio:.1f}x)"
),
decompressed_size=final_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)

return True

def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult:
result = ScanResult(scanner_name=self.name)
contents: list[dict[str, Any]] = []
Expand All @@ -195,30 +377,18 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult:
rule_code=None, # Passing check
)

if not self._preflight_tar_archive(path, result):
result.metadata["contents"] = contents
result.metadata["file_size"] = os.path.getsize(path)
result.finish(success=not result.has_errors)
return result

with tarfile.open(path, "r:*") as tar:
members = tar.getmembers()
if len(members) > self.max_entries:
result.add_check(
name="Entry Count Limit Check",
passed=False,
message=f"TAR file contains too many entries ({len(members)} > {self.max_entries})",
rule_code="S902",
severity=IssueSeverity.WARNING,
location=path,
details={"entries": len(members), "max_entries": self.max_entries},
)
return result
else:
result.add_check(
name="Entry Count Limit Check",
passed=True,
message=f"Entry count ({len(members)}) is within limits",
location=path,
details={"entries": len(members), "max_entries": self.max_entries},
rule_code=None, # Passing check
)
while True:
member = tar.next()
if member is None:
break

for member in members:
name = member.name
temp_base = os.path.join(tempfile.gettempdir(), "extract_tar")
resolved_name, is_safe = sanitize_archive_path(name, temp_base)
Expand Down
13 changes: 13 additions & 0 deletions modelaudit/utils/file/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickletools
import re
import struct
import tarfile
import zipfile
from pathlib import Path, PurePosixPath

Expand Down Expand Up @@ -215,6 +216,14 @@ def is_torchserve_mar_archive(path: str) -> bool:
return False


def _is_tar_archive(path: str) -> bool:
"""Return whether a path is a TAR archive, including compressed wrappers."""
try:
return tarfile.is_tarfile(path)
except Exception:
return False


def is_zipfile(path: str) -> bool:
"""Check if file is a ZIP by reading the signature."""
file_path = Path(path)
Expand Down Expand Up @@ -539,10 +548,14 @@ def detect_file_format(path: str) -> str:

compression_format = _detect_compression_format(header)
if ext in _COMPRESSED_EXTENSION_CODECS:
if _is_tar_archive(path):
return "tar"
expected_codec = _COMPRESSED_EXTENSION_CODECS[ext]
if compression_format == expected_codec:
return "compressed"
return "unknown"
if _is_tar_archive(path):
return "tar"
# Check ZIP magic first (for .pt/.pth files that are actually zips)
if magic4.startswith(b"PK"):
if ext == ".mar" and is_torchserve_mar_archive(path):
Expand Down
Loading
Loading