From f69ba1398279661a491e15f7da8de398e5feefe8 Mon Sep 17 00:00:00 2001 From: Ian Webster Date: Mon, 16 Mar 2026 15:07:48 -0700 Subject: [PATCH] fix(keras): derive safe layer inventory from exports --- modelaudit/config/generated_keras_layers.py | 176 ++++++++++++++++++++ modelaudit/detectors/suspicious_symbols.py | 153 +---------------- scripts/generate_keras_layer_inventory.py | 101 +++++++++++ tests/scanners/test_keras_h5_scanner.py | 6 +- tests/scanners/test_keras_zip_scanner.py | 9 +- 5 files changed, 293 insertions(+), 152 deletions(-) create mode 100644 modelaudit/config/generated_keras_layers.py create mode 100644 scripts/generate_keras_layer_inventory.py diff --git a/modelaudit/config/generated_keras_layers.py b/modelaudit/config/generated_keras_layers.py new file mode 100644 index 00000000..426fd520 --- /dev/null +++ b/modelaudit/config/generated_keras_layers.py @@ -0,0 +1,176 @@ +"""Generated Keras layer inventory. + +Regenerate this file with `scripts/generate_keras_layer_inventory.py` after +updating the Keras reference version used for scanner maintenance. +""" + +GENERATED_KNOWN_SAFE_KERAS_LAYER_CLASSES: frozenset[str] = frozenset( + { + "Activation", + "ActivityRegularization", + "AdaptiveAveragePooling1D", + "AdaptiveAveragePooling2D", + "AdaptiveAveragePooling3D", + "AdaptiveMaxPooling1D", + "AdaptiveMaxPooling2D", + "AdaptiveMaxPooling3D", + "Add", + "AdditiveAttention", + "AlphaDropout", + "Attention", + "AugMix", + "AutoContrast", + "Average", + "AveragePooling1D", + "AveragePooling2D", + "AveragePooling3D", + "AvgPool1D", + "AvgPool2D", + "AvgPool3D", + "BatchNormalization", + "Bidirectional", + "CategoryEncoding", + "CenterCrop", + "Concatenate", + "Conv1D", + "Conv1DTranspose", + "Conv2D", + "Conv2DTranspose", + "Conv3D", + "Conv3DTranspose", + "ConvLSTM1D", + "ConvLSTM2D", + "ConvLSTM3D", + "Convolution1D", + "Convolution1DTranspose", + "Convolution2D", + "Convolution2DTranspose", + "Convolution3D", + "Convolution3DTranspose", + "Cropping1D", + "Cropping2D", + "Cropping3D", + "CutMix", + "Dense", + "DepthwiseConv1D", + "DepthwiseConv2D", + "Discretization", + "Dot", + "Dropout", + "ELU", + "EinsumDense", + "Embedding", + "Equalization", + "Flatten", + "FlaxLayer", + "Functional", + "GRU", + "GRUCell", + "GaussianDropout", + "GaussianNoise", + "GlobalAveragePooling1D", + "GlobalAveragePooling2D", + "GlobalAveragePooling3D", + "GlobalAvgPool1D", + "GlobalAvgPool2D", + "GlobalAvgPool3D", + "GlobalMaxPool1D", + "GlobalMaxPool2D", + "GlobalMaxPool3D", + "GlobalMaxPooling1D", + "GlobalMaxPooling2D", + "GlobalMaxPooling3D", + "GroupNormalization", + "GroupQueryAttention", + "HashedCrossing", + "Hashing", + "Identity", + "InputLayer", + "InputSpec", + "IntegerLookup", + "JaxLayer", + "LSTM", + "LSTMCell", + "Lambda", + "Layer", + "LayerNormalization", + "LeakyReLU", + "Masking", + "MaxNumBoundingBoxes", + "MaxPool1D", + "MaxPool2D", + "MaxPool3D", + "MaxPooling1D", + "MaxPooling2D", + "MaxPooling3D", + "Maximum", + "MelSpectrogram", + "Minimum", + "MixUp", + "Model", + "MultiHeadAttention", + "Multiply", + "Normalization", + "PReLU", + "Permute", + "Pipeline", + "RMSNormalization", + "RNN", + "RandAugment", + "RandomBrightness", + "RandomColorDegeneration", + "RandomColorJitter", + "RandomContrast", + "RandomCrop", + "RandomElasticTransform", + "RandomErasing", + "RandomFlip", + "RandomGaussianBlur", + "RandomGrayscale", + "RandomHue", + "RandomInvert", + "RandomPerspective", + "RandomPosterization", + "RandomRotation", + "RandomSaturation", + "RandomSharpness", + "RandomShear", + "RandomTranslation", + "RandomZoom", + "ReLU", + "RepeatVector", + "Rescaling", + "Reshape", + "Resizing", + "ReversibleEmbedding", + "STFTSpectrogram", + "SeparableConv1D", + "SeparableConv2D", + "SeparableConvolution1D", + "SeparableConvolution2D", + "Sequential", + "SimpleRNN", + "SimpleRNNCell", + "Softmax", + "Solarization", + "SpatialDropout1D", + "SpatialDropout2D", + "SpatialDropout3D", + "SpectralNormalization", + "StackedRNNCells", + "StringLookup", + "Subtract", + "TFSMLayer", + "TextVectorization", + "TimeDistributed", + "TorchModuleWrapper", + "UnitNormalization", + "UpSampling1D", + "UpSampling2D", + "UpSampling3D", + "Wrapper", + "ZeroPadding1D", + "ZeroPadding2D", + "ZeroPadding3D", + } +) diff --git a/modelaudit/detectors/suspicious_symbols.py b/modelaudit/detectors/suspicious_symbols.py index 9deabc15..da4f5e13 100644 --- a/modelaudit/detectors/suspicious_symbols.py +++ b/modelaudit/detectors/suspicious_symbols.py @@ -55,6 +55,8 @@ from typing import Any +from modelaudit.config.generated_keras_layers import GENERATED_KNOWN_SAFE_KERAS_LAYER_CLASSES + from ..config.explanations import DANGEROUS_OPCODES as _EXPLAIN_OPCODES # OS module aliases that provide system access similar to the 'os' module @@ -669,153 +671,10 @@ # without custom code execution (Sequential, Functional, Model) KNOWN_SAFE_MODEL_CLASSES: set[str] = {"Sequential", "Functional", "Model"} -# Known safe Keras layer class names (standard built-in layers). -# Any layer class_name NOT in this set or SUSPICIOUS_LAYER_TYPES is treated as -# a custom/unknown layer that warrants attention. -KNOWN_SAFE_KERAS_LAYER_CLASSES: frozenset[str] = frozenset( - { - # Input - "InputLayer", - "Input", - # Core - "Dense", - "Activation", - "Embedding", - "Masking", - "Flatten", - "Reshape", - "Permute", - "RepeatVector", - "Identity", - "EinsumDense", - # Activations (added) - "ReLU", - "Softmax", - "LeakyReLU", - "PReLU", - "ELU", - # Convolutional - "Conv1D", - "Conv2D", - "Conv3D", - "SeparableConv1D", - "SeparableConv2D", - "DepthwiseConv1D", - "DepthwiseConv2D", - "Conv1DTranspose", - "Conv2DTranspose", - "Conv3DTranspose", - # Pooling - "MaxPooling1D", - "MaxPooling2D", - "MaxPooling3D", - "AveragePooling1D", - "AveragePooling2D", - "AveragePooling3D", - "GlobalMaxPooling1D", - "GlobalMaxPooling2D", - "GlobalMaxPooling3D", - "GlobalAveragePooling1D", - "GlobalAveragePooling2D", - "GlobalAveragePooling3D", - "MaxPool1D", - "MaxPool2D", - "MaxPool3D", - "AvgPool1D", - "AvgPool2D", - "AvgPool3D", - "GlobalMaxPool1D", - "GlobalMaxPool2D", - "GlobalMaxPool3D", - "GlobalAvgPool1D", - "GlobalAvgPool2D", - "GlobalAvgPool3D", - # RNN - "SimpleRNN", - "LSTM", - "GRU", - "ConvLSTM1D", - "ConvLSTM2D", - "ConvLSTM3D", - "SimpleRNNCell", - "LSTMCell", - "GRUCell", - "StackedRNNCells", - "Bidirectional", - "TimeDistributed", - "RNN", - # Normalization - "BatchNormalization", - "LayerNormalization", - "GroupNormalization", - "UnitNormalization", - "SpectralNormalization", - # Regularization - "Dropout", - "SpatialDropout1D", - "SpatialDropout2D", - "SpatialDropout3D", - "GaussianNoise", - "GaussianDropout", - "AlphaDropout", - "ActivityRegularization", - # Attention - "MultiHeadAttention", - "Attention", - "AdditiveAttention", - # Merging - "Add", - "Subtract", - "Multiply", - "Average", - "Maximum", - "Minimum", - "Concatenate", - "Dot", - # Padding/Cropping - "ZeroPadding1D", - "ZeroPadding2D", - "ZeroPadding3D", - "Cropping1D", - "Cropping2D", - "Cropping3D", - # Upsampling - "UpSampling1D", - "UpSampling2D", - "UpSampling3D", - # Preprocessing - "Rescaling", - "Resizing", - "CenterCrop", - "RandomFlip", - "RandomRotation", - "RandomZoom", - "RandomCrop", - "RandomTranslation", - "RandomContrast", - "RandomBrightness", - "RandomHeight", - "RandomWidth", - "Normalization", - "Discretization", - "CategoryEncoding", - "Hashing", - "HashedCrossing", - "StringLookup", - "IntegerLookup", - "TextVectorization", - # TF-specific - "TFSMLayer", - # Wrapper - "Wrapper", - # Model classes (nested models in configs) - "Sequential", - "Functional", - "Model", - # Keras DType - "DTypePolicy", - } -) +# Known safe Keras layer class names derived from Keras public `keras.layers.*` +# exports. Refresh the generated inventory with +# `scripts/generate_keras_layer_inventory.py` when Keras adds new public layers. +KNOWN_SAFE_KERAS_LAYER_CLASSES: frozenset[str] = GENERATED_KNOWN_SAFE_KERAS_LAYER_CLASSES # Known standard Keras loss function names (string identifiers and class names). # Used to detect custom/unknown loss functions in training_config. diff --git a/scripts/generate_keras_layer_inventory.py b/scripts/generate_keras_layer_inventory.py new file mode 100644 index 00000000..87c17a43 --- /dev/null +++ b/scripts/generate_keras_layer_inventory.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Generate the Keras layer inventory used by ModelAudit scanners.""" + +from __future__ import annotations + +import ast +import importlib.util +from pathlib import Path + + +def find_keras_source_root() -> Path: + """Locate the installed Keras source tree without importing Keras.""" + spec = importlib.util.find_spec("keras") + if spec is None or not spec.submodule_search_locations: + raise RuntimeError("Could not locate an installed 'keras' package") + + package_root = Path(next(iter(spec.submodule_search_locations))) + source_root = package_root / "src" + if not source_root.is_dir(): + raise RuntimeError(f"Expected Keras source tree at {source_root}") + return source_root + + +def exported_keras_layer_classes(source_root: Path) -> list[str]: + """Return public `keras.layers.*` class exports from the Keras source tree.""" + exported_names: set[str] = {"Functional", "Model", "Sequential"} + + for path in source_root.rglob("*.py"): + try: + tree = ast.parse(path.read_text()) + except (OSError, SyntaxError, UnicodeDecodeError): + continue + + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef): + continue + for decorator in node.decorator_list: + if not isinstance(decorator, ast.Call): + continue + + func = decorator.func + if isinstance(func, ast.Name): + func_name = func.id + elif isinstance(func, ast.Attribute): + func_name = func.attr + else: + func_name = None + if func_name != "keras_export": + continue + + for argument in decorator.args: + values: list[str] = [] + if isinstance(argument, ast.Constant) and isinstance(argument.value, str): + values = [argument.value] + elif isinstance(argument, (ast.List, ast.Tuple)): + values = [ + item.value + for item in argument.elts + if isinstance(item, ast.Constant) and isinstance(item.value, str) + ] + + for value in values: + if value.startswith("keras.layers."): + exported_names.add(value.split(".")[-1]) + + return sorted(exported_names) + + +def render_module(layer_names: list[str]) -> str: + """Render the generated Python module.""" + lines = [ + '"""Generated Keras layer inventory.', + "", + "Regenerate this file with `scripts/generate_keras_layer_inventory.py` after", + "updating the Keras reference version used for scanner maintenance.", + '"""', + "", + "GENERATED_KNOWN_SAFE_KERAS_LAYER_CLASSES: frozenset[str] = frozenset(", + " {", + ] + lines.extend(f' "{name}",' for name in layer_names) + lines.extend( + [ + " }", + ")", + "", + ] + ) + return "\n".join(lines) + + +def main() -> None: + source_root = find_keras_source_root() + layer_names = exported_keras_layer_classes(source_root) + target = Path(__file__).resolve().parent.parent / "modelaudit" / "config" / "generated_keras_layers.py" + target.write_text(render_module(layer_names)) + print(f"Wrote {len(layer_names)} layer names to {target}") + + +if __name__ == "__main__": + main() diff --git a/tests/scanners/test_keras_h5_scanner.py b/tests/scanners/test_keras_h5_scanner.py index e8905865..234b6fd0 100644 --- a/tests/scanners/test_keras_h5_scanner.py +++ b/tests/scanners/test_keras_h5_scanner.py @@ -522,15 +522,15 @@ def test_training_config_safe_aliases_do_not_trigger_custom_object_checks(tmp_pa def test_builtin_random_preprocessing_layers_do_not_trigger_custom_layer_warning(tmp_path: Path) -> None: - """Built-in RandomWidth/RandomHeight preprocessing layers should stay allowlisted.""" + """Built-in preprocessing layers should not be mislabeled as custom layers.""" model_config = { "class_name": "Sequential", "config": { "name": "preprocessing_model", "layers": [ {"class_name": "InputLayer", "config": {"batch_shape": [None, 32, 32, 3], "name": "input"}}, - {"class_name": "RandomWidth", "config": {"factor": 0.1}}, - {"class_name": "RandomHeight", "config": {"factor": 0.1}}, + {"class_name": "RandomShear", "config": {"factor": 0.1}}, + {"class_name": "RandomColorJitter", "config": {"value_range": [0, 255], "brightness_factor": 0.1}}, ], }, } diff --git a/tests/scanners/test_keras_zip_scanner.py b/tests/scanners/test_keras_zip_scanner.py index 034ef335..9a4089a8 100644 --- a/tests/scanners/test_keras_zip_scanner.py +++ b/tests/scanners/test_keras_zip_scanner.py @@ -518,10 +518,15 @@ def test_compile_config_safe_aliases_and_builtin_layers_do_not_false_positive(se "config": {"batch_shape": [None, 32, 32, 3]}, }, { - "class_name": "RandomWidth", - "name": "random_width", + "class_name": "RandomShear", + "name": "random_shear", "config": {"factor": 0.1}, }, + { + "class_name": "RandomColorJitter", + "name": "random_color_jitter", + "config": {"value_range": [0, 255], "brightness_factor": 0.1}, + }, ] }, "compile_config": {