Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modelscan/middlewares/format_via_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

class FormatViaExtensionMiddleware(MiddlewareBase):
def __call__(self, model: Model, call_next: Callable[[Model], None]) -> None:
extension = model.get_source().suffix
source = str(model.get_source())
formats = [
format
for format, extensions in self._settings["formats"].items()
if extension in extensions
if any(source.endswith(extension) for extension in extensions)
]
if len(formats) > 0:
model.set_context("formats", model.get_context("formats") or [] + formats)
Expand Down
18 changes: 11 additions & 7 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _load_scanners(self) -> None:
and self._settings["scanners"][scanner_path]["enabled"]
):
try:
(modulename, classname) = scanner_path.rsplit(".", 1)
modulename, classname = scanner_path.rsplit(".", 1)
imported_module = importlib.import_module(
name=modulename, package=classname
)
Expand Down Expand Up @@ -302,13 +302,17 @@ def _generate_results(self) -> Dict[str, Any]:

def is_compatible(self, path: str) -> bool:
# Determines whether a file path is compatible with any of the available scanners
if Path(path).suffix in self._settings["supported_zip_extensions"]:
if any(
path.endswith(extension)
for extension in self._settings["supported_zip_extensions"]
):
return True
for scanner_path, scanner_settings in self._settings["scanners"].items():
if (
"supported_extensions" in scanner_settings.keys()
and Path(path).suffix
in self._settings["scanners"][scanner_path]["supported_extensions"]
if "supported_extensions" in scanner_settings.keys() and any(
path.endswith(extension)
for extension in self._settings["scanners"][scanner_path][
"supported_extensions"
]
):
return True

Expand All @@ -320,7 +324,7 @@ def generate_report(self) -> Optional[str]:

scan_report = None
try:
(modulename, classname) = reporting_module.rsplit(".", 1)
modulename, classname = reporting_module.rsplit(".", 1)
imported_module = importlib.import_module(
name=modulename, package=classname
)
Expand Down
32 changes: 32 additions & 0 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,25 @@ class SupportedModelFormats:
"enabled": True,
"supported_extensions": [
".pkl",
".pkl.bz2",
".pkl.gz",
".pkl.lzma",
".pkl.xz",
".pickle",
".pickle.bz2",
".pickle.gz",
".pickle.lzma",
".pickle.xz",
".joblib",
".joblib.bz2",
".joblib.gz",
".joblib.lzma",
".joblib.xz",
".dill",
".dill.bz2",
".dill.gz",
".dill.lzma",
".dill.xz",
".dat",
".data",
],
Expand All @@ -82,9 +98,25 @@ class SupportedModelFormats:
SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
SupportedModelFormats.PICKLE: [
".pkl",
".pkl.bz2",
".pkl.gz",
".pkl.lzma",
".pkl.xz",
".pickle",
".pickle.bz2",
".pickle.gz",
".pickle.lzma",
".pickle.xz",
".joblib",
".joblib.bz2",
".joblib.gz",
".joblib.lzma",
".joblib.xz",
".dill",
".dill.bz2",
".dill.gz",
".dill.lzma",
".dill.xz",
".dat",
".data",
],
Expand Down
49 changes: 46 additions & 3 deletions modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import logging
import bz2
import gzip
import io
import lzma
import pickletools # nosec
from tarfile import TarError
from typing import IO, Any, Dict, List, Set, Tuple, Union, Optional
Expand All @@ -15,6 +19,13 @@

from .utils import MAGIC_NUMBER, _should_read_directly, get_magic_number

COMPRESSED_PICKLE_SUFFIXES = {
".bz2": bz2.decompress,
".gz": gzip.decompress,
".lzma": lzma.decompress,
".xz": lzma.decompress,
}


class GenOpsError(Exception):
def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]):
Expand Down Expand Up @@ -128,8 +139,26 @@ def scan_pickle_bytes(
) -> ScanResults:
"""Disassemble a Pickle stream and report issues"""
issues: List[Issue] = []
stream = model.get_stream(offset)
decompress = COMPRESSED_PICKLE_SUFFIXES.get(model.get_source().suffix)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review: this uses the outer compound extension to decide whether to decompress before pickle opcode scanning. That directly covers valid joblib.dump(..., compress=...) artifacts such as .joblib.gz; the tradeoff is that compression policy stays extension-based rather than magic-byte based.

if decompress is not None:
try:
stream = io.BytesIO(decompress(stream.read()))
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review: decompression is intentionally placed before _list_globals() so the existing unsafe-global detection remains unchanged. This keeps the patch small, but very large compressed artifacts could justify a streaming decompression follow-up.

except Exception as e:
return ScanResults(
issues,
[
PickleGenopsError(
scan_name,
f"Decompression error: {e}",
model,
)
],
[],
)

try:
raw_globals = _list_globals(model.get_stream(offset), multiple_pickles)
raw_globals = _list_globals(stream, multiple_pickles)
except GenOpsError as e:
if e.globals is not None:
return _build_scan_result_from_raw_globals(
Expand Down Expand Up @@ -228,8 +257,22 @@ def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults:
elif magic == np.lib.format.MAGIC_PREFIX:
# .npy file
version = np.lib.format.read_magic(stream) # type: ignore[no-untyped-call]
np.lib.format._check_version(version) # type: ignore[attr-defined]
_, _, dtype = np.lib.format._read_array_header(stream, version) # type: ignore[attr-defined]
if version == (1, 0):
_, _, dtype = np.lib.format.read_array_header_1_0(stream)
elif version == (2, 0):
_, _, dtype = np.lib.format.read_array_header_2_0(stream)
else:
return ScanResults(
[],
[
PickleGenopsError(
scan_name,
f"Unsupported numpy file format version: {version}",
model,
)
],
[],
)

if dtype.hasobject:
return scan_pickle_bytes(model, settings, scan_name, True, stream.tell())
Expand Down
24 changes: 24 additions & 0 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import aiohttp
import bdb
import gzip
import http.client
import importlib
import io
Expand Down Expand Up @@ -283,6 +284,10 @@ def file_path(tmp_path_factory: Any) -> Any:
initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4)
initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4)
initialize_pickle_file(f"{tmp}/data/malicious15.pkl", Malicious15(), 4)
initialize_data_file(
f"{tmp}/data/malicious16.joblib.gz",
gzip.compress(pickle.dumps(Malicious2(), protocol=4)),
)

# Malicious Pickle from Capture-the-Flag challenge 'Misc/Safe Pickle' at https://imaginaryctf.org/Challenges
# GitHub Issue: https://github.com/mmaitre314/picklescan/issues/22
Expand Down Expand Up @@ -670,6 +675,25 @@ def test_scan_file_path(file_path: Any) -> None:
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []

compressed_joblib = ModelScan()
expected_compressed_joblib = {
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"posix",
"system",
IssueSeverity.CRITICAL,
f"{file_path}/data/malicious16.joblib.gz",
),
),
}
results = compressed_joblib.scan(Path(f"{file_path}/data/malicious16.joblib.gz"))
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review: this regression asserts the formerly skipped .joblib.gz artifact now goes through the pickle scanner and reports the embedded posix.system payload as CRITICAL, matching the reported scanner-bypass path.

compare_results(compressed_joblib.issues.all_issues, expected_compressed_joblib)
assert results["summary"]["scanned"]["scanned_files"] == ["malicious16.joblib.gz"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []


def test_scan_pickle_operators(file_path: Any) -> None:
# Tests the unsafe pickle operators we screen for, across differences in pickle versions 0-2, 3, and 4
Expand Down