Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **security:** detect CVE-2022-25882 ONNX external_data path traversal with CVE attribution, CVSS score, and CWE classification in scan results
- **security:** detect CVE-2024-27318 ONNX nested external_data path traversal bypass via path segment sanitization evasion
- **security:** restore ZIP scanner fallback for invalid `.mar` archives so malicious ZIP payloads renamed to `.mar` cannot bypass archive checks

- **security:** flag risky import-only pickle references for `torch.jit`, `torch._dynamo`, `torch._inductor`, `torch.compile`, `torch.storage._load_from_bytes`, `numpy.f2py`, and `numpy.distutils` while preserving safe state-dict reconstruction paths
- **security:** add low-severity pickle structural tamper findings for duplicate or misplaced `PROTO` opcodes while avoiding benign binary-tail false positives
- **security:** stop treating mixed-case valid pickle module names as implausible, so import and reduce checks no longer bypass on names like `PIL` or attacker-chosen `EvilPkg`
- **security:** scan OCI layer members based on registered file extensions so embedded ONNX, Keras H5, and other real-path scanners are no longer skipped inside tar layers
- **security:** resolve bare-module TorchServe handler references like `custom_handler` to concrete archive members so malicious handler source is no longer skipped by static analysis
- **security:** compare archive entry paths against the intended extraction root without following base-directory symlinks
Expand Down
45 changes: 12 additions & 33 deletions modelaudit/scanners/pickle_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,45 +1970,33 @@ def _is_plausible_python_module(name: str) -> bool:
Check whether *name* looks like a real Python module/package path.

Legitimate module names follow Python identifier rules:
- Each dotted segment is a valid Python identifier (letters, digits,
underscores; cannot start with a digit).
- Conventionally all-lowercase, though private/internal modules may use
a leading underscore.

Names that contain uppercase letters, start with digits, or include
characters outside ``[a-z0-9_.]`` are almost certainly **not** real
modules -- they are more likely DataFrame column names, user labels,
or other data strings that ended up as pickle GLOBAL arguments
(e.g. ``PEDRA_2020``).

The check is intentionally conservative: a handful of legitimate but
unusual module names (e.g. ``PIL``, ``Cython``) are covered by
``ML_SAFE_GLOBALS`` and will pass the allowlist before this function
is ever consulted.
- Each dotted segment is an ASCII Python identifier.
- Segments normally contain lowercase characters, with a short explicit
allowlist for case-sensitive imports such as ``PIL``.

Keep obviously malformed names rejected so arbitrary data strings are less
likely to be treated as imports, while still allowing valid mixed-case
segments such as ``EvilPkg`` and ``MyOrg.InternalPkg``.

