diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 1e88c389..981ed9d7 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -1,7 +1,7 @@ import json import zipfile import logging -from typing import List, Optional +from typing import List, Optional, Set from modelscan.error import DependencyError, ModelScanScannerError, JsonDecodeError @@ -14,6 +14,13 @@ logger = logging.getLogger("modelscan") +# Keras-internal module prefixes that are safe to import +_SAFE_KERAS_MODULE_PREFIXES = ( + "keras", + "tensorflow", + "tf_keras", +) + class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan(self, model: Model) -> Optional[ScanResults]: @@ -78,8 +85,6 @@ def scan(self, model: Model) -> Optional[ScanResults]: def _scan_keras_config_file(self, model: Model) -> ScanResults: machine_learning_library_name = "Keras" - # if self._check_json_data(source, config_file): - try: operators_in_model = self._get_keras_operator_names(model) except json.JSONDecodeError as e: @@ -118,16 +123,64 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults: def _get_keras_operator_names(self, model: Model) -> List[str]: model_config_data = json.load(model.get_stream()) + operators = [] + # Check for Lambda layers (original check) lambda_layers = [ layer.get("config", {}).get("function", {}) for layer in model_config_data.get("config", {}).get("layers", {}) if layer.get("class_name", {}) == "Lambda" ] if lambda_layers: - return ["Lambda"] * len(lambda_layers) + operators.extend(["Lambda"] * len(lambda_layers)) + + # Check for unsafe module references in the entire config tree + unsafe_modules = self._extract_unsafe_modules(model_config_data) + for module_ref in unsafe_modules: + operators.append(f"UnsafeModule:{module_ref}") + + return operators + + @staticmethod + def _extract_unsafe_modules(config: dict, visited: Optional[Set[int]] = None) -> List[str]: + """Recursively extract non-Keras module references from config tree. + + Keras config.json uses module/class_name pairs throughout the config + hierarchy (layers, initializers, regularizers, constraints, dtype + policies). Any module outside the Keras/TensorFlow namespace could + be used for arbitrary code execution via importlib on load. + """ + if visited is None: + visited = set() + + obj_id = id(config) + if obj_id in visited: + return [] + visited.add(obj_id) + + unsafe = [] + + if isinstance(config, dict): + module = config.get("module") + if isinstance(module, str) and module: + if not module.startswith(_SAFE_KERAS_MODULE_PREFIXES): + class_name = config.get("class_name", "unknown") + unsafe.append(f"{module}.{class_name}") + + for value in config.values(): + if isinstance(value, (dict, list)): + unsafe.extend( + KerasLambdaDetectScan._extract_unsafe_modules(value, visited) + ) + + elif isinstance(config, list): + for item in config: + if isinstance(item, (dict, list)): + unsafe.extend( + KerasLambdaDetectScan._extract_unsafe_modules(item, visited) + ) - return [] + return unsafe @staticmethod def name() -> str: diff --git a/modelscan/settings.py b/modelscan/settings.py index 56b3a796..d2b2015a 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -42,6 +42,7 @@ class SupportedModelFormats: "supported_extensions": [".pb"], "unsafe_keras_operators": { "Lambda": "MEDIUM", + "UnsafeModule": "CRITICAL", }, }, "modelscan.scanners.SavedModelTensorflowOpScan": { diff --git a/tests/test_keras_nested_config.py b/tests/test_keras_nested_config.py new file mode 100644 index 00000000..a2c22f77 --- /dev/null +++ b/tests/test_keras_nested_config.py @@ -0,0 +1,125 @@ +"""Test that modelscan detects unsafe module references in .keras config.json. + +Verifies that the scanner catches non-Keras module references embedded in +nested config objects (initializers, regularizers, etc.), not just top-level +Lambda layers. +""" + +import json +import zipfile +import io +import pytest + + +def _make_keras_zip(config: dict) -> io.BytesIO: + """Create an in-memory .keras zip with the given config.json.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as z: + z.writestr("config.json", json.dumps(config)) + buf.seek(0) + return buf + + +SAFE_CONFIG = { + "module": "keras", + "class_name": "Sequential", + "config": { + "name": "sequential", + "layers": [ + { + "module": "keras.layers", + "class_name": "Dense", + "config": { + "units": 1, + "kernel_initializer": { + "module": "keras.initializers", + "class_name": "GlorotUniform", + "config": {"seed": None}, + "registered_name": None, + }, + }, + "registered_name": None, + } + ], + }, + "registered_name": None, +} + +MALICIOUS_NESTED_CONFIG = { + "module": "keras", + "class_name": "Sequential", + "config": { + "name": "sequential", + "layers": [ + { + "module": "keras.layers", + "class_name": "Dense", + "config": { + "units": 1, + "kernel_initializer": { + "module": "builtins", + "class_name": "exec", + "config": {}, + "registered_name": None, + }, + }, + "registered_name": None, + } + ], + }, + "registered_name": None, +} + +MALICIOUS_TOP_LEVEL_MODULE = { + "module": "keras", + "class_name": "Sequential", + "config": { + "name": "sequential", + "layers": [ + { + "module": "subprocess", + "class_name": "Popen", + "config": {"name": "exploit"}, + "registered_name": None, + } + ], + }, + "registered_name": None, +} + + +class TestExtractUnsafeModules: + """Test the _extract_unsafe_modules static method directly.""" + + def test_safe_config_returns_empty(self): + from modelscan.scanners.keras.scan import KerasLambdaDetectScan + + result = KerasLambdaDetectScan._extract_unsafe_modules(SAFE_CONFIG) + assert result == [] + + def test_nested_builtins_exec_detected(self): + from modelscan.scanners.keras.scan import KerasLambdaDetectScan + + result = KerasLambdaDetectScan._extract_unsafe_modules(MALICIOUS_NESTED_CONFIG) + assert len(result) == 1 + assert "builtins.exec" in result[0] + + def test_top_level_subprocess_detected(self): + from modelscan.scanners.keras.scan import KerasLambdaDetectScan + + result = KerasLambdaDetectScan._extract_unsafe_modules( + MALICIOUS_TOP_LEVEL_MODULE + ) + assert len(result) == 1 + assert "subprocess.Popen" in result[0] + + def test_tensorflow_modules_are_safe(self): + from modelscan.scanners.keras.scan import KerasLambdaDetectScan + + config = { + "module": "tensorflow.python.ops", + "class_name": "Operation", + "config": {}, + } + result = KerasLambdaDetectScan._extract_unsafe_modules(config) + assert result == []