Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import zipfile
import logging
from typing import List, Optional
from typing import List, Optional, Set


from modelscan.error import DependencyError, ModelScanScannerError, JsonDecodeError
Expand All @@ -14,6 +14,13 @@

logger = logging.getLogger("modelscan")

# Keras-internal module prefixes that are safe to import
_SAFE_KERAS_MODULE_PREFIXES = (
"keras",
"tensorflow",
"tf_keras",
)


class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(self, model: Model) -> Optional[ScanResults]:
Expand Down Expand Up @@ -78,8 +85,6 @@ def scan(self, model: Model) -> Optional[ScanResults]:
def _scan_keras_config_file(self, model: Model) -> ScanResults:
machine_learning_library_name = "Keras"

# if self._check_json_data(source, config_file):

try:
operators_in_model = self._get_keras_operator_names(model)
except json.JSONDecodeError as e:
Expand Down Expand Up @@ -118,16 +123,64 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults:

def _get_keras_operator_names(self, model: Model) -> List[str]:
model_config_data = json.load(model.get_stream())
operators = []

# Check for Lambda layers (original check)
lambda_layers = [
layer.get("config", {}).get("function", {})
for layer in model_config_data.get("config", {}).get("layers", {})
if layer.get("class_name", {}) == "Lambda"
]
if lambda_layers:
return ["Lambda"] * len(lambda_layers)
operators.extend(["Lambda"] * len(lambda_layers))

# Check for unsafe module references in the entire config tree
unsafe_modules = self._extract_unsafe_modules(model_config_data)
for module_ref in unsafe_modules:
operators.append(f"UnsafeModule:{module_ref}")

return operators

@staticmethod
def _extract_unsafe_modules(config: dict, visited: Optional[Set[int]] = None) -> List[str]:
"""Recursively extract non-Keras module references from config tree.

Keras config.json uses module/class_name pairs throughout the config
hierarchy (layers, initializers, regularizers, constraints, dtype
policies). Any module outside the Keras/TensorFlow namespace could
be used for arbitrary code execution via importlib on load.
"""
if visited is None:
visited = set()

obj_id = id(config)
if obj_id in visited:
return []
visited.add(obj_id)

unsafe = []

if isinstance(config, dict):
module = config.get("module")
if isinstance(module, str) and module:
if not module.startswith(_SAFE_KERAS_MODULE_PREFIXES):
class_name = config.get("class_name", "unknown")
unsafe.append(f"{module}.{class_name}")

for value in config.values():
if isinstance(value, (dict, list)):
unsafe.extend(
KerasLambdaDetectScan._extract_unsafe_modules(value, visited)
)

elif isinstance(config, list):
for item in config:
if isinstance(item, (dict, list)):
unsafe.extend(
KerasLambdaDetectScan._extract_unsafe_modules(item, visited)
)

return []
return unsafe

@staticmethod
def name() -> str:
Expand Down
1 change: 1 addition & 0 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SupportedModelFormats:
"supported_extensions": [".pb"],
"unsafe_keras_operators": {
"Lambda": "MEDIUM",
"UnsafeModule": "CRITICAL",
},
},
"modelscan.scanners.SavedModelTensorflowOpScan": {
Expand Down
125 changes: 125 additions & 0 deletions tests/test_keras_nested_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Test that modelscan detects unsafe module references in .keras config.json.

Verifies that the scanner catches non-Keras module references embedded in
nested config objects (initializers, regularizers, etc.), not just top-level
Lambda layers.
"""

import json
import zipfile
import io
import pytest


def _make_keras_zip(config: dict) -> io.BytesIO:
"""Create an in-memory .keras zip with the given config.json."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as z:
z.writestr("config.json", json.dumps(config))
buf.seek(0)
return buf


SAFE_CONFIG = {
"module": "keras",
"class_name": "Sequential",
"config": {
"name": "sequential",
"layers": [
{
"module": "keras.layers",
"class_name": "Dense",
"config": {
"units": 1,
"kernel_initializer": {
"module": "keras.initializers",
"class_name": "GlorotUniform",
"config": {"seed": None},
"registered_name": None,
},
},
"registered_name": None,
}
],
},
"registered_name": None,
}

MALICIOUS_NESTED_CONFIG = {
"module": "keras",
"class_name": "Sequential",
"config": {
"name": "sequential",
"layers": [
{
"module": "keras.layers",
"class_name": "Dense",
"config": {
"units": 1,
"kernel_initializer": {
"module": "builtins",
"class_name": "exec",
"config": {},
"registered_name": None,
},
},
"registered_name": None,
}
],
},
"registered_name": None,
}

MALICIOUS_TOP_LEVEL_MODULE = {
"module": "keras",
"class_name": "Sequential",
"config": {
"name": "sequential",
"layers": [
{
"module": "subprocess",
"class_name": "Popen",
"config": {"name": "exploit"},
"registered_name": None,
}
],
},
"registered_name": None,
}


class TestExtractUnsafeModules:
"""Test the _extract_unsafe_modules static method directly."""

def test_safe_config_returns_empty(self):
from modelscan.scanners.keras.scan import KerasLambdaDetectScan

result = KerasLambdaDetectScan._extract_unsafe_modules(SAFE_CONFIG)
assert result == []

def test_nested_builtins_exec_detected(self):
from modelscan.scanners.keras.scan import KerasLambdaDetectScan

result = KerasLambdaDetectScan._extract_unsafe_modules(MALICIOUS_NESTED_CONFIG)
assert len(result) == 1
assert "builtins.exec" in result[0]

def test_top_level_subprocess_detected(self):
from modelscan.scanners.keras.scan import KerasLambdaDetectScan

result = KerasLambdaDetectScan._extract_unsafe_modules(
MALICIOUS_TOP_LEVEL_MODULE
)
assert len(result) == 1
assert "subprocess.Popen" in result[0]

def test_tensorflow_modules_are_safe(self):
from modelscan.scanners.keras.scan import KerasLambdaDetectScan

config = {
"module": "tensorflow.python.ops",
"class_name": "Operation",
"config": {},
}
result = KerasLambdaDetectScan._extract_unsafe_modules(config)
assert result == []