Returns:
True if *name* plausibly refers to a real Python module.
"""
import re

if not name:
return False

# Fast reject: real module paths never contain whitespace.
if " " in name or "\t" in name:
return False

# Split on dots; each segment must be a valid Python identifier that
# looks like a conventional module name (lowercase + digits + _).
# Split on dots; each segment must be an ASCII Python identifier.
segments = name.split(".")
if not segments or any(s == "" for s in segments):
if not segments or any(s == "" or not s.isascii() or not s.isidentifier() for s in segments):
return False

_MODULE_SEGMENT_RE = re.compile(r"^[a-z_][a-z0-9_]*$")
return all(_MODULE_SEGMENT_RE.match(seg) for seg in segments)
return all(any(char.islower() for char in seg) or seg in _CASE_SENSITIVE_IMPORT_SEGMENTS for seg in segments)


_CASE_SENSITIVE_IMPORT_SEGMENTS = frozenset({"PIL", "Cython"})
_CASE_SENSITIVE_IMPORT_SEGMENTS: frozenset[str] = frozenset({"PIL", "Cython"})
IMPORT_ONLY_ALWAYS_DANGEROUS_GLOBALS = frozenset(
{
("dill", "load"),
Expand Down Expand Up @@ -2079,16 +2067,7 @@ def _is_resolved_import_target(mod: str, func: str) -> bool:

def _is_plausible_import_only_module(mod: str) -> bool:
"""Return True when a module path looks importable without matching common data labels."""
if not mod:
return False

segments = mod.split(".")
if not segments or any(segment == "" or not segment.isidentifier() for segment in segments):
return False

return all(
any(char.islower() for char in segment) or segment in _CASE_SENSITIVE_IMPORT_SEGMENTS for segment in segments
)
return _is_plausible_python_module(mod)


def _classify_import_reference(
Expand Down
173 changes: 172 additions & 1 deletion tests/scanners/test_pickle_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_genops_with_fallback,
_GenopsBudgetExceeded,
_is_actually_dangerous_global,
_is_plausible_python_module,
_is_safe_import_only_global,
_simulate_symbolic_reference_maps,
)
Expand Down Expand Up @@ -759,6 +760,35 @@ def _craft_global_reduce_pickle(module: str, func: str) -> bytes:
call_ops = b"(" + b"t" + b"R" + b"."
return proto + global_op + call_ops

@staticmethod
def _craft_stack_global_reduce_pickle(module: str, func: str) -> bytes:
"""Craft protocol-4 payload with STACK_GLOBAL + REDUCE."""

return b"\x80\x04" + _short_binunicode(module.encode()) + _short_binunicode(func.encode()) + b"\x93(tR."

@staticmethod
def _craft_memoized_stack_global_reduce_pickle(module: str, func: str) -> bytes:
"""Craft protocol-4 payload that recalls a memoized STACK_GLOBAL before REDUCE."""

payload = bytearray(b"\x80\x04")
payload += _short_binunicode(module.encode())
payload += _short_binunicode(func.encode())
payload += b"\x93" # STACK_GLOBAL
payload += b"\x94" # MEMOIZE index 0
payload += b"0" # POP
payload += b"h\x00" # BINGET 0
payload += b"(" # MARK
payload += b"t" # TUPLE
payload += b"R" # REDUCE
payload += b"."
return bytes(payload)

@staticmethod
def _craft_global_import_only_pickle(module: str, func: str) -> bytes:
"""Craft minimal pickle that only imports a GLOBAL and stops."""

return TestPickleScannerBlocklistHardening._craft_global_only_pickle(module, func)

@staticmethod
def _craft_global_only_pickle(module: str, func: str) -> bytes:
"""Craft a minimal pickle with a bare GLOBAL reference and STOP."""
Expand Down Expand Up @@ -882,7 +912,6 @@ def test_builtins_hasattr_stack_global_is_critical(self) -> None:

def test_builtins_hasattr_binput_binget_recall_is_critical(self) -> None:
"""Memoized callable recall via BINPUT/BINGET must keep builtins.hasattr dangerous."""
# Memoize the callable, drop the original stack reference, then recall it.
payload = b"\x80\x02cbuiltins\nhasattr\nq\x010h\x01(tR."

result = self._scan_bytes(payload)
Expand Down Expand Up @@ -925,6 +954,148 @@ def test_builtins_hasattr_detected_after_benign_stream(self) -> None:
f"got: {[check.details for check in failed_reduce_checks]}"
)

def test_plausible_module_allows_mixed_case_identifiers(self) -> None:
assert _is_plausible_python_module("EvilPkg")
assert _is_plausible_python_module("PIL")
assert _is_plausible_python_module("MyOrg.InternalPkg")

def test_plausible_module_rejects_malformed_names(self) -> None:
assert not _is_plausible_python_module("foo bar")
assert not _is_plausible_python_module("foo..bar")
assert not _is_plausible_python_module("foo/bar")
assert not _is_plausible_python_module("!!!")
assert not _is_plausible_python_module("PEDRA_2020")

def test_mixed_case_global_reduce_is_not_suppressed(self) -> None:
result = self._scan_bytes(self._craft_global_reduce_pickle("EvilPkg", "thing"))

reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"]
assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), (
f"Expected failed REDUCE check for mixed-case module, got: {[c.message for c in reduce_checks]}"
)
assert not any("implausible module name 'EvilPkg'" in c.message for c in reduce_checks), (
"Mixed-case module names should not be classified as implausible"
)

def test_pil_global_reduce_is_not_suppressed(self) -> None:
"""Legitimate mixed-case modules like PIL should no longer be treated as implausible."""
result = self._scan_bytes(self._craft_global_reduce_pickle("PIL", "Image"))

reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"]
assert any("PIL.Image" in c.message for c in reduce_checks), (
f"Expected REDUCE analysis to resolve PIL.Image, got: {[c.message for c in reduce_checks]}"
)
assert not any("implausible module name 'PIL'" in c.message for c in reduce_checks), (
"PIL should no longer be classified as an implausible module"
)

def test_mixed_case_stack_global_reduce_is_not_suppressed(self) -> None:
result = self._scan_bytes(self._craft_stack_global_reduce_pickle("EvilPkg", "thing"))

reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"]
assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), (
"Expected failed REDUCE check for STACK_GLOBAL mixed-case module, "
f"got: {[c.message for c in reduce_checks]}"
)
assert not any("implausible module name 'EvilPkg'" in c.message for c in reduce_checks), (
"Mixed-case STACK_GLOBAL paths should not be suppressed as implausible"
)

def test_mixed_case_memoized_stack_global_reduce_is_not_suppressed(self) -> None:
result = self._scan_bytes(self._craft_memoized_stack_global_reduce_pickle("EvilPkg", "thing"))

reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"]
assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), (
"Expected failed REDUCE check for memoized mixed-case STACK_GLOBAL, "
f"got: {[c.message for c in reduce_checks]}"
)
assert not any("implausible module name 'EvilPkg'" in c.message for c in reduce_checks), (
"Memoized mixed-case STACK_GLOBAL paths should not be suppressed as implausible"
)

def test_mixed_case_import_only_payload_still_flags_import(self) -> None:
result = self._scan_bytes(self._craft_global_import_only_pickle("Builtins", "eval"))

import_issues = [issue for issue in result.issues if "Suspicious reference Builtins.eval" in issue.message]
assert import_issues, (
"Expected suspicious import-only detection for mixed-case dangerous global, "
f"got: {[i.message for i in result.issues]}"
)

benign_result = self._scan_bytes(self._craft_global_import_only_pickle("EvilPkg", "thing"))
benign_checks = [
check
for check in benign_result.checks
if check.name == "Global Module Reference Check"
and check.details.get("import_reference") == "EvilPkg.thing"
and check.details.get("import_only") is True
]
assert benign_checks, f"Expected import-only analysis for EvilPkg.thing: {benign_result.checks}"
assert all(check.severity == IssueSeverity.WARNING for check in benign_checks), (
f"Mixed-case unknown imports should not be escalated as dangerous: {benign_checks}"
)
assert all(check.details.get("classification") == "unknown_third_party" for check in benign_checks), (
f"Expected mixed-case benign counterpart to stay unknown_third_party: {benign_checks}"
)
assert not any(
check.severity == IssueSeverity.CRITICAL and check.details.get("import_reference") == "EvilPkg.thing"
for check in benign_result.checks
), f"Unexpected critical mixed-case import finding for EvilPkg.thing: {benign_result.checks}"

def test_mixed_case_unknown_import_only_is_flagged(self) -> None:
"""Mixed-case unknown import-only refs should now reach the import-only warning path."""
result = self._scan_bytes(self._craft_global_import_only_pickle("EvilPkg", "thing"))

failing_checks = [
check
for check in result.checks
if check.name == "Global Module Reference Check"
and check.status == CheckStatus.FAILED
and check.severity == IssueSeverity.WARNING
and check.details.get("import_reference") == "EvilPkg.thing"
and check.details.get("import_only") is True
and check.details.get("classification") == "unknown_third_party"
]
assert failing_checks, f"Expected import-only warning for EvilPkg.thing: {result.checks}"
assert not any(
"implausible module name 'EvilPkg'" in check.message
for check in result.checks
if check.name == "Global Module Reference Check"
), f"Mixed-case import-only path should not be suppressed as implausible: {result.checks}"
assert any(
issue.severity == IssueSeverity.WARNING and "EvilPkg.thing" in issue.message for issue in result.issues
), f"Expected warning issue for EvilPkg.thing: {result.issues}"

def test_mixed_case_reduce_in_later_stream_is_not_suppressed(self) -> None:
import io

buf = io.BytesIO()
pickle.dump({"safe": True}, buf, protocol=2)
buf.write(self._craft_global_reduce_pickle("EvilPkg", "thing"))

result = self._scan_bytes(buf.getvalue())
reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"]
assert any(c.status == CheckStatus.FAILED and "EvilPkg.thing" in c.message for c in reduce_checks), (
f"Expected later-stream REDUCE check for mixed-case module, got: {[c.message for c in reduce_checks]}"
)

def test_malformed_module_reduce_stays_implausible(self) -> None:
result = self._scan_bytes(self._craft_global_reduce_pickle("foo..bar", "thing"))

reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"]
assert any(
c.status == CheckStatus.PASSED and "implausible module name 'foo..bar'" in c.message for c in reduce_checks
), f"Expected malformed module to remain implausible, got: {[c.message for c in reduce_checks]}"

def test_uppercase_data_label_reduce_stays_implausible(self) -> None:
result = self._scan_bytes(self._craft_global_reduce_pickle("PEDRA_2020", "thing"))

reduce_checks = [c for c in result.checks if c.name == "REDUCE Opcode Safety Check"]
assert any(
c.status == CheckStatus.PASSED and "implausible module name 'PEDRA_2020'" in c.message
for c in reduce_checks
), f"Expected uppercase data label to remain implausible, got: {[c.message for c in reduce_checks]}"

@staticmethod
def _structural_tamper_checks(result: ScanResult) -> list:
return [issue for issue in result.issues if issue.details.get("tamper_type") is not None]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,7 @@ def test_exit_code_security_issues_streaming_local_directory(tmp_path: Path) ->
import pickle

evil_pickle_path = tmp_path / "malicious.pkl"
expected_global = f"{os.system.__module__}.system"

class MaliciousClass:
def __reduce__(self):
Expand All @@ -1437,6 +1438,7 @@ def __reduce__(self):
result = runner.invoke(cli, ["scan", "--stream", "--format", "text", str(tmp_path)])

assert result.exit_code == 1, f"Expected exit code 1, got {result.exit_code}. Output: {result.output}"
assert expected_global in result.output, f"Expected malicious finding in output, got: {result.output}"
assert not evil_pickle_path.exists()


Expand Down
15 changes: 14 additions & 1 deletion tests/test_streaming_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,17 @@ def test_scan_model_streaming_critical_findings_do_not_set_operational_errors(
def file_generator():
yield (temp_test_files[0], True)

finding = ScanResult(scanner_name="test_scanner")
finding.bytes_scanned = 1024
finding.success = True
finding.add_issue(
"Detected malicious payload",
severity=IssueSeverity.CRITICAL,
location=str(temp_test_files[0]),
)

with patch("modelaudit.core.scan_file") as mock_scan:
mock_scan.return_value = create_mock_scan_result(with_critical_issue=True)
mock_scan.return_value = finding

result = scan_model_streaming(
file_generator=file_generator(),
Expand All @@ -139,7 +148,11 @@ def file_generator():

assert result.files_scanned == 1
assert len(result.issues) == 1
assert result.issues[0].message == "Detected malicious payload"
assert result.issues[0].severity == IssueSeverity.CRITICAL
assert result.issues[0].location == str(temp_test_files[0])
assert result.has_errors is False
assert result.success is True
assert determine_exit_code(result) == 1


Expand Down
Loading