Skip to content

Commit 0ba1168

Browse files
committed
Merge main into feat/migrate-poetry-to-rye - add Joblib and NumPy scanners while preserving Rye config
2 parents 6d80341 + 26f4973 commit 0ba1168

12 files changed

Lines changed: 945 additions & 13 deletions

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ ModelAudit scans ML model files for:
4949
- **Models with blacklisted names** or content patterns
5050
- **Malicious content in ZIP archives** including nested archives and zip bombs
5151
- **Anomalous weight patterns** that may indicate trojaned models (statistical analysis)
52+
- **Joblib serialization vulnerabilities** (compression bombs, embedded pickle content)
53+
- **NumPy array integrity issues** (malformed headers, dangerous dtypes)
5254

5355
## 🚀 Quick Start
5456

@@ -83,6 +85,9 @@ pip install modelaudit[yaml]
8385
# For SafeTensors model scanning
8486
pip install modelaudit[safetensors]
8587

88+
# For Joblib model scanning
89+
pip install modelaudit[joblib]
90+
8691
# Install all optional dependencies
8792
pip install modelaudit[all]
8893
```
@@ -124,7 +129,7 @@ modelaudit scan model.pkl
124129
modelaudit scan model.onnx
125130

126131
# Scan multiple models
127-
modelaudit scan model1.pkl model2.h5 model3.pt
132+
modelaudit scan model1.pkl model2.h5 model3.pt model4.joblib model5.npy
128133

129134
# Scan a directory
130135
modelaudit scan ./models/
@@ -164,7 +169,7 @@ Issues found: 2 critical, 1 warnings
164169
165170
### Core Capabilities
166171
167-
- **Multiple Format Support**: PyTorch (.pt, .pth, .bin), TensorFlow (SavedModel, .pb), Keras (.h5, .hdf5, .keras), SafeTensors (.safetensors), GGUF/GGML (.gguf, .ggml), Pickle (.pkl, .pickle, .ckpt), ZIP archives (.zip), Manifests (.json, .yaml, .xml, etc.)
172+
- **Multiple Format Support**: PyTorch (.pt, .pth, .bin), TensorFlow (SavedModel, .pb), Keras (.h5, .hdf5, .keras), SafeTensors (.safetensors), GGUF/GGML (.gguf, .ggml), Pickle (.pkl, .pickle, .ckpt), Joblib (.joblib), NumPy (.npy, .npz), ZIP archives (.zip), Manifests (.json, .yaml, .xml, etc.)
168173
- **Automatic Format Detection**: Identifies model formats automatically
169174
- **Deep Security Analysis**: Examines model internals, not just metadata
170175
- **Recursive Archive Scanning**: Scans contents of ZIP files and nested archives
@@ -200,6 +205,8 @@ ModelAudit provides specialized security scanners for different model formats:
200205
| **ONNX** | `.onnx` | Custom operators, external data validation, tensor integrity |
201206
| **SafeTensors** | `.safetensors` | Metadata integrity, tensor validation |
202207
| **GGUF/GGML** | `.gguf`, `.ggml` | Header validation, metadata integrity, suspicious patterns |
208+
| **Joblib** | `.joblib` | Compression bomb detection, embedded pickle analysis |
209+
| **NumPy** | `.npy`, `.npz` | Array integrity, dangerous dtypes, dimension validation |
203210
| **ZIP Archives** | `.zip` | Recursive content scanning, zip bombs, directory traversal |
204211
| **Manifests** | `.json`, `.yaml`, `.yml`, `.xml`, `.toml`, `.ini`, `.cfg`, `.config`, `.manifest`, `.model`, `.metadata` | Suspicious keys, credential exposure, blacklisted patterns |
205212
@@ -357,7 +364,7 @@ pip install -e .[all]
357364

358365
# If optional dependencies fail, install base package first
359366
pip install modelaudit
360-
pip install tensorflow h5py torch pyyaml safetensors onnx # Add what you need
367+
pip install tensorflow h5py torch pyyaml safetensors onnx joblib # Add what you need
361368
```
362369
363370
**Large Models:**

