Skip to content

Commit 5eefa15

Browse files
authored
fix: tighten dill MemoryError downgrade gating
Includes follow-up review fixes to require stronger serialization evidence for MemoryError scanner-limitation downgrades while preserving legitimate dill handling.
1 parent 55de730 commit 5eefa15

4 files changed

Lines changed: 314 additions & 65 deletions

File tree

modelaudit/scanners/pickle_scanner.py

Lines changed: 99 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2854,6 +2854,76 @@ def _is_legitimate_serialization_file(path: str) -> bool:
28542854
Validate that a file is a legitimate joblib or dill serialization file.
28552855
This helps prevent security bypass by simply renaming malicious files.
28562856
"""
2857+
2858+
def _analyze_sample_globals(sample: bytes) -> tuple[bool, bool]:
2859+
if not sample:
2860+
return False, False
2861+
2862+
validation_context = {"is_ml_content": False, "overall_confidence": 0.0, "frameworks": {}}
2863+
has_dangerous_global = False
2864+
has_joblib_like_global = False
2865+
2866+
def _record_global(mod: str, func: str) -> None:
2867+
nonlocal has_dangerous_global, has_joblib_like_global
2868+
if mod in {"joblib", "sklearn", "numpy"} or mod.startswith(("joblib.", "sklearn.", "numpy.")):
2869+
has_joblib_like_global = True
2870+
if _is_actually_dangerous_global(mod, func, validation_context):
2871+
has_dangerous_global = True
2872+
2873+
# Raw protocol 0/1 GLOBAL parsing keeps this heuristic usable even when
2874+
# pickletools itself is the code path that hits MemoryError.
2875+
cursor = 0
2876+
max_global_len = 256
2877+
while cursor < len(sample):
2878+
global_pos = sample.find(b"c", cursor)
2879+
if global_pos == -1:
2880+
break
2881+
2882+
module_end = sample.find(b"\n", global_pos + 1, global_pos + 1 + max_global_len)
2883+
if module_end == -1:
2884+
cursor = global_pos + 1
2885+
continue
2886+
2887+
function_end = sample.find(b"\n", module_end + 1, module_end + 1 + max_global_len)
2888+
if function_end == -1:
2889+
cursor = global_pos + 1
2890+
continue
2891+
2892+
try:
2893+
module = sample[global_pos + 1 : module_end].decode("utf-8")
2894+
function = sample[module_end + 1 : function_end].decode("utf-8")
2895+
except UnicodeDecodeError:
2896+
cursor = global_pos + 1
2897+
continue
2898+
2899+
if module and function:
2900+
_record_global(module, function)
2901+
cursor = function_end + 1
2902+
2903+
try:
2904+
opcodes = list(_genops_with_fallback(io.BytesIO(sample), max_items=128))
2905+
except (_GenopsBudgetExceeded, ValueError, struct.error, UnicodeDecodeError, EOFError):
2906+
return has_dangerous_global, has_joblib_like_global
2907+
except Exception:
2908+
return has_dangerous_global, has_joblib_like_global
2909+
2910+
stack_global_refs, _callable_refs, _origin_refs, _origin_is_ext, _malformed = _simulate_symbolic_reference_maps(
2911+
opcodes
2912+
)
2913+
2914+
for idx, (opcode, arg, _pos) in enumerate(opcodes):
2915+
op_name = getattr(opcode, "name", "")
2916+
if op_name in {"GLOBAL", "INST"} and isinstance(arg, str):
2917+
parsed = _parse_module_function(arg)
2918+
if parsed:
2919+
_record_global(parsed[0], parsed[1])
2920+
elif op_name == "STACK_GLOBAL":
2921+
stack_ref = stack_global_refs.get(idx)
2922+
if stack_ref:
2923+
_record_global(stack_ref[0], stack_ref[1])
2924+
2925+
return has_dangerous_global, has_joblib_like_global
2926+
28572927
try:
28582928
with open(path, "rb") as f:
28592929
# Read first few bytes to check for pickle magic
@@ -2874,34 +2944,24 @@ def _is_legitimate_serialization_file(path: str) -> bool:
28742944
# Common pickle opcode starts for protocols 0-1
28752945
return False
28762946

2877-
# For joblib files, look for joblib-specific patterns
2878-
# Also check extensionless files (e.g. HuggingFace cache blob hashes)
2947+
f.seek(0)
2948+
sample = f.read(64 * 1024)
2949+
has_dangerous_global, has_joblib_like_global = _analyze_sample_globals(sample)
2950+
if has_dangerous_global:
2951+
return False
2952+
2953+
# For joblib files and extensionless cache blobs, require opcode-level
2954+
# framework evidence instead of marker strings. Extension/substring
2955+
# checks alone are too easy to spoof.
28792956
ext_lower = os.path.splitext(path)[1].lower()
28802957
if ext_lower == ".joblib" or not ext_lower:
2881-
f.seek(0)
2882-
# Try to find joblib-specific markers in first 2KB
2883-
sample = f.read(2048)
2884-
# Look for joblib-specific indicators
2885-
joblib_indicators = [
2886-
b"joblib",
2887-
b"sklearn",
2888-
b"numpy",
2889-
b"_joblib",
2890-
b"__main__",
2891-
b"_pickle",
2892-
b"NumpyArrayWrapper",
2893-
]
2894-
if any(marker in sample for marker in joblib_indicators):
2895-
return True
2896-
# For extensionless files, only return False if no indicators found
2897-
# (don't fall through to dill check)
2898-
if not ext_lower:
2899-
return False
2958+
return bool(has_joblib_like_global)
29002959

2901-
# For dill files, they're usually just enhanced pickle
2960+
# Dill can serialize plain pickle-compatible objects without
2961+
# embedding obvious dill globals near the front of the stream, so
2962+
# a .dill extension remains a legitimacy signal after bounded
2963+
# dangerous-global rejection above.
29022964
if ext_lower == ".dill":
2903-
# Dill files should contain standard pickle format
2904-
# Additional validation could check for dill-specific patterns
29052965
return True
29062966

29072967
return False
@@ -5663,6 +5723,10 @@ def get_depth(x):
56635723
mod in {"joblib", "sklearn", "numpy"} or mod.startswith(("joblib.", "sklearn.", "numpy."))
56645724
for mod, _func, _opcode in advanced_globals
56655725
)
5726+
has_dill_globals = any(
5727+
mod in {"dill", "_dill"} or mod.startswith(("dill.", "_dill.", "dill._dill"))
5728+
for mod, _func, _opcode in advanced_globals
5729+
)
56665730
is_joblib_content = is_serialization_ext or (not file_ext and has_joblib_globals)
56675731

56685732
# Check for recursion errors on legitimate ML model files
@@ -5703,13 +5767,17 @@ def get_depth(x):
57035767
and (has_pytorch_advanced_global or has_ordereddict_global)
57045768
and not has_dangerous_advanced_global
57055769
)
5706-
# For serialization content, require positive evidence of legitimacy.
5707-
# Extension-validated files (.joblib/.dill) rely on extension +
5708-
# _is_legitimate_serialization_file() + no dangerous globals.
5709-
# Extensionless blobs require positive joblib/sklearn/numpy globals.
5710-
has_legitimate_serialization_globals = (is_serialization_ext and not has_dangerous_advanced_global) or (
5711-
not file_ext and bool(advanced_globals) and has_joblib_globals and not has_dangerous_advanced_global
5712-
)
5770+
# Require positive opcode-level framework globals for .joblib files;
5771+
# marker bytes and extensions alone are too weak. Dill stays more
5772+
# permissive because legitimate plain-object dill payloads may not
5773+
# expose dill globals before a resource limit hits.
5774+
has_extension_based_serialization_globals = (
5775+
file_ext == ".joblib" and bool(advanced_globals) and has_joblib_globals
5776+
) or (file_ext == ".dill" and (has_dill_globals or not advanced_globals))
5777+
has_legitimate_serialization_globals = (
5778+
has_extension_based_serialization_globals
5779+
or (not file_ext and bool(advanced_globals) and has_joblib_globals)
5780+
) and not has_dangerous_advanced_global
57135781
passes_global_gate = (
57145782
has_legitimate_pytorch_globals
57155783
if file_ext == ".bin"

tests/scanners/test_pickle_scanner.py

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2160,6 +2160,207 @@ def _raise_memory_error(*args: object, **kwargs: object) -> object:
21602160
assert not any(issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} for issue in result.issues)
21612161

21622162

2163+
def test_scan_dill_memory_error_without_dill_globals_not_downgraded(
2164+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
2165+
) -> None:
2166+
"""Dangerous-looking dill prefixes must not qualify for the INFO downgrade."""
2167+
model_path = tmp_path / "suspicious.dill"
2168+
model_path.write_bytes(b"\x80\x04cdill\nloads\nq\x00." + b"dill" + b"\x00" * (256 * 1024))
2169+
2170+
def _raise_memory_error(*args: object, **kwargs: object) -> object:
2171+
raise MemoryError("simulated parser memory limit")
2172+
2173+
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _raise_memory_error)
2174+
monkeypatch.setattr(
2175+
PickleScanner,
2176+
"_extract_globals_advanced",
2177+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: set(),
2178+
)
2179+
2180+
result = PickleScanner().scan(str(model_path))
2181+
2182+
assert not any(check.name == "Pickle Parse Resource Limit" for check in result.checks)
2183+
format_validation_checks = [check for check in result.checks if check.name == "Pickle Format Validation"]
2184+
assert len(format_validation_checks) == 1
2185+
assert format_validation_checks[0].status == CheckStatus.FAILED
2186+
assert format_validation_checks[0].severity == IssueSeverity.WARNING
2187+
assert format_validation_checks[0].details["exception_type"] == "MemoryError"
2188+
2189+
2190+
def test_scan_joblib_memory_error_requires_joblib_globals(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
2191+
"""Only .joblib files with parsed framework globals should downgrade to INFO."""
2192+
model_path = tmp_path / "legitimate.joblib"
2193+
model_path.write_bytes(b"\x80\x04cjoblib.numpy_pickle\nNumpyArrayWrapper\nq\x00." + b"\x00" * (256 * 1024))
2194+
2195+
def _raise_memory_error(*args: object, **kwargs: object) -> object:
2196+
raise MemoryError("simulated parser memory limit")
2197+
2198+
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _raise_memory_error)
2199+
monkeypatch.setattr(
2200+
PickleScanner,
2201+
"_extract_globals_advanced",
2202+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: {
2203+
("joblib.numpy_pickle", "NumpyArrayWrapper", "GLOBAL")
2204+
},
2205+
)
2206+
2207+
result = PickleScanner().scan(str(model_path))
2208+
2209+
resource_limit_checks = [check for check in result.checks if check.name == "Pickle Parse Resource Limit"]
2210+
assert len(resource_limit_checks) == 1
2211+
resource_limit_check = resource_limit_checks[0]
2212+
assert resource_limit_check.status == CheckStatus.FAILED
2213+
assert resource_limit_check.severity == IssueSeverity.INFO
2214+
assert resource_limit_check.details["reason"] == "memory_limit_on_legitimate_model"
2215+
assert resource_limit_check.details["exception_type"] == "MemoryError"
2216+
assert not any(
2217+
issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL}
2218+
and "Unable to parse pickle file" in issue.message
2219+
for issue in result.issues
2220+
)
2221+
2222+
2223+
def test_scan_joblib_memory_error_without_joblib_globals_not_downgraded(
2224+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
2225+
) -> None:
2226+
"""Marker bytes alone must not qualify a .joblib file for INFO downgrade."""
2227+
model_path = tmp_path / "suspicious.joblib"
2228+
model_path.write_bytes(b"\x80\x04joblibsklearn" + b"\x00" * (256 * 1024))
2229+
2230+
def _raise_memory_error(*args: object, **kwargs: object) -> object:
2231+
raise MemoryError("simulated parser memory limit")
2232+
2233+
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _raise_memory_error)
2234+
monkeypatch.setattr(
2235+
PickleScanner,
2236+
"_extract_globals_advanced",
2237+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: set(),
2238+
)
2239+
2240+
result = PickleScanner().scan(str(model_path))
2241+
2242+
assert not any(check.name == "Pickle Parse Resource Limit" for check in result.checks)
2243+
format_validation_checks = [check for check in result.checks if check.name == "Pickle Format Validation"]
2244+
assert len(format_validation_checks) == 1
2245+
assert format_validation_checks[0].status == CheckStatus.FAILED
2246+
assert format_validation_checks[0].severity == IssueSeverity.WARNING
2247+
assert format_validation_checks[0].details["exception_type"] == "MemoryError"
2248+
2249+
2250+
def test_scan_dill_memory_error_with_dill_globals_is_informational(
2251+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
2252+
) -> None:
2253+
"""Legitimate dill globals should still allow the scanner-limitation downgrade."""
2254+
model_path = tmp_path / "legitimate.dill"
2255+
model_path.write_bytes(b"\x80\x04" + b"\x00" * (256 * 1024))
2256+
2257+
def _raise_memory_error(*args: object, **kwargs: object) -> object:
2258+
raise MemoryError("simulated parser memory limit")
2259+
2260+
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _raise_memory_error)
2261+
monkeypatch.setattr(
2262+
PickleScanner,
2263+
"_extract_globals_advanced",
2264+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: {("dill", "dump", "GLOBAL")},
2265+
)
2266+
2267+
result = PickleScanner().scan(str(model_path))
2268+
2269+
resource_limit_checks = [check for check in result.checks if check.name == "Pickle Parse Resource Limit"]
2270+
assert len(resource_limit_checks) == 1
2271+
resource_limit_check = resource_limit_checks[0]
2272+
assert resource_limit_check.status == CheckStatus.FAILED
2273+
assert resource_limit_check.severity == IssueSeverity.INFO
2274+
assert resource_limit_check.details["reason"] == "memory_limit_on_legitimate_model"
2275+
assert resource_limit_check.details["exception_type"] == "MemoryError"
2276+
assert not any(issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} for issue in result.issues)
2277+
2278+
2279+
def test_scan_plain_dill_memory_error_without_globals_is_informational(
2280+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
2281+
) -> None:
2282+
"""Plain-object dill files should keep the scanner-limitation downgrade path."""
2283+
model_path = tmp_path / "plain.dill"
2284+
model_path.write_bytes(dill.dumps([1, 2, 3]))
2285+
2286+
def _raise_memory_error(*args: object, **kwargs: object) -> object:
2287+
raise MemoryError("simulated parser memory limit")
2288+
2289+
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _raise_memory_error)
2290+
monkeypatch.setattr(
2291+
PickleScanner,
2292+
"_extract_globals_advanced",
2293+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: set(),
2294+
)
2295+
2296+
result = PickleScanner().scan(str(model_path))
2297+
2298+
resource_limit_checks = [check for check in result.checks if check.name == "Pickle Parse Resource Limit"]
2299+
assert len(resource_limit_checks) == 1
2300+
resource_limit_check = resource_limit_checks[0]
2301+
assert resource_limit_check.status == CheckStatus.FAILED
2302+
assert resource_limit_check.severity == IssueSeverity.INFO
2303+
assert resource_limit_check.details["reason"] == "memory_limit_on_legitimate_model"
2304+
assert resource_limit_check.details["exception_type"] == "MemoryError"
2305+
assert not any(issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} for issue in result.issues)
2306+
2307+
2308+
def test_scan_dill_memory_error_with_internal_dill_globals_is_informational(
2309+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
2310+
) -> None:
2311+
"""Internal dill globals should qualify for the scanner-limitation downgrade."""
2312+
model_path = tmp_path / "legitimate.dill"
2313+
model_path.write_bytes(b"\x80\x04" + b"\x00" * (256 * 1024))
2314+
2315+
def _raise_memory_error(*args: object, **kwargs: object) -> object:
2316+
raise MemoryError("simulated parser memory limit")
2317+
2318+
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _raise_memory_error)
2319+
monkeypatch.setattr(
2320+
PickleScanner,
2321+
"_extract_globals_advanced",
2322+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: {("_dill", "dump", "GLOBAL")},
2323+
)
2324+
2325+
result = PickleScanner().scan(str(model_path))
2326+
2327+
resource_limit_checks = [check for check in result.checks if check.name == "Pickle Parse Resource Limit"]
2328+
assert len(resource_limit_checks) == 1
2329+
resource_limit_check = resource_limit_checks[0]
2330+
assert resource_limit_check.status == CheckStatus.FAILED
2331+
assert resource_limit_check.severity == IssueSeverity.INFO
2332+
assert resource_limit_check.details["reason"] == "memory_limit_on_legitimate_model"
2333+
assert resource_limit_check.details["exception_type"] == "MemoryError"
2334+
assert not any(issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} for issue in result.issues)
2335+
2336+
2337+
def test_scan_joblib_memory_error_with_dangerous_prefix_not_downgraded(
2338+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
2339+
) -> None:
2340+
"""Marker bytes must not hide a dangerous pickle prefix in .joblib files."""
2341+
model_path = tmp_path / "suspicious.joblib"
2342+
model_path.write_bytes(b"\x80\x02cbuiltins\neval\nq\x00." + b"joblibsklearnnumpy" + b"\x00" * (256 * 1024))
2343+
2344+
def _raise_memory_error(*args: object, **kwargs: object) -> object:
2345+
raise MemoryError("simulated parser memory limit")
2346+
2347+
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _raise_memory_error)
2348+
monkeypatch.setattr(
2349+
PickleScanner,
2350+
"_extract_globals_advanced",
2351+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: set(),
2352+
)
2353+
2354+
result = PickleScanner().scan(str(model_path))
2355+
2356+
assert not any(check.name == "Pickle Parse Resource Limit" for check in result.checks)
2357+
format_validation_checks = [check for check in result.checks if check.name == "Pickle Format Validation"]
2358+
assert len(format_validation_checks) == 1
2359+
assert format_validation_checks[0].status == CheckStatus.FAILED
2360+
assert format_validation_checks[0].severity == IssueSeverity.WARNING
2361+
assert format_validation_checks[0].details["exception_type"] == "MemoryError"
2362+
2363+
21632364
def test_scan_memory_error_with_dangerous_globals_not_downgraded(
21642365
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
21652366
) -> None:
@@ -2179,7 +2380,7 @@ def _raise_memory_error(*args: object, **kwargs: object) -> object:
21792380
monkeypatch.setattr(
21802381
PickleScanner,
21812382
"_extract_globals_advanced",
2182-
lambda self, file_obj, multiple_pickles=True: {("builtins", "eval", "GLOBAL")},
2383+
lambda self, file_obj, multiple_pickles=True, scan_start_time=None: {("builtins", "eval", "GLOBAL")},
21832384
)
21842385

21852386
result = PickleScanner().scan(str(model_path))

0 commit comments

Comments
 (0)