From 5f38596f911d8d9f14c6b5be499d881e9e31eb1f Mon Sep 17 00:00:00 2001 From: massy-o Date: Thu, 14 May 2026 14:52:04 +0900 Subject: [PATCH] Scan compressed pickle artifacts --- modelscan/middlewares/format_via_extension.py | 4 +- modelscan/modelscan.py | 18 ++++--- modelscan/settings.py | 32 ++++++++++++ modelscan/tools/picklescanner.py | 49 +++++++++++++++++-- tests/test_modelscan.py | 24 +++++++++ 5 files changed, 115 insertions(+), 12 deletions(-) diff --git a/modelscan/middlewares/format_via_extension.py b/modelscan/middlewares/format_via_extension.py index 2150acfc..51b2602f 100644 --- a/modelscan/middlewares/format_via_extension.py +++ b/modelscan/middlewares/format_via_extension.py @@ -5,11 +5,11 @@ class FormatViaExtensionMiddleware(MiddlewareBase): def __call__(self, model: Model, call_next: Callable[[Model], None]) -> None: - extension = model.get_source().suffix + source = str(model.get_source()) formats = [ format for format, extensions in self._settings["formats"].items() - if extension in extensions + if any(source.endswith(extension) for extension in extensions) ] if len(formats) > 0: model.set_context("formats", model.get_context("formats") or [] + formats) diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 4442f5eb..5b6cf1c1 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -61,7 +61,7 @@ def _load_scanners(self) -> None: and self._settings["scanners"][scanner_path]["enabled"] ): try: - (modulename, classname) = scanner_path.rsplit(".", 1) + modulename, classname = scanner_path.rsplit(".", 1) imported_module = importlib.import_module( name=modulename, package=classname ) @@ -302,13 +302,17 @@ def _generate_results(self) -> Dict[str, Any]: def is_compatible(self, path: str) -> bool: # Determines whether a file path is compatible with any of the available scanners - if Path(path).suffix in self._settings["supported_zip_extensions"]: + if any( + path.endswith(extension) + for extension in self._settings["supported_zip_extensions"] + ): return True for scanner_path, scanner_settings in self._settings["scanners"].items(): - if ( - "supported_extensions" in scanner_settings.keys() - and Path(path).suffix - in self._settings["scanners"][scanner_path]["supported_extensions"] + if "supported_extensions" in scanner_settings.keys() and any( + path.endswith(extension) + for extension in self._settings["scanners"][scanner_path][ + "supported_extensions" + ] ): return True @@ -320,7 +324,7 @@ def generate_report(self) -> Optional[str]: scan_report = None try: - (modulename, classname) = reporting_module.rsplit(".", 1) + modulename, classname = reporting_module.rsplit(".", 1) imported_module = importlib.import_module( name=modulename, package=classname ) diff --git a/modelscan/settings.py b/modelscan/settings.py index 56b3a796..9f888f32 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -60,9 +60,25 @@ class SupportedModelFormats: "enabled": True, "supported_extensions": [ ".pkl", + ".pkl.bz2", + ".pkl.gz", + ".pkl.lzma", + ".pkl.xz", ".pickle", + ".pickle.bz2", + ".pickle.gz", + ".pickle.lzma", + ".pickle.xz", ".joblib", + ".joblib.bz2", + ".joblib.gz", + ".joblib.lzma", + ".joblib.xz", ".dill", + ".dill.bz2", + ".dill.gz", + ".dill.lzma", + ".dill.xz", ".dat", ".data", ], @@ -82,9 +98,25 @@ class SupportedModelFormats: SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], SupportedModelFormats.PICKLE: [ ".pkl", + ".pkl.bz2", + ".pkl.gz", + ".pkl.lzma", + ".pkl.xz", ".pickle", + ".pickle.bz2", + ".pickle.gz", + ".pickle.lzma", + ".pickle.xz", ".joblib", + ".joblib.bz2", + ".joblib.gz", + ".joblib.lzma", + ".joblib.xz", ".dill", + ".dill.bz2", + ".dill.gz", + ".dill.lzma", + ".dill.xz", ".dat", ".data", ], diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index 44c4e2a0..6c695617 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -1,4 +1,8 @@ import logging +import bz2 +import gzip +import io +import lzma import pickletools # nosec from tarfile import TarError from typing import IO, Any, Dict, List, Set, Tuple, Union, Optional @@ -15,6 +19,13 @@ from .utils import MAGIC_NUMBER, _should_read_directly, get_magic_number +COMPRESSED_PICKLE_SUFFIXES = { + ".bz2": bz2.decompress, + ".gz": gzip.decompress, + ".lzma": lzma.decompress, + ".xz": lzma.decompress, +} + class GenOpsError(Exception): def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]): @@ -128,8 +139,26 @@ def scan_pickle_bytes( ) -> ScanResults: """Disassemble a Pickle stream and report issues""" issues: List[Issue] = [] + stream = model.get_stream(offset) + decompress = COMPRESSED_PICKLE_SUFFIXES.get(model.get_source().suffix) + if decompress is not None: + try: + stream = io.BytesIO(decompress(stream.read())) + except Exception as e: + return ScanResults( + issues, + [ + PickleGenopsError( + scan_name, + f"Decompression error: {e}", + model, + ) + ], + [], + ) + try: - raw_globals = _list_globals(model.get_stream(offset), multiple_pickles) + raw_globals = _list_globals(stream, multiple_pickles) except GenOpsError as e: if e.globals is not None: return _build_scan_result_from_raw_globals( @@ -228,8 +257,22 @@ def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults: elif magic == np.lib.format.MAGIC_PREFIX: # .npy file version = np.lib.format.read_magic(stream) # type: ignore[no-untyped-call] - np.lib.format._check_version(version) # type: ignore[attr-defined] - _, _, dtype = np.lib.format._read_array_header(stream, version) # type: ignore[attr-defined] + if version == (1, 0): + _, _, dtype = np.lib.format.read_array_header_1_0(stream) + elif version == (2, 0): + _, _, dtype = np.lib.format.read_array_header_2_0(stream) + else: + return ScanResults( + [], + [ + PickleGenopsError( + scan_name, + f"Unsupported numpy file format version: {version}", + model, + ) + ], + [], + ) if dtype.hasobject: return scan_pickle_bytes(model, settings, scan_name, True, stream.tell()) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index a82e4ecc..1727af1b 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -1,5 +1,6 @@ import aiohttp import bdb +import gzip import http.client import importlib import io @@ -283,6 +284,10 @@ def file_path(tmp_path_factory: Any) -> Any: initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) initialize_pickle_file(f"{tmp}/data/malicious15.pkl", Malicious15(), 4) + initialize_data_file( + f"{tmp}/data/malicious16.joblib.gz", + gzip.compress(pickle.dumps(Malicious2(), protocol=4)), + ) # Malicious Pickle from Capture-the-Flag challenge 'Misc/Safe Pickle' at https://imaginaryctf.org/Challenges # GitHub Issue: https://github.com/mmaitre314/picklescan/issues/22 @@ -670,6 +675,25 @@ def test_scan_file_path(file_path: Any) -> None: assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] + compressed_joblib = ModelScan() + expected_compressed_joblib = { + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", + "system", + IssueSeverity.CRITICAL, + f"{file_path}/data/malicious16.joblib.gz", + ), + ), + } + results = compressed_joblib.scan(Path(f"{file_path}/data/malicious16.joblib.gz")) + compare_results(compressed_joblib.issues.all_issues, expected_compressed_joblib) + assert results["summary"]["scanned"]["scanned_files"] == ["malicious16.joblib.gz"] + assert results["summary"]["skipped"]["skipped_files"] == [] + assert results["errors"] == [] + def test_scan_pickle_operators(file_path: Any) -> None: # Tests the unsafe pickle operators we screen for, across differences in pickle versions 0-2, 3, and 4