diff --git a/CHANGELOG.md b/CHANGELOG.md index 71a1b90e..830004ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,9 +124,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **security:** detect CVE-2022-25882 ONNX external_data path traversal with CVE attribution, CVSS score, and CWE classification in scan results - **security:** detect CVE-2024-27318 ONNX nested external_data path traversal bypass via path segment sanitization evasion - **security:** restore ZIP scanner fallback for invalid `.mar` archives so malicious ZIP payloads renamed to `.mar` cannot bypass archive checks - - **security:** flag risky import-only pickle references for `torch.jit`, `torch._dynamo`, `torch._inductor`, `torch.compile`, `torch.storage._load_from_bytes`, `numpy.f2py`, and `numpy.distutils` while preserving safe state-dict reconstruction paths - **security:** add low-severity pickle structural tamper findings for duplicate or misplaced `PROTO` opcodes while avoiding benign binary-tail false positives +- **security:** stop treating mixed-case valid pickle module names as implausible, so import and reduce checks no longer bypass on names like `PIL` or attacker-chosen `EvilPkg` - **security:** scan OCI layer members based on registered file extensions so embedded ONNX, Keras H5, and other real-path scanners are no longer skipped inside tar layers - **security:** resolve bare-module TorchServe handler references like `custom_handler` to concrete archive members so malicious handler source is no longer skipped by static analysis - **security:** compare archive entry paths against the intended extraction root without following base-directory symlinks diff --git a/modelaudit/scanners/pickle_scanner.py b/modelaudit/scanners/pickle_scanner.py index 59238a41..99f247f8 100644 --- a/modelaudit/scanners/pickle_scanner.py +++ b/modelaudit/scanners/pickle_scanner.py @@ -1970,27 +1970,17 @@ def _is_plausible_python_module(name: str) -> bool: Check whether *name* looks like a real Python module/package path. Legitimate module names follow Python identifier rules: - - Each dotted segment is a valid Python identifier (letters, digits, - underscores; cannot start with a digit). - - Conventionally all-lowercase, though private/internal modules may use - a leading underscore. - - Names that contain uppercase letters, start with digits, or include - characters outside ``[a-z0-9_.]`` are almost certainly **not** real - modules -- they are more likely DataFrame column names, user labels, - or other data strings that ended up as pickle GLOBAL arguments - (e.g. ``PEDRA_2020``). - - The check is intentionally conservative: a handful of legitimate but - unusual module names (e.g. ``PIL``, ``Cython``) are covered by - ``ML_SAFE_GLOBALS`` and will pass the allowlist before this function - is ever consulted. + - Each dotted segment is an ASCII Python identifier. + - Segments normally contain lowercase characters, with a short explicit + allowlist for case-sensitive imports such as ``PIL``. + + Keep obviously malformed names rejected so arbitrary data strings are less + likely to be treated as imports, while still allowing valid mixed-case + segments such as ``EvilPkg`` and ``MyOrg.InternalPkg``. Returns: True if *name* plausibly refers to a real Python module. """ - import re - if not name: return False @@ -1998,17 +1988,15 @@ def _is_plausible_python_module(name: str) -> bool: if " " in name or "\t" in name: return False - # Split on dots; each segment must be a valid Python identifier that - # looks like a conventional module name (lowercase + digits + _). + # Split on dots; each segment must be an ASCII Python identifier. segments = name.split(".") - if not segments or any(s == "" for s in segments): + if not segments or any(s == "" or not s.isascii() or not s.isidentifier() for s in segments): return False - _MODULE_SEGMENT_RE = re.compile(r"^[a-z_][a-z0-9_]*$") - return all(_MODULE_SEGMENT_RE.match(seg) for seg in segments) + return all(any(char.islower() for char in seg) or seg in _CASE_SENSITIVE_IMPORT_SEGMENTS for seg in segments) -_CASE_SENSITIVE_IMPORT_SEGMENTS = frozenset({"PIL", "Cython"}) +_CASE_SENSITIVE_IMPORT_SEGMENTS: frozenset[str] = frozenset({"PIL", "Cython"}) IMPORT_ONLY_ALWAYS_DANGEROUS_GLOBALS = frozenset( { ("dill", "load"), @@ -2079,16 +2067,7 @@ def _is_resolved_import_target(mod: str, func: str) -> bool: def _is_plausible_import_only_module(mod: str) -> bool: """Return True when a module path looks importable without matching common data labels.""" - if not mod: - return False - - segments = mod.split(".") - if not segments or any(segment == "" or not segment.isidentifier() for segment in segments): - return False - - return all( - any(char.islower() for char in segment) or segment in _CASE_SENSITIVE_IMPORT_SEGMENTS for segment in segments - ) + return _is_plausible_python_module(mod) def _classify_import_reference( diff --git a/tests/scanners/test_pickle_scanner.py b/tests/scanners/test_pickle_scanner.py index a3a57e54..fdbcc39a 100644 --- a/tests/scanners/test_pickle_scanner.py +++ b/tests/scanners/test_pickle_scanner.py @@ -26,6 +26,7 @@ _genops_with_fallback, _GenopsBudgetExceeded, _is_actually_dangerous_global, + _is_plausible_python_module, _is_safe_import_only_global, _simulate_symbolic_reference_maps, ) @@ -759,6 +760,35 @@ def _craft_global_reduce_pickle(module: str, func: str) -> bytes: call_ops = b"(" + b"t" + b"R" + b"." return proto + global_op + call_ops + @staticmethod + def _craft_stack_global_reduce_pickle(module: str, func: str) -> bytes: + """Craft protocol-4 payload with STACK_GLOBAL + REDUCE.""" + + return b"\x80\x04" + _short_binunicode(module.encode()) + _short_binunicode(func.encode()) + b"\x93(tR." + + @staticmethod + def _craft_memoized_stack_global_reduce_pickle(module: str, func: str) -> bytes: + """Craft protocol-4 payload that recalls a memoized STACK_GLOBAL before REDUCE.""" + + payload = bytearray(b"\x80\x04") + payload += _short_binunicode(module.encode()) + payload += _short_binunicode(func.encode()) + payload += b"\x93" # STACK_GLOBAL + payload += b"\x94" # MEMOIZE index 0 + payload += b"0" # POP + payload += b"h\x00" # BINGET 0 + payload += b"(" # MARK + payload += b"t" # TUPLE + payload += b"R" # REDUCE + payload += b"." + return bytes(payload) + + @staticmethod + def _craft_global_import_only_pickle(module: str, func: str) -> bytes: + """Craft minimal pickle that only imports a GLOBAL and stops.""" + + return TestPickleScannerBlocklistHardening._craft_global_only_pickle(module, func) + @staticmethod def _craft_global_only_pickle(module: str, func: str) -> bytes: """Craft a minimal pickle with a bare GLOBAL reference and STOP.""" @@ -882,7 +912,6 @@ def test_builtins_hasattr_stack_global_is_critical(self) -> None: def test_builtins_hasattr_binput_binget_recall_is_critical(self) -> None: """Memoized callable recall via BINPUT/BINGET must keep builtins.hasattr dangerous.""" - # Memoize the callable, drop the original stack reference, then recall it. payload = b"\x80\x02cbuiltins\nhasattr\nq\x010h\x01(tR." result = self._scan_bytes(payload) @@ -925,6 +954,148 @@ def test_builtins_hasattr_detected_after_benign_stream(self) -> None: f"got: {[check.details for check in failed_reduce_checks]}" ) + def test_plausible_module_allows_mixed_case_identifiers(self) -> None: + assert _is_plausible_python_module("EvilPkg") + assert _is_plausible_python_module("PIL") + assert _is_plausible_python_module("MyOrg.InternalPkg") + + def test_plausible_module_rejects_malformed_names(self) -> None: + assert not _is_plausible_python_module("foo bar") + assert not _is_plausible_python_module("foo..bar") + assert not _is_plausible_python_module("foo/bar") + assert not _is_plausible_python_module("!!!") + assert not _is_plausible_python_module("PEDRA_2020") + + def test_mixed_case_global_reduce_is_not_suppressed(self) -> None: + result = self._scan_bytes(self._craft_global_reduce_pickle("EvilPkg", "thing")) + + reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"] + assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), ( + f"Expected failed REDUCE check for mixed-case module, got: {[c.message for c in reduce_checks]}" + ) + assert not any("implausible module name 'EvilPkg'" in c.message for c in reduce_checks), ( + "Mixed-case module names should not be classified as implausible" + ) + + def test_pil_global_reduce_is_not_suppressed(self) -> None: + """Legitimate mixed-case modules like PIL should no longer be treated as implausible.""" + result = self._scan_bytes(self._craft_global_reduce_pickle("PIL", "Image")) + + reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"] + assert any("PIL.Image" in c.message for c in reduce_checks), ( + f"Expected REDUCE analysis to resolve PIL.Image, got: {[c.message for c in reduce_checks]}" + ) + assert not any("implausible module name 'PIL'" in c.message for c in reduce_checks), ( + "PIL should no longer be classified as an implausible module" + ) + + def test_mixed_case_stack_global_reduce_is_not_suppressed(self) -> None: + result = self._scan_bytes(self._craft_stack_global_reduce_pickle("EvilPkg", "thing")) + + reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"] + assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), ( + "Expected failed REDUCE check for STACK_GLOBAL mixed-case module, " + f"got: {[c.message for c in reduce_checks]}" + ) + assert not any("implausible module name 'EvilPkg'" in c.message for c in reduce_checks), ( + "Mixed-case STACK_GLOBAL paths should not be suppressed as implausible" + ) + + def test_mixed_case_memoized_stack_global_reduce_is_not_suppressed(self) -> None: + result = self._scan_bytes(self._craft_memoized_stack_global_reduce_pickle("EvilPkg", "thing")) + + reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"] + assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), ( + "Expected failed REDUCE check for memoized mixed-case STACK_GLOBAL, " + f"got: {[c.message for c in reduce_checks]}" + ) + assert not any("implausible module name 'EvilPkg'" in c.message for c in reduce_checks), ( + "Memoized mixed-case STACK_GLOBAL paths should not be suppressed as implausible" + ) + + def test_mixed_case_import_only_payload_still_flags_import(self) -> None: + result = self._scan_bytes(self._craft_global_import_only_pickle("Builtins", "eval")) + + import_issues = [issue for issue in result.issues if "Suspicious reference Builtins.eval" in issue.message] + assert import_issues, ( + "Expected suspicious import-only detection for mixed-case dangerous global, " + f"got: {[i.message for i in result.issues]}" + ) + + benign_result = self._scan_bytes(self._craft_global_import_only_pickle("EvilPkg", "thing")) + benign_checks = [ + check + for check in benign_result.checks + if check.name == "Global Module Reference Check" + and check.details.get("import_reference") == "EvilPkg.thing" + and check.details.get("import_only") is True + ] + assert benign_checks, f"Expected import-only analysis for EvilPkg.thing: {benign_result.checks}" + assert all(check.severity == IssueSeverity.WARNING for check in benign_checks), ( + f"Mixed-case unknown imports should not be escalated as dangerous: {benign_checks}" + ) + assert all(check.details.get("classification") == "unknown_third_party" for check in benign_checks), ( + f"Expected mixed-case benign counterpart to stay unknown_third_party: {benign_checks}" + ) + assert not any( + check.severity == IssueSeverity.CRITICAL and check.details.get("import_reference") == "EvilPkg.thing" + for check in benign_result.checks + ), f"Unexpected critical mixed-case import finding for EvilPkg.thing: {benign_result.checks}" + + def test_mixed_case_unknown_import_only_is_flagged(self) -> None: + """Mixed-case unknown import-only refs should now reach the import-only warning path.""" + result = self._scan_bytes(self._craft_global_import_only_pickle("EvilPkg", "thing")) + + failing_checks = [ + check + for check in result.checks + if check.name == "Global Module Reference Check" + and check.status == CheckStatus.FAILED + and check.severity == IssueSeverity.WARNING + and check.details.get("import_reference") == "EvilPkg.thing" + and check.details.get("import_only") is True + and check.details.get("classification") == "unknown_third_party" + ] + assert failing_checks, f"Expected import-only warning for EvilPkg.thing: {result.checks}" + assert not any( + "implausible module name 'EvilPkg'" in check.message + for check in result.checks + if check.name == "Global Module Reference Check" + ), f"Mixed-case import-only path should not be suppressed as implausible: {result.checks}" + assert any( + issue.severity == IssueSeverity.WARNING and "EvilPkg.thing" in issue.message for issue in result.issues + ), f"Expected warning issue for EvilPkg.thing: {result.issues}" + + def test_mixed_case_reduce_in_later_stream_is_not_suppressed(self) -> None: + import io + + buf = io.BytesIO() + pickle.dump({"safe": True}, buf, protocol=2) + buf.write(self._craft_global_reduce_pickle("EvilPkg", "thing")) + + result = self._scan_bytes(buf.getvalue()) + reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"] + assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), ( + f"Expected later-stream REDUCE check for mixed-case module, got: {[c.message for c in reduce_checks]}" + ) + + def test_malformed_module_reduce_stays_implausible(self) -> None: + result = self._scan_bytes(self._craft_global_reduce_pickle("foo..bar", "thing")) + + reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"] + assert any( + c.status == CheckStatus.PASSED and "implausible module name 'foo..bar'" in c.message for c in reduce_checks + ), f"Expected malformed module to remain implausible, got: {[c.message for c in reduce_checks]}" + + def test_uppercase_data_label_reduce_stays_implausible(self) -> None: + result = self._scan_bytes(self._craft_global_reduce_pickle("PEDRA_2020", "thing")) + + reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"] + assert any( + c.status == CheckStatus.PASSED and "implausible module name 'PEDRA_2020'" in c.message + for c in reduce_checks + ), f"Expected uppercase data label to remain implausible, got: {[c.message for c in reduce_checks]}" + @staticmethod def _structural_tamper_checks(result: ScanResult) -> list: return [issue for issue in result.issues if issue.details.get("tamper_type") is not None] diff --git a/tests/test_cli.py b/tests/test_cli.py index 9cbf1ce4..3430af52 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1425,6 +1425,7 @@ def test_exit_code_security_issues_streaming_local_directory(tmp_path: Path) -> import pickle evil_pickle_path = tmp_path / "malicious.pkl" + expected_global = f"{os.system.__module__}.system" class MaliciousClass: def __reduce__(self): @@ -1437,6 +1438,7 @@ def __reduce__(self): result = runner.invoke(cli, ["scan", "--stream", "--format", "text", str(tmp_path)]) assert result.exit_code == 1, f"Expected exit code 1, got {result.exit_code}. Output: {result.output}" + assert expected_global in result.output, f"Expected malicious finding in output, got: {result.output}" assert not evil_pickle_path.exists() diff --git a/tests/test_streaming_scan.py b/tests/test_streaming_scan.py index 8ed45963..a39eb782 100644 --- a/tests/test_streaming_scan.py +++ b/tests/test_streaming_scan.py @@ -128,8 +128,17 @@ def test_scan_model_streaming_critical_findings_do_not_set_operational_errors( def file_generator(): yield (temp_test_files[0], True) + finding = ScanResult(scanner_name="test_scanner") + finding.bytes_scanned = 1024 + finding.success = True + finding.add_issue( + "Detected malicious payload", + severity=IssueSeverity.CRITICAL, + location=str(temp_test_files[0]), + ) + with patch("modelaudit.core.scan_file") as mock_scan: - mock_scan.return_value = create_mock_scan_result(with_critical_issue=True) + mock_scan.return_value = finding result = scan_model_streaming( file_generator=file_generator(), @@ -139,7 +148,11 @@ def file_generator(): assert result.files_scanned == 1 assert len(result.issues) == 1 + assert result.issues[0].message == "Detected malicious payload" + assert result.issues[0].severity == IssueSeverity.CRITICAL + assert result.issues[0].location == str(temp_test_files[0]) assert result.has_errors is False + assert result.success is True assert determine_exit_code(result) == 1