diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fa6e9a9..d8680e46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- **cli:** preserve original local files during `--stream` directory scans instead of unlinking them after analysis +- **security:** recurse into object-dtype `.npy` payloads and `.npz` object members with the pickle scanner while preserving CVE-2019-6446 warnings and archive-member context - **security:** remove `dill.load` / `dill.loads` from the pickle safe-global allowlist so recursive dill deserializers stay flagged as dangerous loader entry points - **security:** add exact dangerous helper coverage for validated torch and NumPy refs such as `numpy.f2py.crackfortran.getlincoef`, `torch._dynamo.guards.GuardBuilder.get`, and `torch.utils.collect_env.run` - **security:** add exact dangerous-global coverage for `numpy.load`, `site.main`, `_io.FileIO`, `test.support.script_helper.assert_python_ok`, `_osx_support._read_output`, `_aix_support._read_cmd_output`, `_pyrepl.pager.pipe_pager`, `torch.serialization.load`, and `torch._inductor.codecache.compile_file` (9 PickleScan-only loader and execution primitives) diff --git a/modelaudit/cli.py b/modelaudit/cli.py index 609fcfa6..27d666f0 100644 --- a/modelaudit/cli.py +++ b/modelaudit/cli.py @@ -1683,11 +1683,12 @@ def enhanced_progress_callback(message, percentage): # Create file iterator file_generator = iterate_files_streaming(actual_path) - # Scan with streaming mode - propagate all config + # Scan with streaming mode - propagate all config. + # Local files already live on disk, so preserve the originals. streaming_result = scan_model_streaming( file_generator=file_generator, timeout=final_timeout, - delete_after_scan=True, # Delete files after scanning in streaming mode + delete_after_scan=False, progress_callback=progress_callback, blacklist_patterns=list(blacklist) if blacklist else None, max_file_size=final_max_file_size, diff --git a/modelaudit/core.py b/modelaudit/core.py index 1a2a7eff..3b57944f 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -45,6 +45,35 @@ logger = logging.getLogger("modelaudit.core") +OPERATIONAL_ERROR_INDICATORS = ( + "Error during scan", + "Error checking file size", + "Error scanning file", + "Scanner crashed", + "Scan timeout", + "Path does not exist", + "Path is not readable", + "Permission denied", + "File not found", + "not installed, cannot scan", + "Missing dependency", + "Import error", + "Module not found", + "not a valid", + "Invalid file format", + "Corrupted file", + "Bad file signature", + "Unable to parse", + "Out of memory", + "Disk space", + "Too many open files", +) + + +def _has_operational_error_message(message: Any) -> bool: + """Return True when an issue message reflects an operational scan failure.""" + return isinstance(message, str) and any(indicator in message for indicator in OPERATIONAL_ERROR_INDICATORS) + def _to_telemetry_severity(severity: Any) -> str: """Normalize severity values to stable telemetry strings.""" @@ -272,8 +301,12 @@ def _group_checks_by_asset(checks_list: list[Any]) -> dict[tuple[str, str], list check_name = check.get("name", "Unknown Check") location = check.get("location", "") primary_asset = _extract_primary_asset_from_location(location) + details = check.get("details") + zip_entry = details.get("zip_entry") if isinstance(details, dict) else None - group_key = (check_name, primary_asset) + asset_group = f"{primary_asset}:{zip_entry}" if isinstance(zip_entry, str) and zip_entry else primary_asset + + group_key = (check_name, asset_group) check_groups[group_key].append(check) return check_groups @@ -1029,39 +1062,10 @@ def scan_model_directory_or_file( # Determine if there were operational scan errors vs security findings # has_errors should only be True for operational errors (scanner crashes, # file not found, etc.) not for security findings detected in models - operational_error_indicators = [ - # Scanner execution errors - "Error during scan", - "Error checking file size", - "Error scanning file", - "Scanner crashed", - "Scan timeout", - # File system errors - "Path does not exist", - "Path is not readable", - "Permission denied", - "File not found", - # Dependency/environment errors - "not installed, cannot scan", - "Missing dependency", - "Import error", - "Module not found", - # File format/corruption errors - "not a valid", - "Invalid file format", - "Corrupted file", - "Bad file signature", - "Unable to parse", - # Resource/system errors - "Out of memory", - "Disk space", - "Too many open files", - ] - # Check for operational errors in issues results.has_errors = ( any( - any(indicator in issue.message for indicator in operational_error_indicators) + _has_operational_error_message(issue.message) for issue in results.issues if issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} ) @@ -1591,6 +1595,9 @@ def scan_model_streaming( if scan_result: metadata_dict = dict(scan_result.metadata or {}) metadata_dict.setdefault("file_size", file_path.stat().st_size) + operational_scan_failure = any( + _has_operational_error_message(issue.message) for issue in (scan_result.issues or []) + ) existing_hashes = metadata_dict.get("file_hashes") if isinstance(existing_hashes, dict): @@ -1602,10 +1609,10 @@ def scan_model_streaming( scan_result_dict = { "bytes_scanned": scan_result.bytes_scanned, "files_scanned": 1, # Each scan_result represents one file - # ScanResult.has_errors means "critical findings", but - # ModelAuditResultModel.has_errors is reserved for - # operational scan failures. - "has_errors": not scan_result.success, + # Preserve the main scan semantics: success=False does not + # imply an operational error when the scanner completed + # and only reported informational integrity findings. + "has_errors": operational_scan_failure, "success": scan_result.success, "issues": [issue.__dict__ for issue in (scan_result.issues or [])], "checks": [check.__dict__ for check in (scan_result.checks or [])], diff --git a/modelaudit/scanners/numpy_scanner.py b/modelaudit/scanners/numpy_scanner.py index 23e89597..fa4bd778 100644 --- a/modelaudit/scanners/numpy_scanner.py +++ b/modelaudit/scanners/numpy_scanner.py @@ -4,9 +4,10 @@ import sys import warnings -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar from .base import BaseScanner, IssueSeverity, ScanResult +from .pickle_scanner import PickleScanner # Import NumPy with compatibility handling try: @@ -88,6 +89,17 @@ def _validate_array_dimensions(self, shape: tuple[int, ...]) -> None: CVE_2019_6446_CVSS = 9.8 CVE_2019_6446_CWE = "CWE-502" + def _scan_embedded_pickle_payload( + self, + file_obj: BinaryIO, + payload_size: int, + context_path: str, + ) -> ScanResult: + """Reuse PickleScanner analysis for object-dtype NumPy payloads.""" + pickle_scanner = PickleScanner(config=self.config) + pickle_scanner.current_file_path = context_path + return pickle_scanner._scan_pickle_bytes(file_obj, payload_size) + def _validate_dtype(self, dtype: Any) -> None: """Validate numpy dtype for security""" # Check for problematic data types @@ -256,7 +268,8 @@ def scan(self, path: str) -> ScanResult: # enabling arbitrary code execution. # dtype.hasobject catches structured dtypes with # object fields; kind=="O" catches plain object arrays. - if dtype.kind == "O" or bool(getattr(dtype, "hasobject", False)): + has_object_dtype = dtype.kind == "O" or bool(getattr(dtype, "hasobject", False)) + if has_object_dtype: result.add_check( name=f"{self.CVE_2019_6446_ID}: Object Dtype Pickle Deserialization", passed=False, @@ -299,6 +312,60 @@ def scan(self, path: str) -> ScanResult: ), ) + f.seek(data_offset) + embedded_result = self._scan_embedded_pickle_payload( + f, + file_size - data_offset, + path, + ) + result.issues.extend(embedded_result.issues) + result.checks.extend(embedded_result.checks) + + pickle_end_offset = embedded_result.metadata.get("first_pickle_end_pos") + if isinstance(pickle_end_offset, int) and pickle_end_offset < file_size: + trailing_bytes = file_size - pickle_end_offset + result.add_check( + name="File Integrity Check", + passed=False, + message=( + "Object-dtype payload contains trailing bytes after the embedded pickle stream" + ), + severity=IssueSeverity.INFO, + location=path, + rule_code="S902", + details={ + "expected_pickle_end": pickle_end_offset, + "actual_size": file_size, + "trailing_bytes": trailing_bytes, + "dtype": str(dtype), + }, + ) + result.finish(success=False) + return result + + # Object-dtype .npy payloads are stored as a pickle stream rather than + # fixed-width element data, so the numeric dtype/size validation path + # is not applicable after we recurse into the embedded pickle payload. + result.add_check( + name="Data Type Safety Check", + passed=True, + message=f"Object dtype '{dtype}' handled via recursive pickle analysis", + location=path, + rule_code=None, + details={ + "dtype": str(dtype), + "dtype_kind": dtype.kind, + "handled_via": "embedded_pickle_scan", + "cve_id": self.CVE_2019_6446_ID, + }, + ) + result.bytes_scanned = file_size + result.metadata.update( + {"shape": shape, "dtype": str(dtype), "fortran_order": fortran}, + ) + result.finish(success=True) + return result + self._validate_dtype(dtype) result.add_check( name="Data Type Safety Check", diff --git a/modelaudit/scanners/pickle_scanner.py b/modelaudit/scanners/pickle_scanner.py index 108f37ff..9cc3af66 100644 --- a/modelaudit/scanners/pickle_scanner.py +++ b/modelaudit/scanners/pickle_scanner.py @@ -4428,12 +4428,14 @@ def _scan_pickle_bytes(self, file_obj: BinaryIO, file_size: int) -> ScanResult: suspicious_count = 0 # For large files, use chunked reading to avoid memory issues - MAX_MEMORY_READ = 50 * 1024 * 1024 # 50MB max in memory at once + MAX_MEMORY_READ = 10 * 1024 * 1024 # 10MB max in memory at once current_pos = file_obj.tell() - # Read file data - either all at once for small files or first chunk for large files - # For large files, read first 50MB for pattern analysis (critical malicious code is usually at the beginning) + # Read file data - either all at once for small files or first chunk for large files. + # For large files, read only the first 10MB for pattern analysis to cap + # embedded-pickle memory usage while still inspecting the most security- + # relevant prefix. file_data = file_obj.read() if file_size <= MAX_MEMORY_READ else file_obj.read(MAX_MEMORY_READ) file_obj.seek(current_pos) # Reset position @@ -4629,7 +4631,9 @@ def _scan_pickle_bytes(self, file_obj: BinaryIO, file_size: int) -> ScanResult: elif opcode.name == "STOP": current_stack_depth = 0 if first_pickle_end_pos is None: - first_pickle_end_pos = start_pos + pos + 1 + # pickletools reports absolute positions even when parsing + # starts from a non-zero file offset. + first_pickle_end_pos = pos + 1 # Store stack depth warnings for ML-context-aware processing later if current_stack_depth > base_stack_depth_limit: diff --git a/modelaudit/scanners/zip_scanner.py b/modelaudit/scanners/zip_scanner.py index 8024a14c..dad432cb 100644 --- a/modelaudit/scanners/zip_scanner.py +++ b/modelaudit/scanners/zip_scanner.py @@ -117,6 +117,44 @@ def scan(self, path: str) -> ScanResult: result.metadata["file_size"] = os.path.getsize(path) return result + def _rewrite_nested_result_context( + self, scan_result: ScanResult, tmp_path: str, archive_path: str, entry_name: str + ) -> None: + """Rewrite nested result locations so archive members, not temp files, are reported.""" + archive_location = f"{archive_path}:{entry_name}" + + for issue in scan_result.issues: + if issue.location: + if issue.location.startswith(tmp_path): + issue.location = issue.location.replace(tmp_path, archive_location, 1) + else: + issue.location = f"{archive_location} {issue.location}" + else: + issue.location = archive_location + + existing_issue_entry = issue.details.get("zip_entry") + issue.details["zip_entry"] = ( + f"{entry_name}:{existing_issue_entry}" + if isinstance(existing_issue_entry, str) and existing_issue_entry + else entry_name + ) + + for check in scan_result.checks: + if check.location: + if check.location.startswith(tmp_path): + check.location = check.location.replace(tmp_path, archive_location, 1) + else: + check.location = f"{archive_location} {check.location}" + else: + check.location = archive_location + + existing_check_entry = check.details.get("zip_entry") + check.details["zip_entry"] = ( + f"{entry_name}:{existing_check_entry}" + if isinstance(existing_check_entry, str) and existing_check_entry + else entry_name + ) + def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult: """Recursively scan a ZIP file and its contents""" result = ScanResult(scanner_name=self.name) @@ -319,16 +357,7 @@ def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult: if name.lower().endswith(".zip"): try: nested_result = self._scan_zip_file(tmp_path, depth + 1) - # Update locations in nested results - for issue in nested_result.issues: - if issue.location and issue.location.startswith( - tmp_path, - ): - issue.location = issue.location.replace( - tmp_path, - f"{path}:{name}", - 1, - ) + self._rewrite_nested_result_context(nested_result, tmp_path, path, name) result.merge(nested_result) asset_entry = asset_from_scan_result( @@ -355,26 +384,7 @@ def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult: # Use core.scan_file to scan with appropriate scanner file_result = core.scan_file(tmp_path, self.config) - - # Update locations in file results - for issue in file_result.issues: - if issue.location: - if issue.location.startswith(tmp_path): - issue.location = issue.location.replace( - tmp_path, - f"{path}:{name}", - 1, - ) - else: - issue.location = f"{path}:{name} {issue.location}" - else: - issue.location = f"{path}:{name}" - - # Add zip entry name to details - if issue.details: - issue.details["zip_entry"] = name - else: - issue.details = {"zip_entry": name} + self._rewrite_nested_result_context(file_result, tmp_path, path, name) result.merge(file_result) diff --git a/tests/scanners/test_numpy_scanner.py b/tests/scanners/test_numpy_scanner.py index 4ef037e4..190b3345 100644 --- a/tests/scanners/test_numpy_scanner.py +++ b/tests/scanners/test_numpy_scanner.py @@ -1,6 +1,11 @@ +import zipfile +from collections.abc import Callable +from pathlib import Path +from typing import Any + import numpy as np -from modelaudit.scanners.base import IssueSeverity +from modelaudit.scanners.base import Check, IssueSeverity, ScanResult from modelaudit.scanners.numpy_scanner import NumPyScanner @@ -42,10 +47,14 @@ def test_object_dtype_triggers_cve(self, tmp_path): scanner = NumPyScanner() result = scanner.scan(str(path)) + assert result.success is True cve_checks = [c for c in result.checks if "CVE-2019-6446" in c.name or "CVE-2019-6446" in c.message] assert len(cve_checks) > 0, f"Should detect CVE-2019-6446. Checks: {[c.message for c in result.checks]}" assert cve_checks[0].severity == IssueSeverity.WARNING assert cve_checks[0].details.get("cve_id") == "CVE-2019-6446" + assert not any(c.name == "Data Type Safety Check" and c.status.value == "failed" for c in result.checks), ( + f"Object dtype should not be treated as a scan failure: {[c.message for c in result.checks]}" + ) def test_numeric_dtype_no_cve(self, tmp_path): """Numeric dtype arrays should not trigger CVE-2019-6446.""" @@ -98,5 +107,227 @@ def test_structured_with_object_field_triggers_cve(self, tmp_path): scanner = NumPyScanner() result = scanner.scan(str(path)) + assert result.success is True cve_checks = [c for c in result.checks if "CVE-2019-6446" in (c.name + c.message)] assert len(cve_checks) > 0, "Structured dtype with object field should trigger CVE" + + +class _ExecPayload: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: + return (exec, ("print('owned')",)) + + +class _SSLPayload: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: + import ssl + + return (ssl.get_server_certificate, (("example.com", 443),)) + + +def _failed_checks(result: ScanResult) -> list[Check]: + return [c for c in result.checks if c.status.value == "failed"] + + +def _inject_comment_token_into_npy_payload(path: Path) -> None: + with path.open("rb") as handle: + major, minor = np.lib.format.read_magic(handle) + if (major, minor) == (1, 0): + np.lib.format.read_array_header_1_0(handle) + elif (major, minor) == (2, 0): + np.lib.format.read_array_header_2_0(handle) + else: + read_array_header = getattr(np.lib.format, "_read_array_header", None) + if read_array_header is None: + raise AssertionError(f"Unsupported NumPy header version: {(major, minor)}") + read_array_header(handle, version=(major, minor)) + data_offset = handle.tell() + payload = handle.read() + + if len(payload) < 2 or payload[0] != 0x80: + raise AssertionError(f"Unexpected pickle payload header: {payload[:4]!r}") + + protocol = payload[1] + comment = b"# harmless note" + if protocol >= 4: + comment_op = b"\x8c" + bytes([len(comment)]) + comment + else: + comment_op = b"X" + len(comment).to_bytes(4, "little") + comment + + patched = payload[:2] + comment_op + b"0" + payload[2:] + original = path.read_bytes() + path.write_bytes(original[:data_offset] + patched) + + +def _inject_comment_token_into_npz_member(path: Path, member_name: str) -> None: + with zipfile.ZipFile(path, "r") as archive: + members = {info.filename: archive.read(info.filename) for info in archive.infolist()} + + member_path = path.parent / member_name + member_path.write_bytes(members[member_name]) + _inject_comment_token_into_npy_payload(member_path) + members[member_name] = member_path.read_bytes() + member_path.unlink() + + with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as archive: + for name, content in members.items(): + archive.writestr(name, content) + + +def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path: Path) -> None: + arr = np.array([_ExecPayload()], dtype=object) + path = tmp_path / "malicious_object.npy" + np.save(path, arr, allow_pickle=True) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + assert result.success is True + assert result.has_errors is True + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) for c in failed) + assert any("exec" in (c.message.lower()) for c in failed) + + +def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path: Path) -> None: + arr = np.array([_SSLPayload()], dtype=object) + path = tmp_path / "malicious_ssl_object.npy" + np.save(path, arr, allow_pickle=True) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + assert result.success is True + assert result.has_errors is True + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) for c in failed) + assert any("ssl.get_server_certificate" in c.message for c in failed) + + +def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path: Path) -> None: + npz_path = tmp_path / "numeric_only.npz" + np.savez(npz_path, a=np.arange(4), b=np.ones((2, 2), dtype=np.float32)) + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + assert not any("CVE-2019-6446" in (c.name + c.message) for c in result.checks) + assert not any("exec" in c.message.lower() for c in result.checks) + assert not any(i.details.get("cve_id") == "CVE-2019-6446" for i in result.issues) + assert not any("exec" in i.message.lower() for i in result.issues) + + +def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_path: Path) -> None: + safe = np.array([1, 2, 3], dtype=np.int64) + malicious = np.array([_ExecPayload()], dtype=object) + npz_path = tmp_path / "mixed_object.npz" + np.savez(npz_path, safe=safe, payload=malicious) + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) and "payload.npy" in str(c.location) for c in failed) + assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) + + +def test_object_dtype_numpy_comment_token_bypass_still_detected(tmp_path: Path) -> None: + arr = np.array([_ExecPayload()], dtype=object) + path = tmp_path / "comment_token.npy" + np.save(path, arr, allow_pickle=True) + _inject_comment_token_into_npy_payload(path) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) for c in failed) + assert any("exec" in c.message.lower() for c in failed) + + +def test_object_npz_member_comment_token_bypass_still_detected(tmp_path: Path) -> None: + npz_path = tmp_path / "comment_token.npz" + np.savez(npz_path, payload=np.array([_ExecPayload()], dtype=object)) + _inject_comment_token_into_npz_member(npz_path, "payload.npy") + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + failed = _failed_checks(result) + assert any("CVE-2019-6446" in (c.name + c.message) and "payload.npy" in str(c.location) for c in failed) + assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) + + +def test_benign_object_dtype_numpy_no_nested_critical(tmp_path: Path) -> None: + arr = np.array([{"k": "v"}, [1, 2, 3]], dtype=object) + path = tmp_path / "benign_object.npy" + np.save(path, arr, allow_pickle=True) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + assert result.success is True + assert result.has_errors is False + assert any("CVE-2019-6446" in (c.name + c.message) for c in result.checks) + assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues if "CVE-2019-6446" not in i.message) + + +def test_benign_object_dtype_npz_no_nested_critical(tmp_path: Path) -> None: + npz_path = tmp_path / "benign_object.npz" + np.savez(npz_path, safe=np.array([{"x": 1}], dtype=object)) + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + assert any("CVE-2019-6446" in (c.name + c.message) for c in result.checks) + assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues) + + +def test_truncated_npy_fails_safely(tmp_path: Path) -> None: + arr = np.array([_ExecPayload()], dtype=object) + path = tmp_path / "truncated.npy" + np.save(path, arr, allow_pickle=True) + path.write_bytes(path.read_bytes()[:-8]) + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + assert result.success is True + assert result.has_errors is False + assert any( + i.severity in {IssueSeverity.INFO, IssueSeverity.WARNING} and "corrupted pickle" in i.message.lower() + for i in result.issues + ), f"Expected a non-critical corruption finding, got: {[i.message for i in result.issues]}" + + +def test_object_dtype_numpy_trailing_bytes_fail_integrity(tmp_path: Path) -> None: + arr = np.array([{"k": "v"}], dtype=object) + path = tmp_path / "trailing.npy" + np.save(path, arr, allow_pickle=True) + path.write_bytes(path.read_bytes() + b"TRAILINGJUNK") + + scanner = NumPyScanner() + result = scanner.scan(str(path)) + + assert result.success is False + assert any( + check.name == "File Integrity Check" + and check.status.value == "failed" + and "trailing bytes" in check.message.lower() + for check in result.checks + ), f"Expected trailing-byte integrity failure, got: {[c.message for c in result.checks]}" + + +def test_corrupted_npz_fails_safely(tmp_path: Path) -> None: + npz_path = tmp_path / "corrupt.npz" + npz_path.write_bytes(b"not-a-zip") + + from modelaudit.scanners.zip_scanner import ZipScanner + + result = ZipScanner().scan(str(npz_path)) + + assert result.success is False + assert any(i.severity == IssueSeverity.INFO for i in result.issues) diff --git a/tests/scanners/test_zip_scanner.py b/tests/scanners/test_zip_scanner.py index 6c1ff5f7..7a4db65b 100644 --- a/tests/scanners/test_zip_scanner.py +++ b/tests/scanners/test_zip_scanner.py @@ -1,7 +1,9 @@ import os import tempfile import zipfile +from collections.abc import Callable from pathlib import Path +from typing import Any from modelaudit.scanners.base import IssueSeverity from modelaudit.scanners.zip_scanner import ZipScanner @@ -258,7 +260,7 @@ def test_scan_zip_with_dangerous_pickle(self): import pickle class DangerousClass: - def __reduce__(self): + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: return (os_module.system, ("echo pwned",)) dangerous_obj = DangerousClass() @@ -314,6 +316,61 @@ def test_scan_zip_with_prefixed_proto0_pickle_disguised_as_text(self, tmp_path: f"Expected critical os/posix.system issue, got: {critical_messages}" ) + def test_scan_npz_with_object_member_recurses_into_pickle(self, tmp_path: Path) -> None: + import numpy as np + + class _ExecPayload: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: + return (exec, ("print('owned')",)) + + archive_path = tmp_path / "payload.npz" + np.savez(archive_path, safe=np.arange(3), payload=np.array([_ExecPayload()], dtype=object)) + + result = self.scanner.scan(str(archive_path)) + assert result.success is True + + failed_checks = [c for c in result.checks if c.status.value == "failed"] + assert any("cve-2019-6446" in (c.name + c.message).lower() for c in failed_checks) + assert any( + c.details.get("zip_entry") == "payload.npy" and c.location == f"{archive_path}:payload.npy" + for c in failed_checks + ), f"Expected rewritten check context for payload.npy, got: {[(c.location, c.details) for c in failed_checks]}" + assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues) + + def test_scan_outer_zip_preserves_nested_npz_member_context(self, tmp_path: Path) -> None: + import numpy as np + + class _ExecPayload: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: + return (exec, ("print('owned')",)) + + inner_npz = tmp_path / "inner.npz" + np.savez( + inner_npz, + payload_a=np.array([_ExecPayload()], dtype=object), + payload_b=np.array([_ExecPayload()], dtype=object), + ) + + archive_path = tmp_path / "outer.zip" + with zipfile.ZipFile(archive_path, "w") as zf: + zf.write(inner_npz, arcname="inner.npz") + + result = self.scanner.scan(str(archive_path)) + failed_checks = [c for c in result.checks if c.status.value == "failed"] + + assert any( + c.details.get("zip_entry") == "inner.npz:payload_a.npy" + and c.location + and f"{archive_path}:inner.npz:payload_a.npy" in c.location + for c in failed_checks + ) + assert any( + c.details.get("zip_entry") == "inner.npz:payload_b.npy" + and c.location + and f"{archive_path}:inner.npz:payload_b.npy" in c.location + for c in failed_checks + ) + def test_scan_zip_with_plain_text_global_prefix_not_treated_as_pickle(self, tmp_path: Path) -> None: """Plain text entries that start with GLOBAL-like bytes should not trigger pickle parse warnings.""" archive_path = tmp_path / "plain_text_payload.zip" diff --git a/tests/test_cli.py b/tests/test_cli.py index 9cbf1ce4..fc50b360 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1421,7 +1421,7 @@ def __reduce__(self): def test_exit_code_security_issues_streaming_local_directory(tmp_path: Path) -> None: - """Streaming local-directory scans should keep security findings as exit code 1.""" + """Streaming local-directory scans should keep security findings as exit code 1 without deleting originals.""" import pickle evil_pickle_path = tmp_path / "malicious.pkl" @@ -1437,7 +1437,7 @@ def __reduce__(self): result = runner.invoke(cli, ["scan", "--stream", "--format", "text", str(tmp_path)]) assert result.exit_code == 1, f"Expected exit code 1, got {result.exit_code}. Output: {result.output}" - assert not evil_pickle_path.exists() + assert evil_pickle_path.exists() def test_exit_code_scan_errors(tmp_path): diff --git a/tests/test_core_asset_extraction.py b/tests/test_core_asset_extraction.py index 532dd7fd..d7656584 100644 --- a/tests/test_core_asset_extraction.py +++ b/tests/test_core_asset_extraction.py @@ -1,9 +1,13 @@ import ntpath import pickle import sys +import zipfile +from collections.abc import Callable from pathlib import Path +from typing import Any from unittest.mock import patch +import numpy as np import pytest from modelaudit.core import _extract_primary_asset_from_location, scan_model_directory_or_file @@ -72,3 +76,74 @@ def test_check_consolidation_handles_newlines_in_file_paths(tmp_path: Path) -> N expected_paths = {str(group / "dup a.pkl"), str(group / "dup b.pkl")} assert set(check.details["duplicate_files"]) == expected_paths assert check.location in expected_paths + + +def test_npz_member_checks_keep_archive_member_locations(tmp_path: Path) -> None: + class _ExecPayload: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: + return (exec, ("print('owned')",)) + + archive_path = tmp_path / "payload.npz" + np.savez(archive_path, safe=np.arange(3), payload=np.array([_ExecPayload()], dtype=object)) + + result = scan_model_directory_or_file(str(archive_path)) + payload_checks = [ + check + for check in result.checks + if check.status.value == "failed" + and (check.details.get("zip_entry") == "payload.npy" or ":payload.npy" in (check.location or "")) + ] + + assert any( + check.location == f"{archive_path}:payload.npy" and check.details.get("zip_entry") == "payload.npy" + for check in payload_checks + ), f"Expected archive-member check location, got: {[(c.location, c.details) for c in payload_checks]}" + assert not any(check.location and not check.location.startswith(f"{archive_path}:") for check in payload_checks), ( + f"Unexpected non-archive check locations: {[c.location for c in payload_checks]}" + ) + + +def test_check_consolidation_keeps_distinct_npz_member_findings(tmp_path: Path) -> None: + class ExecPayload: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: + return (exec, ("print('owned')",)) + + archive_path = tmp_path / "payload.npz" + np.savez(archive_path, safe=np.arange(3), payload=np.array([ExecPayload()], dtype=object)) + + result = scan_model_directory_or_file(str(archive_path)) + cve_checks = [ + check for check in result.checks if check.name.startswith("CVE-2019-6446") and check.status.value == "failed" + ] + + assert len(cve_checks) == 1 + assert cve_checks[0].location == f"{archive_path}:payload.npy" + assert cve_checks[0].details.get("zip_entry") == "payload.npy" + + +def test_check_consolidation_keeps_nested_npz_member_findings_distinct(tmp_path: Path) -> None: + class ExecPayload: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]: + return (exec, ("print('owned')",)) + + inner_npz = tmp_path / "inner.npz" + np.savez( + inner_npz, + payload_a=np.array([ExecPayload()], dtype=object), + payload_b=np.array([ExecPayload()], dtype=object), + ) + + outer_zip = tmp_path / "outer.zip" + with zipfile.ZipFile(outer_zip, "w") as zf: + zf.write(inner_npz, arcname="inner.npz") + + result = scan_model_directory_or_file(str(outer_zip)) + cve_checks = [ + check for check in result.checks if check.name.startswith("CVE-2019-6446") and check.status.value == "failed" + ] + + assert len(cve_checks) == 2 + assert {check.details.get("zip_entry") for check in cve_checks} == { + "inner.npz:payload_a.npy", + "inner.npz:payload_b.npy", + } diff --git a/tests/test_security_enhancements.py b/tests/test_security_enhancements.py index a3e21e8a..24e32dd4 100644 --- a/tests/test_security_enhancements.py +++ b/tests/test_security_enhancements.py @@ -6,6 +6,7 @@ import pickle import zipfile import zlib +from pathlib import Path import numpy as np @@ -270,28 +271,20 @@ def test_dimension_size_limit(self, tmp_path): size_issues = [issue for issue in result.issues if "too large" in issue.message.lower()] assert len(size_issues) > 0 - def test_dangerous_dtype_rejection(self, tmp_path): - """Test rejection of dangerous data types.""" + def test_dangerous_dtype_reports_cve_warning(self, tmp_path: Path) -> None: + """Object dtype arrays should scan successfully while emitting CVE-2019-6446 warnings.""" scanner = NumPyScanner() - - # Create numpy file with object dtype manually npy_file = tmp_path / "object_dtype.npy" - - with open(npy_file, "wb") as f: - f.write(b"\x93NUMPY") # Magic - f.write(b"\x01\x00") # Version 1.0 - header = "{'descr': '|O', 'fortran_order': False, 'shape': (10,), }" - header_len = len(header) - f.write(header_len.to_bytes(2, "little")) - f.write(header.encode("latin1")) - # Add some dummy data - f.write(b"\x00" * 80) # 10 * 8 bytes per object pointer + np.save(npy_file, np.array([{"key": "value"}], dtype=object), allow_pickle=True) result = scanner.scan(str(npy_file)) - assert result.success is False - dtype_issues = [issue for issue in result.issues if "dangerous dtype" in issue.message.lower()] - assert len(dtype_issues) > 0 + assert result.success is True + cve_issues = [issue for issue in result.issues if "CVE-2019-6446" in issue.message] + assert len(cve_issues) > 0 + assert not any( + check.name == "Data Type Safety Check" and check.status.value == "failed" for check in result.checks + ) def test_array_size_overflow_protection(self, tmp_path): """Test protection against integer overflow in size calculation.""" diff --git a/tests/test_streaming_scan.py b/tests/test_streaming_scan.py index 8ed45963..27a8a45e 100644 --- a/tests/test_streaming_scan.py +++ b/tests/test_streaming_scan.py @@ -143,6 +143,38 @@ def file_generator(): assert determine_exit_code(result) == 1 +def test_scan_model_streaming_informational_failed_scan_does_not_set_operational_errors( + temp_test_files: list[Path], +) -> None: + """Informational failed scans should not override security findings with exit code 2.""" + + def file_generator(): + yield (temp_test_files[0], False) + yield (temp_test_files[1], True) + + info_result = ScanResult(scanner_name="numpy") + info_result.add_issue( + "Object-dtype payload contains trailing bytes after the embedded pickle stream", + severity=IssueSeverity.INFO, + location="trailing.npy", + ) + info_result.finish(success=False) + + with patch("modelaudit.core.scan_file") as mock_scan: + mock_scan.side_effect = [info_result, create_mock_scan_result(with_critical_issue=True)] + + result = scan_model_streaming( + file_generator=file_generator(), + timeout=30, + delete_after_scan=False, + ) + + assert result.files_scanned == 2 + assert result.success is True + assert result.has_errors is False + assert determine_exit_code(result) == 1 + + def test_scan_model_streaming_content_hash_deterministic(): """Test that content hash is deterministic for same files.""" # Create two files with same content