modelaudit/scanners/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from . import (
22
base,
33
gguf_scanner,
4+
joblib_scanner,
45
keras_h5_scanner,
56
manifest_scanner,
7+
numpy_scanner,
68
onnx_scanner,
79
pickle_scanner,
810
pytorch_binary_scanner,
@@ -16,8 +18,10 @@
1618
# Import scanner classes for direct use
1719
from .base import BaseScanner, Issue, IssueSeverity, ScanResult
1820
from .gguf_scanner import GgufScanner
21+
from .joblib_scanner import JoblibScanner
1922
from .keras_h5_scanner import KerasH5Scanner
2023
from .manifest_scanner import ManifestScanner
24+
from .numpy_scanner import NumPyScanner
2125
from .onnx_scanner import OnnxScanner
2226
from .pickle_scanner import PickleScanner
2327
from .pytorch_binary_scanner import PyTorchBinaryScanner
@@ -39,6 +43,8 @@
3943
ManifestScanner,
4044
WeightDistributionScanner,
4145
GgufScanner,
46+
JoblibScanner,
47+
NumPyScanner,
4248
SafeTensorsScanner,
4349
ZipScanner, # Generic zip scanner should be last
4450
# Add new scanners here as they are implemented
@@ -56,7 +62,8 @@
5662
"manifest_scanner",
5763
"weight_distribution_scanner",
5864
"gguf_scanner",
59-
"safetensors_scanner",
65+
"joblib_scanner",
66+
"numpy_scanner",
6067
"zip_scanner",
6168
"BaseScanner",
6269
"ScanResult",
@@ -72,6 +79,8 @@
7279
"ManifestScanner",
7380
"WeightDistributionScanner",
7481
"GgufScanner",
82+
"JoblibScanner",
83+
"NumPyScanner",
7584
"ZipScanner",
7685
"SCANNER_REGISTRY",
7786
]
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from __future__ import annotations
2+
3+
import io
4+
import lzma
5+
import os
6+
import zlib
7+
from typing import Any, Optional
8+
9+
from ..utils.filetype import read_magic_bytes
10+
from .base import BaseScanner, IssueSeverity, ScanResult
11+
from .pickle_scanner import PickleScanner
12+
13+
14+
class JoblibScanner(BaseScanner):
15+
"""Scanner for joblib serialized files."""
16+
17+
name = "joblib"
18+
description = "Scans joblib files by decompressing and analyzing embedded pickle"
19+
supported_extensions = [".joblib"]
20+
21+
def __init__(self, config: Optional[dict[str, Any]] = None):
22+
super().__init__(config)
23+
self.pickle_scanner = PickleScanner(config)
24+
# Security limits
25+
self.max_decompression_ratio = self.config.get("max_decompression_ratio", 100.0)
26+
self.max_decompressed_size = self.config.get(
27+
"max_decompressed_size", 100 * 1024 * 1024
28+
) # 100MB
29+
self.max_file_read_size = self.config.get(
30+
"max_file_read_size", 100 * 1024 * 1024
31+
) # 100MB
32+
self.chunk_size = self.config.get("chunk_size", 8192) # 8KB chunks
33+
34+
@classmethod
35+
def can_handle(cls, path: str) -> bool:
36+
if not os.path.isfile(path):
37+
return False
38+
ext = os.path.splitext(path)[1].lower()
39+
if ext != ".joblib":
40+
return False
41+
return True
42+
43+
def _read_file_safely(self, path: str) -> bytes:
44+
"""Read file in chunks with size validation"""
45+
data = b""
46+
file_size = self.get_file_size(path)
47+
48+
if file_size > self.max_file_read_size:
49+
raise ValueError(
50+
f"File too large: {file_size} bytes (max: {self.max_file_read_size})"
51+
)
52+
53+
with open(path, "rb") as f:
54+
while True:
55+
chunk = f.read(self.chunk_size)
56+
if not chunk:
57+
break
58+
data += chunk
59+
if len(data) > self.max_file_read_size:
60+
raise ValueError(f"File read exceeds limit: {len(data)} bytes")
61+
return data
62+
63+
def _safe_decompress(self, data: bytes) -> bytes:
64+
"""Safely decompress data with bomb protection"""
65+
compressed_size = len(data)
66+
67+
# Try zlib first
68+
decompressed = None
69+
try:
70+
decompressed = zlib.decompress(data)
71+
except Exception:
72+
# Try lzma
73+
try:
74+
decompressed = lzma.decompress(data)
75+
except Exception as e:
76+
raise ValueError(f"Unable to decompress joblib file: {e}")
77+
78+
# Check decompression ratio for compression bomb detection
79+
if compressed_size > 0:
80+
ratio = len(decompressed) / compressed_size
81+
if ratio > self.max_decompression_ratio:
82+
raise ValueError(
83+
f"Suspicious compression ratio: {ratio:.1f}x "
84+
f"(max: {self.max_decompression_ratio}x) - possible compression bomb"
85+
)
86+
87+
# Check absolute decompressed size
88+
if len(decompressed) > self.max_decompressed_size:
89+
raise ValueError(
90+
f"Decompressed size too large: {len(decompressed)} bytes "
91+
f"(max: {self.max_decompressed_size})"
92+
)
93+
94+
return decompressed
95+
96+
def scan(self, path: str) -> ScanResult:
97+
path_check_result = self._check_path(path)
98+
if path_check_result:
99+
return path_check_result
100+
101+
result = self._create_result()
102+
file_size = self.get_file_size(path)
103+
result.metadata["file_size"] = file_size
104+
105+
try:
106+
self.current_file_path = path
107+
magic = read_magic_bytes(path, 4)
108+
data = self._read_file_safely(path)
109+
110+
if magic.startswith(b"PK"):
111+
# Treat as zip archive
112+
from .zip_scanner import ZipScanner
113+
114+
zip_scanner = ZipScanner(self.config)
115+
sub_result = zip_scanner.scan(path)
116+
result.merge(sub_result)
117+
result.bytes_scanned = sub_result.bytes_scanned
118+
result.metadata.update(sub_result.metadata)
119+
result.finish(success=sub_result.success)
120+
return result
121+
122+
if magic.startswith(b"\x80"):
123+
file_like = io.BytesIO(data)
124+
sub_result = self.pickle_scanner._scan_pickle_bytes(
125+
file_like, len(data)
126+
)
127+
result.merge(sub_result)
128+
result.bytes_scanned = len(data)
129+
else:
130+
# Try safe decompression
131+
try:
132+
decompressed = self._safe_decompress(data)
133+
except ValueError as e:
134+
result.add_issue(
135+
str(e),
136+
severity=IssueSeverity.CRITICAL,
137+
location=path,
138+
details={"security_check": "compression_bomb_detection"},
139+
)
140+
result.finish(success=False)
141+
return result
142+
except Exception as e:
143+
result.add_issue(
144+
f"Error decompressing joblib file: {e}",
145+
severity=IssueSeverity.CRITICAL,
146+
location=path,
147+
)
148+
result.finish(success=False)
149+
return result
150+
file_like = io.BytesIO(decompressed)
151+
sub_result = self.pickle_scanner._scan_pickle_bytes(
152+
file_like, len(decompressed)
153+
)
154+
result.merge(sub_result)
155+
result.bytes_scanned = len(decompressed)
156+
except Exception as e: # pragma: no cover
157+
result.add_issue(
158+
f"Error scanning joblib file: {e}",
159+
severity=IssueSeverity.CRITICAL,
160+
location=path,
161+
details={"exception": str(e), "exception_type": type(e).__name__},
162+
)
163+
result.finish(success=False)
164+
return result
165+
166+
result.finish(success=True)
167+
return result

0 commit comments

Comments
 (0)