From b7cb604f08a756377f8b81a6722842c676d733e7 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Fri, 13 Mar 2026 18:32:19 -0400 Subject: [PATCH 1/7] fix: recurse into numpy object pickle payloads --- CHANGELOG.md | 1 + modelaudit/scanners/numpy_scanner.py | 23 +++++- tests/scanners/test_numpy_scanner.py | 117 +++++++++++++++++++++++++++ tests/scanners/test_zip_scanner.py | 17 ++++ 4 files changed, 157 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85b98db68..9b02214eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- **security:** recurse into object-dtype `.npy` payloads and `.npz` object members with the pickle scanner while preserving CVE-2019-6446 warnings and archive-member context - **security:** harden TensorFlow weight extraction limits to bound actual tensor payload materialization, including malformed `tensor_content` and string-backed tensors, and continue scanning past oversized `Const` nodes - **security:** stream TAR members to temp files under size limits instead of buffering whole entries in memory during scan - **security:** inspect TensorFlow SavedModel function definitions when scanning for dangerous ops and protobuf string abuse, with function-aware finding locations diff --git a/modelaudit/scanners/numpy_scanner.py b/modelaudit/scanners/numpy_scanner.py index 23e895970..9a49e03c4 100644 --- a/modelaudit/scanners/numpy_scanner.py +++ b/modelaudit/scanners/numpy_scanner.py @@ -4,9 +4,10 @@ import sys import warnings -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar from .base import BaseScanner, IssueSeverity, ScanResult +from .pickle_scanner import PickleScanner # Import NumPy with compatibility handling try: @@ -88,6 +89,17 @@ def _validate_array_dimensions(self, shape: tuple[int, ...]) -> None: CVE_2019_6446_CVSS = 9.8 CVE_2019_6446_CWE = "CWE-502" + def _scan_embedded_pickle_payload( + self, + file_obj: BinaryIO, + payload_size: int, + context_path: str, + ) -> ScanResult: + """Reuse PickleScanner analysis for object-dtype NumPy payloads.""" + pickle_scanner = PickleScanner(config=self.config) + pickle_scanner.current_file_path = context_path + return pickle_scanner._scan_pickle_bytes(file_obj, payload_size) + def _validate_dtype(self, dtype: Any) -> None: """Validate numpy dtype for security""" # Check for problematic data types @@ -299,6 +311,15 @@ def scan(self, path: str) -> ScanResult: ), ) + f.seek(data_offset) + embedded_result = self._scan_embedded_pickle_payload( + f, + file_size - data_offset, + path, + ) + result.issues.extend(embedded_result.issues) + result.checks.extend(embedded_result.checks) + self._validate_dtype(dtype) result.add_check( name="Data Type Safety Check", diff --git a/tests/scanners/test_numpy_scanner.py b/tests/scanners/test_numpy_scanner.py index 4ef037e44..9a5f6386c 100644 --- a/tests/scanners/test_numpy_scanner.py +++ b/tests/scanners/test_numpy_scanner.py @@ -100,3 +100,120 @@ def test_structured_with_object_field_triggers_cve(self, tmp_path): cve_checks = [c for c in result.checks if "CVE-2019-6446" in (c.name + c.message)] assert len(cve_checks) > 0, "Structured dtype with object field should trigger CVE" + + +class _ExecPayload: + def __reduce__(self): + return (exec, ("print('owned')",)) + + +class _SSLPayload: + def __reduce__(self): + import ssl + + return (ssl.get_server_certificate, (("example.com", 443),)) + + +def _failed_checks(result): + return [c for c in result.checks if c.status.value == "failed"] + + +def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path) -> None: + arr = np.array([_ExecPayload()], dtype=object) + path = tmp_path / "malicious_object.npy" + np.save(path, arr, allow_pickle=True) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) for c in failed) + assert any("exec" in (c.message.lower()) for c in failed) + + +def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path) -> None: + arr = np.array([_SSLPayload()], dtype=object) + path = tmp_path / "malicious_ssl_object.npy" + np.save(path, arr, allow_pickle=True) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) for c in failed) + assert any("ssl.get_server_certificate" in c.message for c in failed) + + +def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path) -> None: + npz_path = tmp_path / "numeric_only.npz" + np.savez(npz_path, a=np.arange(4), b=np.ones((2, 2), dtype=np.float32)) + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + assert not any("CVE-2019-6446" in (c.name + c.message) for c in result.checks) + assert not any("exec" in c.message.lower() for c in result.checks) + + +def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_path) -> None: + safe = np.array([1, 2, 3], dtype=np.int64) + malicious = np.array([_ExecPayload()], dtype=object) + npz_path = tmp_path / "mixed_object.npz" + np.savez(npz_path, safe=safe, payload=malicious) + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) and "payload.npy" in str(c.location) for c in failed) + assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) + + +def test_benign_object_dtype_numpy_no_nested_critical(tmp_path) -> None: + arr = np.array([{"k": "v"}, [1, 2, 3]], dtype=object) + path = tmp_path / "benign_object.npy" + np.save(path, arr, allow_pickle=True) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + assert any("CVE-2019-6446" in (c.name + c.message) for c in result.checks) + assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues if "CVE-2019-6446" not in i.message) + + +def test_benign_object_dtype_npz_no_nested_critical(tmp_path) -> None: + npz_path = tmp_path / "benign_object.npz" + np.savez(npz_path, safe=np.array([{"x": 1}], dtype=object)) + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + assert any("CVE-2019-6446" in (c.name + c.message) for c in result.checks) + assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues) + + +def test_truncated_npy_fails_safely(tmp_path) -> None: + arr = np.array([_ExecPayload()], dtype=object) + path = tmp_path / "truncated.npy" + np.save(path, arr, allow_pickle=True) + path.write_bytes(path.read_bytes()[:-8]) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + assert any(i.severity == IssueSeverity.INFO for i in result.issues) + + +def test_corrupted_npz_fails_safely(tmp_path) -> None: + npz_path = tmp_path / "corrupt.npz" + npz_path.write_bytes(b"not-a-zip") + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + assert result.success is False + assert any(i.severity == IssueSeverity.INFO for i in result.issues) diff --git a/tests/scanners/test_zip_scanner.py b/tests/scanners/test_zip_scanner.py index 6c1ff5f77..25796b4c5 100644 --- a/tests/scanners/test_zip_scanner.py +++ b/tests/scanners/test_zip_scanner.py @@ -314,6 +314,23 @@ def test_scan_zip_with_prefixed_proto0_pickle_disguised_as_text(self, tmp_path: f"Expected critical os/posix.system issue, got: {critical_messages}" ) + def test_scan_npz_with_object_member_recurses_into_pickle(self, tmp_path: Path) -> None: + import numpy as np + + class _ExecPayload: + def __reduce__(self): + return (exec, ("print('owned')",)) + + archive_path = tmp_path / "payload.npz" + np.savez(archive_path, safe=np.arange(3), payload=np.array([_ExecPayload()], dtype=object)) + + result = self.scanner.scan(str(archive_path)) + assert result.success is True + + failed_checks = [c for c in result.checks if c.status.value == "failed"] + assert any("cve-2019-6446" in (c.name + c.message).lower() for c in failed_checks) + assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) + def test_scan_zip_with_plain_text_global_prefix_not_treated_as_pickle(self, tmp_path: Path) -> None: """Plain text entries that start with GLOBAL-like bytes should not trigger pickle parse warnings.""" archive_path = tmp_path / "plain_text_payload.zip" From 407652a89a92f2f3c5a21e5c2560eb28c3657b57 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Fri, 13 Mar 2026 18:49:46 -0400 Subject: [PATCH 2/7] test: type annotate numpy recursion regressions --- tests/scanners/test_numpy_scanner.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/scanners/test_numpy_scanner.py b/tests/scanners/test_numpy_scanner.py index 9a5f6386c..7f8f7345e 100644 --- a/tests/scanners/test_numpy_scanner.py +++ b/tests/scanners/test_numpy_scanner.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np from modelaudit.scanners.base import IssueSeverity @@ -118,7 +120,7 @@ def _failed_checks(result): return [c for c in result.checks if c.status.value == "failed"] -def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path) -> None: +def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path: Path) -> None: arr = np.array([_ExecPayload()], dtype=object) path = tmp_path / "malicious_object.npy" np.save(path, arr, allow_pickle=True) @@ -131,7 +133,7 @@ def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path) -> None: assert any("exec" in (c.message.lower()) for c in failed) -def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path) -> None: +def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path: Path) -> None: arr = np.array([_SSLPayload()], dtype=object) path = tmp_path / "malicious_ssl_object.npy" np.save(path, arr, allow_pickle=True) @@ -144,7 +146,7 @@ def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path) -> None: assert any("ssl.get_server_certificate" in c.message for c in failed) -def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path) -> None: +def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path: Path) -> None: npz_path = tmp_path / "numeric_only.npz" np.savez(npz_path, a=np.arange(4), b=np.ones((2, 2), dtype=np.float32)) @@ -156,7 +158,7 @@ def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path) -> None: assert not any("exec" in c.message.lower() for c in result.checks) -def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_path) -> None: +def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_path: Path) -> None: safe = np.array([1, 2, 3], dtype=np.int64) malicious = np.array([_ExecPayload()], dtype=object) npz_path = tmp_path / "mixed_object.npz" @@ -171,7 +173,7 @@ def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_pat assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) -def test_benign_object_dtype_numpy_no_nested_critical(tmp_path) -> None: +def test_benign_object_dtype_numpy_no_nested_critical(tmp_path: Path) -> None: arr = np.array([{"k": "v"}, [1, 2, 3]], dtype=object) path = tmp_path / "benign_object.npy" np.save(path, arr, allow_pickle=True) @@ -183,7 +185,7 @@ def test_benign_object_dtype_numpy_no_nested_critical(tmp_path) -> None: assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues if "CVE-2019-6446" not in i.message) -def test_benign_object_dtype_npz_no_nested_critical(tmp_path) -> None: +def test_benign_object_dtype_npz_no_nested_critical(tmp_path: Path) -> None: npz_path = tmp_path / "benign_object.npz" np.savez(npz_path, safe=np.array([{"x": 1}], dtype=object)) @@ -195,7 +197,7 @@ def test_benign_object_dtype_npz_no_nested_critical(tmp_path) -> None: assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues) -def test_truncated_npy_fails_safely(tmp_path) -> None: +def test_truncated_npy_fails_safely(tmp_path: Path) -> None: arr = np.array([_ExecPayload()], dtype=object) path = tmp_path / "truncated.npy" np.save(path, arr, allow_pickle=True) @@ -207,7 +209,7 @@ def test_truncated_npy_fails_safely(tmp_path) -> None: assert any(i.severity == IssueSeverity.INFO for i in result.issues) -def test_corrupted_npz_fails_safely(tmp_path) -> None: +def test_corrupted_npz_fails_safely(tmp_path: Path) -> None: npz_path = tmp_path / "corrupt.npz" npz_path.write_bytes(b"not-a-zip") From 56782ee77c1c70817ebad53d01e2ea87e9186928 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Fri, 13 Mar 2026 18:53:52 -0400 Subject: [PATCH 3/7] fix(numpy): preserve npz member check context --- modelaudit/core.py | 6 ++- modelaudit/scanners/zip_scanner.py | 63 +++++++++++++++------------- tests/scanners/test_numpy_scanner.py | 22 +++++----- tests/scanners/test_zip_scanner.py | 4 ++ tests/test_core_asset_extraction.py | 21 ++++++++++ 5 files changed, 75 insertions(+), 41 deletions(-) diff --git a/modelaudit/core.py b/modelaudit/core.py index 32e892b23..bbbdf7551 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -272,8 +272,12 @@ def _group_checks_by_asset(checks_list: list[Any]) -> dict[tuple[str, str], list check_name = check.get("name", "Unknown Check") location = check.get("location", "") primary_asset = _extract_primary_asset_from_location(location) + details = check.get("details") + zip_entry = details.get("zip_entry") if isinstance(details, dict) else None - group_key = (check_name, primary_asset) + asset_group = f"{primary_asset}:{zip_entry}" if isinstance(zip_entry, str) and zip_entry else primary_asset + + group_key = (check_name, asset_group) check_groups[group_key].append(check) return check_groups diff --git a/modelaudit/scanners/zip_scanner.py b/modelaudit/scanners/zip_scanner.py index 6594041f5..525b3cac4 100644 --- a/modelaudit/scanners/zip_scanner.py +++ b/modelaudit/scanners/zip_scanner.py @@ -116,6 +116,37 @@ def scan(self, path: str) -> ScanResult: result.metadata["file_size"] = os.path.getsize(path) return result + def _rewrite_nested_result_context( + self, scan_result: ScanResult, tmp_path: str, archive_path: str, entry_name: str + ) -> None: + """Rewrite nested result locations so archive members, not temp files, are reported.""" + archive_location = f"{archive_path}:{entry_name}" + + for issue in scan_result.issues: + if issue.location: + if issue.location.startswith(tmp_path): + issue.location = issue.location.replace(tmp_path, archive_location, 1) + else: + issue.location = f"{archive_location} {issue.location}" + else: + issue.location = archive_location + + if issue.details: + issue.details["zip_entry"] = entry_name + else: + issue.details = {"zip_entry": entry_name} + + for check in scan_result.checks: + if check.location: + if check.location.startswith(tmp_path): + check.location = check.location.replace(tmp_path, archive_location, 1) + else: + check.location = f"{archive_location} {check.location}" + else: + check.location = archive_location + + check.details["zip_entry"] = entry_name + def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult: """Recursively scan a ZIP file and its contents""" result = ScanResult(scanner_name=self.name) @@ -317,16 +348,7 @@ def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult: if name.lower().endswith(".zip"): try: nested_result = self._scan_zip_file(tmp_path, depth + 1) - # Update locations in nested results - for issue in nested_result.issues: - if issue.location and issue.location.startswith( - tmp_path, - ): - issue.location = issue.location.replace( - tmp_path, - f"{path}:{name}", - 1, - ) + self._rewrite_nested_result_context(nested_result, tmp_path, path, name) result.merge(nested_result) asset_entry = asset_from_scan_result( @@ -348,26 +370,7 @@ def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult: # Use core.scan_file to scan with appropriate scanner file_result = core.scan_file(tmp_path, self.config) - - # Update locations in file results - for issue in file_result.issues: - if issue.location: - if issue.location.startswith(tmp_path): - issue.location = issue.location.replace( - tmp_path, - f"{path}:{name}", - 1, - ) - else: - issue.location = f"{path}:{name} {issue.location}" - else: - issue.location = f"{path}:{name}" - - # Add zip entry name to details - if issue.details: - issue.details["zip_entry"] = name - else: - issue.details = {"zip_entry": name} + self._rewrite_nested_result_context(file_result, tmp_path, path, name) result.merge(file_result) diff --git a/tests/scanners/test_numpy_scanner.py b/tests/scanners/test_numpy_scanner.py index 9a5f6386c..47936808e 100644 --- a/tests/scanners/test_numpy_scanner.py +++ b/tests/scanners/test_numpy_scanner.py @@ -1,6 +1,8 @@ +from pathlib import Path + import numpy as np -from modelaudit.scanners.base import IssueSeverity +from modelaudit.scanners.base import Check, IssueSeverity, ScanResult from modelaudit.scanners.numpy_scanner import NumPyScanner @@ -114,11 +116,11 @@ def __reduce__(self): return (ssl.get_server_certificate, (("example.com", 443),)) -def _failed_checks(result): +def _failed_checks(result: ScanResult) -> list[Check]: return [c for c in result.checks if c.status.value == "failed"] -def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path) -> None: +def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path: Path) -> None: arr = np.array([_ExecPayload()], dtype=object) path = tmp_path / "malicious_object.npy" np.save(path, arr, allow_pickle=True) @@ -131,7 +133,7 @@ def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path) -> None: assert any("exec" in (c.message.lower()) for c in failed) -def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path) -> None: +def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path: Path) -> None: arr = np.array([_SSLPayload()], dtype=object) path = tmp_path / "malicious_ssl_object.npy" np.save(path, arr, allow_pickle=True) @@ -144,7 +146,7 @@ def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path) -> None: assert any("ssl.get_server_certificate" in c.message for c in failed) -def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path) -> None: +def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path: Path) -> None: npz_path = tmp_path / "numeric_only.npz" np.savez(npz_path, a=np.arange(4), b=np.ones((2, 2), dtype=np.float32)) @@ -156,7 +158,7 @@ def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path) -> None: assert not any("exec" in c.message.lower() for c in result.checks) -def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_path) -> None: +def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_path: Path) -> None: safe = np.array([1, 2, 3], dtype=np.int64) malicious = np.array([_ExecPayload()], dtype=object) npz_path = tmp_path / "mixed_object.npz" @@ -171,7 +173,7 @@ def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_pat assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) -def test_benign_object_dtype_numpy_no_nested_critical(tmp_path) -> None: +def test_benign_object_dtype_numpy_no_nested_critical(tmp_path: Path) -> None: arr = np.array([{"k": "v"}, [1, 2, 3]], dtype=object) path = tmp_path / "benign_object.npy" np.save(path, arr, allow_pickle=True) @@ -183,7 +185,7 @@ def test_benign_object_dtype_numpy_no_nested_critical(tmp_path) -> None: assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues if "CVE-2019-6446" not in i.message) -def test_benign_object_dtype_npz_no_nested_critical(tmp_path) -> None: +def test_benign_object_dtype_npz_no_nested_critical(tmp_path: Path) -> None: npz_path = tmp_path / "benign_object.npz" np.savez(npz_path, safe=np.array([{"x": 1}], dtype=object)) @@ -195,7 +197,7 @@ def test_benign_object_dtype_npz_no_nested_critical(tmp_path) -> None: assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues) -def test_truncated_npy_fails_safely(tmp_path) -> None: +def test_truncated_npy_fails_safely(tmp_path: Path) -> None: arr = np.array([_ExecPayload()], dtype=object) path = tmp_path / "truncated.npy" np.save(path, arr, allow_pickle=True) @@ -207,7 +209,7 @@ def test_truncated_npy_fails_safely(tmp_path) -> None: assert any(i.severity == IssueSeverity.INFO for i in result.issues) -def test_corrupted_npz_fails_safely(tmp_path) -> None: +def test_corrupted_npz_fails_safely(tmp_path: Path) -> None: npz_path = tmp_path / "corrupt.npz" npz_path.write_bytes(b"not-a-zip") diff --git a/tests/scanners/test_zip_scanner.py b/tests/scanners/test_zip_scanner.py index 25796b4c5..46e71eded 100644 --- a/tests/scanners/test_zip_scanner.py +++ b/tests/scanners/test_zip_scanner.py @@ -329,6 +329,10 @@ def __reduce__(self): failed_checks = [c for c in result.checks if c.status.value == "failed"] assert any("cve-2019-6446" in (c.name + c.message).lower() for c in failed_checks) + assert any( + c.details.get("zip_entry") == "payload.npy" and c.location and f"{archive_path}:payload.npy" in c.location + for c in failed_checks + ) assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) def test_scan_zip_with_plain_text_global_prefix_not_treated_as_pickle(self, tmp_path: Path) -> None: diff --git a/tests/test_core_asset_extraction.py b/tests/test_core_asset_extraction.py index 532dd7fd2..be02ec30e 100644 --- a/tests/test_core_asset_extraction.py +++ b/tests/test_core_asset_extraction.py @@ -4,6 +4,7 @@ from pathlib import Path from unittest.mock import patch +import numpy as np import pytest from modelaudit.core import _extract_primary_asset_from_location, scan_model_directory_or_file @@ -72,3 +73,23 @@ def test_check_consolidation_handles_newlines_in_file_paths(tmp_path: Path) -> N expected_paths = {str(group / "dup a.pkl"), str(group / "dup b.pkl")} assert set(check.details["duplicate_files"]) == expected_paths assert check.location in expected_paths + + +def test_check_consolidation_keeps_distinct_npz_member_findings(tmp_path: Path) -> None: + class ExecPayload: + def __reduce__(self): + return (exec, ("print('owned')",)) + + archive_path = tmp_path / "payload.npz" + np.savez(archive_path, safe=np.arange(3), payload=np.array([ExecPayload()], dtype=object)) + + result = scan_model_directory_or_file(str(archive_path)) + data_type_checks = [ + check + for check in result.checks + if check.name == "Data Type Safety Check" and check.status.value == "failed" + ] + + assert len(data_type_checks) == 1 + assert data_type_checks[0].location == f"{archive_path}:payload.npy" + assert data_type_checks[0].details.get("zip_entry") == "payload.npy" From f42f85475eda5e926f4924744fbfde1179dbb25a Mon Sep 17 00:00:00 2001 From: mldangelo Date: Fri, 13 Mar 2026 18:59:31 -0400 Subject: [PATCH 4/7] test: format asset extraction regressions --- tests/test_core_asset_extraction.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_core_asset_extraction.py b/tests/test_core_asset_extraction.py index be02ec30e..aa600bc29 100644 --- a/tests/test_core_asset_extraction.py +++ b/tests/test_core_asset_extraction.py @@ -85,9 +85,7 @@ def __reduce__(self): result = scan_model_directory_or_file(str(archive_path)) data_type_checks = [ - check - for check in result.checks - if check.name == "Data Type Safety Check" and check.status.value == "failed" + check for check in result.checks if check.name == "Data Type Safety Check" and check.status.value == "failed" ] assert len(data_type_checks) == 1 From 95ae02c13ae61c86df247d67665f88d4d5f1b2d0 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Sat, 14 Mar 2026 17:36:49 -0400 Subject: [PATCH 5/7] test: type annotate numpy trailing-bytes regression --- tests/scanners/test_numpy_scanner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scanners/test_numpy_scanner.py b/tests/scanners/test_numpy_scanner.py index e3ae700cd..9e210c731 100644 --- a/tests/scanners/test_numpy_scanner.py +++ b/tests/scanners/test_numpy_scanner.py @@ -273,7 +273,7 @@ def test_truncated_npy_fails_safely(tmp_path: Path) -> None: ), f"Expected a non-critical corruption finding, got: {[i.message for i in result.issues]}" -def test_object_dtype_numpy_trailing_bytes_fail_integrity(tmp_path) -> None: +def test_object_dtype_numpy_trailing_bytes_fail_integrity(tmp_path: Path) -> None: arr = np.array([{"k": "v"}], dtype=object) path = tmp_path / "trailing.npy" np.save(path, arr, allow_pickle=True) From 0e141f47d7d38922a69c17dd5fb1aa0f0fde66c4 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Sun, 15 Mar 2026 09:33:35 -0400 Subject: [PATCH 6/7] fix: harden numpy recursion follow-up checks Cap embedded pickle prefix reads at 10MB, add the missing regression type annotations, and preserve the security exit code for streaming scans that find critical issues without operational failures. --- modelaudit/core.py | 5 +++- modelaudit/models.py | 2 +- modelaudit/scanners/pickle_scanner.py | 8 +++--- tests/test_cli.py | 20 +++++++++++++++ tests/test_core_asset_extraction.py | 2 +- tests/test_security_enhancements.py | 3 ++- tests/test_streaming_scan.py | 35 ++++++++++++++++++++++----- 7 files changed, 62 insertions(+), 13 deletions(-) diff --git a/modelaudit/core.py b/modelaudit/core.py index bbbdf7551..05134cba8 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -1606,7 +1606,10 @@ def scan_model_streaming( scan_result_dict = { "bytes_scanned": scan_result.bytes_scanned, "files_scanned": 1, # Each scan_result represents one file - "has_errors": scan_result.has_errors, + # ScanResult.has_errors means "critical findings", but + # ModelAuditResultModel.has_errors is reserved for + # operational scan failures. + "has_errors": not scan_result.success, "success": scan_result.success, "issues": [issue.__dict__ for issue in (scan_result.issues or [])], "checks": [check.__dict__ for check in (scan_result.checks or [])], diff --git a/modelaudit/models.py b/modelaudit/models.py index 0077629f8..d64a095e3 100644 --- a/modelaudit/models.py +++ b/modelaudit/models.py @@ -460,7 +460,7 @@ def aggregate_scan_result_direct(self, scan_result: Any) -> None: self.bytes_scanned += scan_result.bytes_scanned self.files_scanned += 1 # Each ScanResult represents one file scan - if scan_result.has_errors: + if not scan_result.success: self.has_errors = True # Update success status - only set to False for operational errors diff --git a/modelaudit/scanners/pickle_scanner.py b/modelaudit/scanners/pickle_scanner.py index e7858f59f..ddae4c740 100644 --- a/modelaudit/scanners/pickle_scanner.py +++ b/modelaudit/scanners/pickle_scanner.py @@ -4193,12 +4193,14 @@ def _scan_pickle_bytes(self, file_obj: BinaryIO, file_size: int) -> ScanResult: suspicious_count = 0 # For large files, use chunked reading to avoid memory issues - MAX_MEMORY_READ = 50 * 1024 * 1024 # 50MB max in memory at once + MAX_MEMORY_READ = 10 * 1024 * 1024 # 10MB max in memory at once current_pos = file_obj.tell() - # Read file data - either all at once for small files or first chunk for large files - # For large files, read first 50MB for pattern analysis (critical malicious code is usually at the beginning) + # Read file data - either all at once for small files or first chunk for large files. + # For large files, read only the first 10MB for pattern analysis to cap + # embedded-pickle memory usage while still inspecting the most security- + # relevant prefix. file_data = file_obj.read() if file_size <= MAX_MEMORY_READ else file_obj.read(MAX_MEMORY_READ) file_obj.seek(current_pos) # Reset position diff --git a/tests/test_cli.py b/tests/test_cli.py index 1136f89b1..b20f7d8ae 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1419,6 +1419,26 @@ def __reduce__(self): ) +def test_exit_code_security_issues_streaming_local_directory(tmp_path): + """Streaming local-directory scans should keep security findings as exit code 1.""" + import pickle + + evil_pickle_path = tmp_path / "malicious.pkl" + + class MaliciousClass: + def __reduce__(self): + return (os.system, ('echo "This is a malicious pickle"',)) + + with evil_pickle_path.open("wb") as f: + pickle.dump(MaliciousClass(), f) + + runner = CliRunner() + result = runner.invoke(cli, ["scan", "--stream", "--format", "text", str(tmp_path)]) + + assert result.exit_code == 1, f"Expected exit code 1, got {result.exit_code}. Output: {result.output}" + assert not evil_pickle_path.exists() + + def test_exit_code_scan_errors(tmp_path): """Test exit code 2 when errors occur during scanning.""" runner = CliRunner() diff --git a/tests/test_core_asset_extraction.py b/tests/test_core_asset_extraction.py index 7c0abdbd3..d76565840 100644 --- a/tests/test_core_asset_extraction.py +++ b/tests/test_core_asset_extraction.py @@ -80,7 +80,7 @@ def test_check_consolidation_handles_newlines_in_file_paths(tmp_path: Path) -> N def test_npz_member_checks_keep_archive_member_locations(tmp_path: Path) -> None: class _ExecPayload: - def __reduce__(self): + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: return (exec, ("print('owned')",)) archive_path = tmp_path / "payload.npz" diff --git a/tests/test_security_enhancements.py b/tests/test_security_enhancements.py index 7379aeec0..24e32dd46 100644 --- a/tests/test_security_enhancements.py +++ b/tests/test_security_enhancements.py @@ -6,6 +6,7 @@ import pickle import zipfile import zlib +from pathlib import Path import numpy as np @@ -270,7 +271,7 @@ def test_dimension_size_limit(self, tmp_path): size_issues = [issue for issue in result.issues if "too large" in issue.message.lower()] assert len(size_issues) > 0 - def test_dangerous_dtype_reports_cve_warning(self, tmp_path): + def test_dangerous_dtype_reports_cve_warning(self, tmp_path: Path) -> None: """Object dtype arrays should scan successfully while emitting CVE-2019-6446 warnings.""" scanner = NumPyScanner() npy_file = tmp_path / "object_dtype.npy" diff --git a/tests/test_streaming_scan.py b/tests/test_streaming_scan.py index afc68a071..287a19ae9 100644 --- a/tests/test_streaming_scan.py +++ b/tests/test_streaming_scan.py @@ -7,8 +7,8 @@ import pytest -from modelaudit.core import scan_model_directory_or_file, scan_model_streaming -from modelaudit.scanners.base import ScanResult +from modelaudit.core import determine_exit_code, scan_model_directory_or_file, scan_model_streaming +from modelaudit.scanners.base import IssueSeverity, ScanResult from modelaudit.utils.helpers.secure_hasher import compute_aggregate_hash @@ -27,11 +27,13 @@ def temp_test_files(): file_path.unlink() -def create_mock_scan_result(bytes_scanned: int = 1024) -> ScanResult: +def create_mock_scan_result(bytes_scanned: int = 1024, with_critical_issue: bool = False) -> ScanResult: """Create a mock ScanResult for testing.""" result = ScanResult(scanner_name="test_scanner") result.bytes_scanned = bytes_scanned result.success = True + if with_critical_issue: + result.add_issue("Detected malicious behavior", severity=IssueSeverity.CRITICAL, location="test.pkl") return result @@ -113,9 +115,30 @@ def file_generator(): for f in temp_test_files: assert not f.exists() - # Verify scan completed - assert result.files_scanned == 3 - assert result.content_hash is not None + # Verify scan completed + assert result.files_scanned == 3 + assert result.content_hash is not None + + +def test_scan_model_streaming_critical_findings_do_not_set_operational_errors(temp_test_files): + """Security findings in streaming mode should still return the security exit code.""" + + def file_generator(): + yield (temp_test_files[0], True) + + with patch("modelaudit.core.scan_file") as mock_scan: + mock_scan.return_value = create_mock_scan_result(with_critical_issue=True) + + result = scan_model_streaming( + file_generator=file_generator(), + timeout=30, + delete_after_scan=False, + ) + + assert result.files_scanned == 1 + assert len(result.issues) == 1 + assert result.has_errors is False + assert determine_exit_code(result) == 1 def test_scan_model_streaming_content_hash_deterministic(): From 81229c09431e658b632af59d82633e9c826cc1e6 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Mon, 16 Mar 2026 02:46:11 -0700 Subject: [PATCH 7/7] fix: harden numpy recursion and local streaming --- CHANGELOG.md | 1 + modelaudit/cli.py | 5 +- modelaudit/core.py | 71 ++++++++++++++------------- modelaudit/scanners/numpy_scanner.py | 16 +----- modelaudit/scanners/pickle_scanner.py | 4 +- tests/scanners/test_numpy_scanner.py | 30 +++++++++++ tests/test_cli.py | 4 +- tests/test_streaming_scan.py | 32 ++++++++++++ 8 files changed, 109 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dfdd3c6a2..7c8282b31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- **cli:** preserve original local files during `--stream` directory scans instead of unlinking them after analysis - **security:** recurse into object-dtype `.npy` payloads and `.npz` object members with the pickle scanner while preserving CVE-2019-6446 warnings and archive-member context - **security:** remove `dill.load` / `dill.loads` from the pickle safe-global allowlist so recursive dill deserializers stay flagged as dangerous loader entry points - **security:** add exact dangerous helper coverage for validated torch and NumPy refs such as `numpy.f2py.crackfortran.getlincoef`, `torch._dynamo.guards.GuardBuilder.get`, and `torch.utils.collect_env.run` diff --git a/modelaudit/cli.py b/modelaudit/cli.py index 609fcfa6b..27d666f08 100644 --- a/modelaudit/cli.py +++ b/modelaudit/cli.py @@ -1683,11 +1683,12 @@ def enhanced_progress_callback(message, percentage): # Create file iterator file_generator = iterate_files_streaming(actual_path) - # Scan with streaming mode - propagate all config + # Scan with streaming mode - propagate all config. + # Local files already live on disk, so preserve the originals. streaming_result = scan_model_streaming( file_generator=file_generator, timeout=final_timeout, - delete_after_scan=True, # Delete files after scanning in streaming mode + delete_after_scan=False, progress_callback=progress_callback, blacklist_patterns=list(blacklist) if blacklist else None, max_file_size=final_max_file_size, diff --git a/modelaudit/core.py b/modelaudit/core.py index 05134cba8..3b57944fc 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -45,6 +45,35 @@ logger = logging.getLogger("modelaudit.core") +OPERATIONAL_ERROR_INDICATORS = ( + "Error during scan", + "Error checking file size", + "Error scanning file", + "Scanner crashed", + "Scan timeout", + "Path does not exist", + "Path is not readable", + "Permission denied", + "File not found", + "not installed, cannot scan", + "Missing dependency", + "Import error", + "Module not found", + "not a valid", + "Invalid file format", + "Corrupted file", + "Bad file signature", + "Unable to parse", + "Out of memory", + "Disk space", + "Too many open files", +) + + +def _has_operational_error_message(message: Any) -> bool: + """Return True when an issue message reflects an operational scan failure.""" + return isinstance(message, str) and any(indicator in message for indicator in OPERATIONAL_ERROR_INDICATORS) + def _to_telemetry_severity(severity: Any) -> str: """Normalize severity values to stable telemetry strings.""" @@ -1033,39 +1062,10 @@ def scan_model_directory_or_file( # Determine if there were operational scan errors vs security findings # has_errors should only be True for operational errors (scanner crashes, # file not found, etc.) not for security findings detected in models - operational_error_indicators = [ - # Scanner execution errors - "Error during scan", - "Error checking file size", - "Error scanning file", - "Scanner crashed", - "Scan timeout", - # File system errors - "Path does not exist", - "Path is not readable", - "Permission denied", - "File not found", - # Dependency/environment errors - "not installed, cannot scan", - "Missing dependency", - "Import error", - "Module not found", - # File format/corruption errors - "not a valid", - "Invalid file format", - "Corrupted file", - "Bad file signature", - "Unable to parse", - # Resource/system errors - "Out of memory", - "Disk space", - "Too many open files", - ] - # Check for operational errors in issues results.has_errors = ( any( - any(indicator in issue.message for indicator in operational_error_indicators) + _has_operational_error_message(issue.message) for issue in results.issues if issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} ) @@ -1595,6 +1595,9 @@ def scan_model_streaming( if scan_result: metadata_dict = dict(scan_result.metadata or {}) metadata_dict.setdefault("file_size", file_path.stat().st_size) + operational_scan_failure = any( + _has_operational_error_message(issue.message) for issue in (scan_result.issues or []) + ) existing_hashes = metadata_dict.get("file_hashes") if isinstance(existing_hashes, dict): @@ -1606,10 +1609,10 @@ def scan_model_streaming( scan_result_dict = { "bytes_scanned": scan_result.bytes_scanned, "files_scanned": 1, # Each scan_result represents one file - # ScanResult.has_errors means "critical findings", but - # ModelAuditResultModel.has_errors is reserved for - # operational scan failures. - "has_errors": not scan_result.success, + # Preserve the main scan semantics: success=False does not + # imply an operational error when the scanner completed + # and only reported informational integrity findings. + "has_errors": operational_scan_failure, "success": scan_result.success, "issues": [issue.__dict__ for issue in (scan_result.issues or [])], "checks": [check.__dict__ for check in (scan_result.checks or [])], diff --git a/modelaudit/scanners/numpy_scanner.py b/modelaudit/scanners/numpy_scanner.py index a0e238a85..fa4bd7785 100644 --- a/modelaudit/scanners/numpy_scanner.py +++ b/modelaudit/scanners/numpy_scanner.py @@ -2,7 +2,6 @@ from __future__ import annotations -import pickletools import sys import warnings from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar @@ -101,19 +100,6 @@ def _scan_embedded_pickle_payload( pickle_scanner.current_file_path = context_path return pickle_scanner._scan_pickle_bytes(file_obj, payload_size) - def _find_embedded_pickle_end(self, file_obj: BinaryIO) -> int | None: - """Return the absolute file offset of the first embedded pickle STOP opcode.""" - start_pos = file_obj.tell() - try: - for opcode, _arg, _pos in pickletools.genops(file_obj): - if opcode.name == "STOP": - return file_obj.tell() - except Exception: - return None - finally: - file_obj.seek(start_pos) - return None - def _validate_dtype(self, dtype: Any) -> None: """Validate numpy dtype for security""" # Check for problematic data types @@ -284,7 +270,6 @@ def scan(self, path: str) -> ScanResult: # object fields; kind=="O" catches plain object arrays. has_object_dtype = dtype.kind == "O" or bool(getattr(dtype, "hasobject", False)) if has_object_dtype: - pickle_end_offset = self._find_embedded_pickle_end(f) result.add_check( name=f"{self.CVE_2019_6446_ID}: Object Dtype Pickle Deserialization", passed=False, @@ -336,6 +321,7 @@ def scan(self, path: str) -> ScanResult: result.issues.extend(embedded_result.issues) result.checks.extend(embedded_result.checks) + pickle_end_offset = embedded_result.metadata.get("first_pickle_end_pos") if isinstance(pickle_end_offset, int) and pickle_end_offset < file_size: trailing_bytes = file_size - pickle_end_offset result.add_check( diff --git a/modelaudit/scanners/pickle_scanner.py b/modelaudit/scanners/pickle_scanner.py index 491291169..5c2d65b47 100644 --- a/modelaudit/scanners/pickle_scanner.py +++ b/modelaudit/scanners/pickle_scanner.py @@ -4571,7 +4571,9 @@ def _scan_pickle_bytes(self, file_obj: BinaryIO, file_size: int) -> ScanResult: elif opcode.name == "STOP": current_stack_depth = 0 if first_pickle_end_pos is None: - first_pickle_end_pos = start_pos + pos + 1 + # pickletools reports absolute positions even when parsing + # starts from a non-zero file offset. + first_pickle_end_pos = pos + 1 # Store stack depth warnings for ML-context-aware processing later if current_stack_depth > base_stack_depth_limit: diff --git a/tests/scanners/test_numpy_scanner.py b/tests/scanners/test_numpy_scanner.py index 9e210c731..190b33451 100644 --- a/tests/scanners/test_numpy_scanner.py +++ b/tests/scanners/test_numpy_scanner.py @@ -1,3 +1,4 @@ +import zipfile from collections.abc import Callable from pathlib import Path from typing import Any @@ -157,6 +158,21 @@ def _inject_comment_token_into_npy_payload(path: Path) -> None: path.write_bytes(original[:data_offset] + patched) +def _inject_comment_token_into_npz_member(path: Path, member_name: str) -> None: + with zipfile.ZipFile(path, "r") as archive: + members = {info.filename: archive.read(info.filename) for info in archive.infolist()} + + member_path = path.parent / member_name + member_path.write_bytes(members[member_name]) + _inject_comment_token_into_npy_payload(member_path) + members[member_name] = member_path.read_bytes() + member_path.unlink() + + with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as archive: + for name, content in members.items(): + archive.writestr(name, content) + + def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path: Path) -> None: arr = np.array([_ExecPayload()], dtype=object) path = tmp_path / "malicious_object.npy" @@ -230,6 +246,20 @@ def test_object_dtype_numpy_comment_token_bypass_still_detected(tmp_path: Path) assert any("exec" in c.message.lower() for c in failed) +def test_object_npz_member_comment_token_bypass_still_detected(tmp_path: Path) -> None: + npz_path = tmp_path / "comment_token.npz" + np.savez(npz_path, payload=np.array([_ExecPayload()], dtype=object)) + _inject_comment_token_into_npz_member(npz_path, "payload.npy") + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) and "payload.npy" in str(c.location) for c in failed) + assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) + + def test_benign_object_dtype_numpy_no_nested_critical(tmp_path: Path) -> None: arr = np.array([{"k": "v"}, [1, 2, 3]], dtype=object) path = tmp_path / "benign_object.npy" diff --git a/tests/test_cli.py b/tests/test_cli.py index 9cbf1ce44..fc50b360b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1421,7 +1421,7 @@ def __reduce__(self): def test_exit_code_security_issues_streaming_local_directory(tmp_path: Path) -> None: - """Streaming local-directory scans should keep security findings as exit code 1.""" + """Streaming local-directory scans should keep security findings as exit code 1 without deleting originals.""" import pickle evil_pickle_path = tmp_path / "malicious.pkl" @@ -1437,7 +1437,7 @@ def __reduce__(self): result = runner.invoke(cli, ["scan", "--stream", "--format", "text", str(tmp_path)]) assert result.exit_code == 1, f"Expected exit code 1, got {result.exit_code}. Output: {result.output}" - assert not evil_pickle_path.exists() + assert evil_pickle_path.exists() def test_exit_code_scan_errors(tmp_path): diff --git a/tests/test_streaming_scan.py b/tests/test_streaming_scan.py index 8ed459631..27a8a45e8 100644 --- a/tests/test_streaming_scan.py +++ b/tests/test_streaming_scan.py @@ -143,6 +143,38 @@ def file_generator(): assert determine_exit_code(result) == 1 +def test_scan_model_streaming_informational_failed_scan_does_not_set_operational_errors( + temp_test_files: list[Path], +) -> None: + """Informational failed scans should not override security findings with exit code 2.""" + + def file_generator(): + yield (temp_test_files[0], False) + yield (temp_test_files[1], True) + + info_result = ScanResult(scanner_name="numpy") + info_result.add_issue( + "Object-dtype payload contains trailing bytes after the embedded pickle stream", + severity=IssueSeverity.INFO, + location="trailing.npy", + ) + info_result.finish(success=False) + + with patch("modelaudit.core.scan_file") as mock_scan: + mock_scan.side_effect = [info_result, create_mock_scan_result(with_critical_issue=True)] + + result = scan_model_streaming( + file_generator=file_generator(), + timeout=30, + delete_after_scan=False, + ) + + assert result.files_scanned == 2 + assert result.success is True + assert result.has_errors is False + assert determine_exit_code(result) == 1 + + def test_scan_model_streaming_content_hash_deterministic(): """Test that content hash is deterministic for same files.""" # Create two files with same content