diff --git a/.gitmodules b/.gitmodules index e69de29..aa121bc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "DeepQuant"] + path = DeepQuant + url = https://github.com/JanCSEM/DeepQuant.git diff --git a/DeepQuant b/DeepQuant new file mode 160000 index 0000000..7ec8705 --- /dev/null +++ b/DeepQuant @@ -0,0 +1 @@ +Subproject commit 7ec87052165be8dac15394d6df39443bac427a62 diff --git a/Onnx4Deeploy.py b/Onnx4Deeploy.py index 9b1b9be..565131c 100644 --- a/Onnx4Deeploy.py +++ b/Onnx4Deeploy.py @@ -34,6 +34,7 @@ def list_available_models(): CCTExporter, EpiDeNetExporter, LightweightCnnExporter, + QLiteCnnExporter, MambaExporter, MIBMInetExporter, MobileNetV2Exporter, @@ -41,6 +42,7 @@ def list_available_models(): ResNetExporter, SimpleMlpExporter, SleepConViTExporter, + QSleepConViTExporter ) models = { @@ -140,6 +142,18 @@ def list_available_models(): "input_shape": "(B, 1, 28, 28)", "classes": 10, }, + "QLiteCNN": { + "class": QLiteCnnExporter, + "description": "QLite CNN (Compact CNN for image classification)", + "input_shape": "(B, 1, 28, 28)", + "classes": 10, + }, + "QSleepConViT": { + "class": QSleepConViTExporter, + "description": "QLite SleepConViT (Quantized Vision Transformer for Sleep Stage Classification)", + "input_shape": "(B, 1, 3000)", + "classes": 5 + } } return models @@ -155,6 +169,7 @@ def list_available_operators(): # Matrix operations "Gemm": "General matrix multiplication", "MatMul": "Matrix multiplication", + "Conv2d": "2D convolution", # Pooling "MaxPool": "Max pooling", "AveragePool": "Average pooling", @@ -169,10 +184,20 @@ def list_available_operators(): "ConvGradX": "Convolution input gradient", "ConvGradW": "Convolution weight gradient", "ConvGradB": "Convolution bias gradient", + # ZO + "PerturbNormal": "Perturb input with gaussian random noise", + "PerturbUniform": "Perturb input with uniform random noise", + "PerturbTriangle": "Perturb input with triangle random noise", + "PerturbRademacher": "Perturb input with Rademacher random noise", + "PerturbEggroll": "Perturb input with Eggroll random noise", + "RQSPerturbrademacher": "Perturb input with quantized Rademacher random noise", + "RQSPerturbUniform": "Perturb input with quantized Uniform random noise", + # Others "ReduceSum": "Sum reduction", "SoftmaxCrossEntropy": "Softmax cross entropy", "ReluGrad": "ReLU gradient", + } return operators @@ -248,7 +273,7 @@ def generate_operator(operator_name: str, output_path: Optional[str] = None): sys.exit(1) -def generate_model(model_name: str, mode: str, output_path: Optional[str] = None): +def generate_model(model_name: str, mode: str, output_path: Optional[str] = None, noise_type: str = "gaussian"): """Generate model ONNX""" print(f"\n{'='*70}") print(f"๐Ÿš€ Generating model: {model_name} ({mode.upper()} mode)") @@ -292,12 +317,21 @@ def generate_model(model_name: str, mode: str, output_path: Optional[str] = None if mode == "infer": onnx_file = exporter.export_inference() mode_desc = "Inference mode" + elif mode == "q-infer": + onnx_file = exporter.export_inference(quant=True) + mode_desc = "Quantized Inference mode" elif mode == "train": onnx_file = exporter.export_training() mode_desc = "Training mode" + elif mode == "zo-train": + onnx_file = exporter.export_zo_training(noise_type=noise_type) + mode_desc = "Zeroth-order Training mode" + elif mode == "q-zo-train": + onnx_file = exporter.export_zo_training(noise_type=noise_type, quant=True) + mode_desc = "Quantized Zeroth-order Training mode" else: print(f"โŒ Unknown mode: {mode}") - print(" Available modes: infer, train") + print(" Available modes: infer, train, zo-train, q-infer, q-zo-train") sys.exit(1) print(f"\n{'='*70}") @@ -311,7 +345,9 @@ def generate_model(model_name: str, mode: str, output_path: Optional[str] = None files_to_check = ["network.onnx", "inputs.npz", "outputs.npz"] if mode == "train": files_to_check.extend(["network_train.onnx", "optimizer_model.onnx"]) - + elif mode in ["zo-train", "q-zo-train"]: + files_to_check.append("network_zo.onnx") + for file in files_to_check: file_path = output_dir / file if file_path.exists(): @@ -406,9 +442,9 @@ def main(): "-mode", "--mode", type=str, - choices=["infer", "train"], + choices=["infer", "train", "zo-train", "q-infer", "q-zo-train"], default="infer", - help="Model export mode: infer (inference) or train (training) [default: infer]", + help="Model export mode: infer (inference), train (BP training), zo-train (zeroth-order training), q-infer (quantized inference), or q-zo-train (quantized zeroth-order training) [default: infer]", ) # Output path @@ -429,7 +465,8 @@ def main(): # Other options parser.add_argument("--examples", action="store_true", help="Show usage examples") - + parser.add_argument("--noise-type", type=str, choices=["gaussian", "uniform", "triangle", "rademacher", "eggroll", "rqs_rademacher", "rqs_uniform"], + default="gaussian", help="Noise type for perturbation operators [default: gaussian]") # Parse arguments args = parser.parse_args() @@ -472,7 +509,7 @@ def main(): if args.operator: generate_operator(args.operator, args.output) elif args.model: - generate_model(args.model, args.mode, args.output) + generate_model(args.model, args.mode, args.output, args.noise_type) if __name__ == "__main__": diff --git a/gen_noise_tests.sh b/gen_noise_tests.sh new file mode 100755 index 0000000..67a2a75 --- /dev/null +++ b/gen_noise_tests.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +python3 Onnx4Deeploy.py -operator PerturbNormal -o PerturbNormal +python3 Onnx4Deeploy.py -operator PerturbUniform -o PerturbUniform +python3 Onnx4Deeploy.py -operator PerturbRademacher -o PerturbRademacher +python3 Onnx4Deeploy.py -operator PerturbTriangle -o PerturbTriangle +python3 Onnx4Deeploy.py -operator PerturbEggroll -o PerturbEggroll diff --git a/gen_zo_model_tests.sh b/gen_zo_model_tests.sh new file mode 100755 index 0000000..76bb7c9 --- /dev/null +++ b/gen_zo_model_tests.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# LiteCNN +python3 Onnx4Deeploy.py -model LightweightCNN -mode infer -o LiteCNN +python3 Onnx4Deeploy.py -model LightweightCNN -mode zo-train -o LiteCNN-Rad --noise-type rademacher +python3 Onnx4Deeploy.py -model LightweightCNN -mode zo-train -o LiteCNN-Lorp --noise-type eggroll +python3 Onnx4Deeploy.py -model LightweightCNN -mode zo-train -o LiteCNN-Uniform --noise-type uniform +python3 Onnx4Deeploy.py -model LightweightCNN -mode zo-train -o LiteCNN-Gaussian --noise-type gaussian + +# QLiteCNN +python3 Onnx4Deeploy.py -model QLiteCNN -mode q-infer -o QLiteCNN +python3 Onnx4Deeploy.py -model QLiteCNN -mode q-zo-train -o QLiteCNN-Rad --noise-type rademacher +python3 Onnx4Deeploy.py -model QLiteCNN -mode q-zo-train -o QLiteCNN-Lorp --noise-type eggroll +python3 Onnx4Deeploy.py -model QLiteCNN -mode q-zo-train -o QLiteCNN-Uniform --noise-type uniform +python3 Onnx4Deeploy.py -model QLiteCNN -mode q-zo-train -o QLiteCNN-Gaussian --noise-type gaussian \ No newline at end of file diff --git a/onnx4deeploy/core/base_exporter.py b/onnx4deeploy/core/base_exporter.py index 3e14cb0..9ed5cc5 100644 --- a/onnx4deeploy/core/base_exporter.py +++ b/onnx4deeploy/core/base_exporter.py @@ -25,7 +25,11 @@ from onnx import helper from onnxruntime.training import artifacts +from DeepQuant.Export4Deeploy import exportBrevitas + from .onnx_utils import print_model_info, randomize_onnx_initializers +from onnx4deeploy.transform.quant_transform import fix_duplicate_tensor_names +from onnx4deeploy.transform.zo_transform import generate_weight_update_graph, generate_zo_graph class ExportMode(Enum): @@ -33,6 +37,7 @@ class ExportMode(Enum): TRAINING = "train" INFERENCE = "infer" + ZO_TRAINING = "zo-train" class BaseONNXExporter(ABC): @@ -250,6 +255,15 @@ def setup_paths(self, mode: ExportMode) -> Dict[str, str]: } ) + if mode == ExportMode.ZO_TRAINING: + paths.update( + { + "network_infer": os.path.join(output_dir, "network_infer.onnx"), + "network_zo_train": os.path.join(output_dir, "network_zo_train.onnx"), + "network_zo_update": os.path.join(output_dir, "network_zo_update.onnx"), + } + ) + return paths def _get_config_string(self) -> str: @@ -294,7 +308,7 @@ def _export_to_onnx( onnx_model = onnx.load_model_from_string(f.getvalue()) return onnx_model - def export_inference(self, save_path: Optional[str] = None) -> str: + def export_inference(self, save_path: Optional[str] = None, quant: bool = False) -> str: """ Export model in inference mode. @@ -325,10 +339,25 @@ def export_inference(self, save_path: Optional[str] = None) -> str: input_tensor = torch.randn(*input_shape, dtype=torch.float32) print(f" Input shape: {input_shape}") - # Export to ONNX - print("\n๐Ÿ“ค Exporting to ONNX...") - opset_version = self.config.get("opset_version", 12) - onnx_model = self._export_to_onnx(model, input_tensor, opset_version) + if not quant: + # initialize weights and biases for testing + for name, param in model.named_parameters(): + if param.requires_grad: + torch.nn.init.normal_(param, mean=0.0, std=0.02) + # Export to ONNX + print("\n๐Ÿ“ค Exporting to ONNX...") + opset_version = self.config.get("opset_version", 12) + onnx_model = self._export_to_onnx(model, input_tensor, opset_version) + elif quant: + #load weights. + state_dict = torch.load(os.path.join(os.getcwd(), self.config.get("weights_path", "model_weights.pth")), map_location="cpu") + model.load_state_dict(state_dict, strict=False) + #Jansno: temporary workaround + for name, param in model.named_parameters(): + if param.requires_grad and "bias" in name: + if torch.all(param.data) == 0: + torch.nn.init.uniform_(param, a=0.01, b=0.02) + onnx_model = exportBrevitas(model, input_tensor, debug=False) # Save onnx.save(onnx_model, self.paths["network"]) @@ -344,6 +373,14 @@ def export_inference(self, save_path: Optional[str] = None) -> str: infer_shapes_with_custom_ops(self.paths["network"], self.paths["network"]) + # Fix duplicate initializer/node-output names introduced by + # onnx.save's write_external_data_tensors on gs-exported models. + if quant: + _m = onnx.load(self.paths["network"]) + _m = fix_duplicate_tensor_names(_m) + with open(self.paths["network"], "wb") as _f: + _f.write(_m.SerializeToString()) + # Save test input/output data if method is implemented if hasattr(self, "save_test_data"): try: @@ -485,43 +522,212 @@ def export_training(self, save_path: Optional[str] = None) -> str: return self.paths["network"] - def _create_test_data(self): + def export_zo_training(self, save_path: Optional[str] = None, noise_type: str = "gaussian", quant: bool = False) -> str: + """ + Export model in zeroth-order training mode. + + Args: + save_path: Optional custom save path + noise_type: Type of noise to use for perturbation + quant: Whether to apply quantization """ + if save_path: + self.save_path = save_path + + # Load configuration + self.config = self.load_config() + self.paths = self.setup_paths(ExportMode.ZO_TRAINING) + + print(f"\n{'='*60}") + print(f"๐Ÿš€ Exporting {self.get_model_name()} to ONNX (Zeroth-Order Training Mode)") + print(f"{'='*60}\n") + + # Create PyTorch model + print("๐Ÿ“ฆ Creating PyTorch model...") + model = self.create_model() + model.eval() # Zeroth-Order Training mode + + # Generate input + input_shape = self.get_input_shape() + input_tensor = torch.randn(*input_shape, dtype=torch.float32) + print(f" Input shape: {input_shape}") + + # Export to ONNX + if not quant: + # initialize weights and biases for testing + for name, param in model.named_parameters(): + if param.requires_grad: + torch.nn.init.normal_(param, mean=0.0, std=0.02) + # Export to ONNX + print("\n๐Ÿ“ค Exporting to ONNX...") + opset_version = self.config.get("opset_version", 12) + onnx_model = self._export_to_onnx(model, input_tensor, opset_version) + elif quant: + #load weights. + state_dict = torch.load(os.path.join(os.getcwd(), self.config.get("weights_path", "model_weights.pth")), map_location="cpu") + model.load_state_dict(state_dict, strict=False) + print("\n๐Ÿ“ค Exporting to ONNX with quantization...") + # JanSno: Temporary workaround + for name, param in model.named_parameters(): + if param.requires_grad and "bias" in name: + if torch.all(param.data) == 0: + torch.nn.init.uniform_(param, a=0.01, b=0.02) + # use DeepQuant to export to ONNX + onnx_model = exportBrevitas(model, input_tensor, debug=False) + onnx_model = fix_duplicate_tensor_names(onnx_model) + + # Save + onnx.save(onnx_model, self.paths["network_infer"]) + print(f"โœ… ONNX model saved: {self.paths['network_infer']}") + + # Run inference optimizations + print("\n๐Ÿ”ง Running inference optimizations...") + self.run_inference_optimization(self.paths["network_infer"], self.paths["network_infer"]) + + # Run shape inference + print("\n๐Ÿ” Running shape inference...") + from ..optimization.shape_optimizer import infer_shapes_with_custom_ops + + infer_shapes_with_custom_ops(self.paths["network_infer"], self.paths["network_infer"]) + + # Save test input/output data if method is implemented + if hasattr(self, "save_test_data"): + try: + self.save_test_data(model, self.paths["output_dir"]) + except Exception as e: + print(f"โš ๏ธ Failed to save test data: {e}") + + # Reload optimized model + onnx_model = onnx.load(self.paths["network_infer"]) + print_model_info(self.paths["network_infer"]) + + # Get trainable parameters + all_param_names = [init.name for init in onnx_model.graph.initializer] + requires_grad = self.get_trainable_params(all_param_names) + frozen_params = [name for name in all_param_names if name not in requires_grad] + + print(f"\n๐Ÿ”น Trainable parameters: {len(requires_grad)}") + print(f"๐Ÿ”น Frozen parameters: {len(frozen_params)}") + + # Transform model for zeroth-order training (e.g., add noise nodes, modify outputs) + print("\n๐Ÿ”ง Transforming model for zeroth-order training...") + # Randomize initializers for testing + if not quant: + onnx_model = randomize_onnx_initializers(onnx_model) + generate_zo_graph( + inference_onnx=self.paths["network_infer"], + output_onnx=self.paths["network_zo_train"], + zo_config=self.config["zo"], + noise_type=noise_type, + ) + + generate_weight_update_graph( + onnx_path=self.paths["network_infer"], + output_path=self.paths["network_zo_update"], + zo_config=self.config["zo"], + noise_type=noise_type) + + # # Load training model and add gradient outputs + # onnx_model = onnx.load(self.paths["network_train"]) + # graph = onnx_model.graph + # grad_tensor_names = [name + "_grad" for name in requires_grad] + + # for grad_name in grad_tensor_names: + # if not any(output.name == grad_name for output in graph.output): + # grad_output = helper.make_tensor_value_info(grad_name, onnx.TensorProto.FLOAT, None) + # graph.output.append(grad_output) + + # # Save with gradient outputs + # onnx.save(onnx_model, self.paths["network_train_optim"]) + # onnx.save(onnx_model, self.paths["network_train"]) + + # Run shape inference for training model (handles Microsoft custom ops) + print("\n๐Ÿ” Running shape inference...") + from ..optimization.shape_optimizer import infer_shapes_with_custom_ops + infer_shapes_with_custom_ops( + self.paths["network_zo_train"] + ) + + # # Run training-specific optimizations + # print("\n๐Ÿ”ง Running training optimizations...") + # self.run_training_optimization(self.paths["network_train_optim"], self.paths["network"]) + + # # Save pre-SGD model + # shutil.copy(self.paths["network"], self.paths["network_pre_sgd"]) + # print(f"โœ… Pre-SGD model saved: {self.paths['network_pre_sgd']}") + + # Create test input/output + print("\n๐Ÿงช Creating test input/output...") + self._create_test_data(mode=ExportMode.ZO_TRAINING, quant=quant) + + # # Add optimizer (SGD) nodes + # print("\nโž• Adding SGD optimizer nodes...") + # self._add_optimizer_nodes() + + print(f"\n{'='*60}") + print("โœ… Export Complete!") + print(f" Final model: {self.paths['network']}") + print(f"{'='*60}\n") + + def _create_test_data(self, mode=ExportMode.INFERENCE, quant=False): """ Create test input/output data for training. - Uses ONNX Runtime to generate reference output from the (potentially randomized) ONNX model. - This ensures test data matches the actual ONNX model weights. + For standard inference mode, uses ONNX Runtime to generate reference + output. For modes that involve custom ops (q-infer, zo-train, + q-zo-train) the pure-Python ``run_onnx_graph`` executor is used instead + so that Quant/Dequant/RequantShift and MeZO perturbation nodes can be + executed without a custom-op shared library. """ - # Generate test data using ONNX Runtime try: from pathlib import Path import numpy as np - import onnxruntime as ort print("๐Ÿ’พ Generating test input/output data from ONNX model...") - # Create test input input_shape = self.get_input_shape() test_input = np.random.randn(*input_shape).astype(np.float32) - # Run ONNX inference to get output - session = ort.InferenceSession(self.paths["network_infer"]) - input_name = session.get_inputs()[0].name - test_output = session.run(None, {input_name: test_input})[0] + use_pure_python = quant or (mode == ExportMode.ZO_TRAINING) + + if use_pure_python: + from onnx4deeploy.utils.onnx_node_implementations import run_onnx_graph + print(" Using pure-Python ONNX executor (custom ops present)...") + test_output = run_onnx_graph( + self.paths["network_infer"], {"input": test_input} + ) + if not isinstance(test_output, np.ndarray): + test_output = np.array(test_output, dtype=np.float32) + else: + import onnxruntime as ort + from onnxruntime_extensions import get_library_path + sess_options = ort.SessionOptions() + sess_options.register_custom_ops_library(get_library_path()) + session = ort.InferenceSession( + self.paths["network_infer"], sess_options=sess_options + ) + input_name = session.get_inputs()[0].name + test_output = session.run(None, {input_name: test_input})[0] # Save as .npz files save_path = Path(self.paths["output_dir"]) save_path.mkdir(parents=True, exist_ok=True) - np.savez(save_path / "inputs.npz", input=test_input) + if mode == ExportMode.ZO_TRAINING: + test_label = np.random.randint( + 0, self.config["num_classes"], size=(input_shape[0], 1) + ).astype(np.int8) + print(f" Generated test labels with shape: {test_label.shape}") + np.savez(save_path / "inputs.npz", input=test_input, label=test_label) + else: + np.savez(save_path / "inputs.npz", input=test_input) np.savez(save_path / "outputs.npz", output=test_output) print(" โœ… Saved test data (ONNX reference):") print(f" Input: {save_path / 'inputs.npz'} shape={test_input.shape}") print(f" Output: {save_path / 'outputs.npz'} shape={test_output.shape}") except Exception as e: - print(f"โš ๏ธ Failed to create test data: {e}") + raise RuntimeError(f"Failed to create test data: {e}") def _add_optimizer_nodes(self): """ @@ -546,5 +752,11 @@ def export(self, mode: str = "train", save_path: Optional[str] = None) -> str: return self.export_training(save_path) elif mode == "infer": return self.export_inference(save_path) + elif mode == "zo-train": + return self.export_zo_training(save_path) + elif mode =="q-infer": + return self.export_inference(save_path, quant=True) + elif mode == "q-zo-train": + return self.export_zo_training(save_path, quant=True) else: raise ValueError(f"Invalid mode: {mode}. Must be 'train' or 'infer'") diff --git a/onnx4deeploy/core/onnx_utils.py b/onnx4deeploy/core/onnx_utils.py index 53d14d5..1083825 100644 --- a/onnx4deeploy/core/onnx_utils.py +++ b/onnx4deeploy/core/onnx_utils.py @@ -176,7 +176,12 @@ def randomize_onnx_initializers(onnx_model: onnx.ModelProto) -> onnx.ModelProto: tensor = numpy_helper.to_array(initializer) # Randomize the values - randomized_tensor = np.random.randn(*tensor.shape).astype(tensor.dtype) + print(F"Randomizing initializer '{initializer.name}' with shape {tensor.shape}") + + randomized_tensor = np.random.randn(*tensor.shape) + if not isinstance(randomized_tensor, np.ndarray): + randomized_tensor = np.array(randomized_tensor) + randomized_tensor = randomized_tensor.astype(tensor.dtype) # Update the initializer new_initializer = numpy_helper.from_array(randomized_tensor, initializer.name) diff --git a/onnx4deeploy/io/__init__.py b/onnx4deeploy/io/__init__.py index 7d6d920..7c45a96 100644 --- a/onnx4deeploy/io/__init__.py +++ b/onnx4deeploy/io/__init__.py @@ -10,5 +10,5 @@ __all__ = [ "load_config", "load_train_config", - "compare_onnx_models", + "compare_onnx_models" ] diff --git a/onnx4deeploy/io/config_loader.py b/onnx4deeploy/io/config_loader.py index 9e186d0..6616be1 100644 --- a/onnx4deeploy/io/config_loader.py +++ b/onnx4deeploy/io/config_loader.py @@ -70,4 +70,4 @@ def load_train_config(config_filename: str = "../config.yaml") -> float: with open(config_file, "r") as f: config = yaml.safe_load(f).get("training", {}) - return config.get("learning_rate", 0.01) + return config.get("learning_rate", 0.01) \ No newline at end of file diff --git a/onnx4deeploy/models/__init__.py b/onnx4deeploy/models/__init__.py index 4e3f577..36d9512 100644 --- a/onnx4deeploy/models/__init__.py +++ b/onnx4deeploy/models/__init__.py @@ -7,6 +7,7 @@ from .cct_exporter import CCTExporter from .epidenet_exporter import EpiDeNetExporter from .lightweight_cnn_exporter import LightweightCnnExporter +from .qlite_cnn_exporter import QLiteCnnExporter from .mamba_exporter import MambaExporter from .mibminet_exporter import MIBMInetExporter from .mobilenetv2_exporter import MobileNetV2Exporter @@ -14,11 +15,13 @@ from .resnet_exporter import ResNetExporter from .simple_mlp_exporter import SimpleMlpExporter from .sleep_convit_exporter import SleepConViTExporter +from .qsleep_convit_exporter import QSleepConViTExporter __all__ = [ "CCTExporter", "EpiDeNetExporter", "LightweightCnnExporter", + "QLiteCnnExporter", "MIBMInetExporter", "SimpleMlpExporter", "ResNetExporter", @@ -26,4 +29,5 @@ "MobileViTExporter", "MambaExporter", "SleepConViTExporter", + "QSleepConViTExporter" ] diff --git a/onnx4deeploy/models/lightweight_cnn_exporter.py b/onnx4deeploy/models/lightweight_cnn_exporter.py index aa9f8f6..36d3160 100644 --- a/onnx4deeploy/models/lightweight_cnn_exporter.py +++ b/onnx4deeploy/models/lightweight_cnn_exporter.py @@ -49,6 +49,10 @@ def load_config(self) -> Dict[str, Any]: # Training configuration "training_strategy": "full", # Options: "full", "last_layer", "custom" "custom_trainable_params": [], + "zo": { + "epsilon": 0.1, + "seed": 42 + } } self.model_config = config diff --git a/onnx4deeploy/models/pytorch_models/lightweight_cnn/__init__.py b/onnx4deeploy/models/pytorch_models/lightweight_cnn/__init__.py index 8ebc582..ce05964 100644 --- a/onnx4deeploy/models/pytorch_models/lightweight_cnn/__init__.py +++ b/onnx4deeploy/models/pytorch_models/lightweight_cnn/__init__.py @@ -5,5 +5,6 @@ """Lightweight CNN PyTorch model.""" from .lightweight_cnn import LightweightCNN +from .qlite_cnn import QLiteCNN -__all__ = ["LightweightCNN"] +__all__ = ["LightweightCNN", "QLiteCNN"] diff --git a/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.pth b/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.pth new file mode 100644 index 0000000..908305b Binary files /dev/null and b/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.pth differ diff --git a/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.py b/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.py new file mode 100644 index 0000000..ad3bdff --- /dev/null +++ b/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.py @@ -0,0 +1,121 @@ +import torch.nn as nn +import torch.nn.functional as F +import brevitas.nn as qnn +from brevitas.inject.enum import FloatToIntImplType +from brevitas.quant.scaled_int import ( + Int8ActPerTensorFloat, + Int32Bias, + Int8WeightPerTensorFloat, + Int8WeightPerChannelFloat +) +from brevitas.core.function_wrapper.stochastic_round import StochasticRoundSte + +class StochasticInt8WeightPerChannelFloat(Int8WeightPerChannelFloat): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class StochasticInt8WeightPerTensorFloat(Int8WeightPerTensorFloat): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class StochasticInt8ActPerTensorFloat(Int8ActPerTensorFloat): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class StochasticInt32Bias(Int32Bias): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class QLiteCNN(nn.Module): + def __init__(self, + batch_size: int = 1, + input_channels: int = 1, + num_classes: int = 10, + dropout: float = 0.0):# ignored + + self.batch_size = batch_size + self.input_channels = input_channels + self.num_classes = num_classes + self.fc_channels = 160 # Fixed: 10 * 4 * 4 = 160 + + self.convAndLinQuantParams = { + "bias": True, + "weight_bit_width": 8, + "bias_quant": Int32Bias, + "input_quant": Int8ActPerTensorFloat, + "weight_quant":Int8WeightPerChannelFloat, #no channel wise support in deeploy yet. + #"weight_quant":Int8WeightPerTensorFloat, + "output_quant": None, + "return_quant_tensor": True + } + + self.convAndLinQuantParamsOut = { + "bias": True, + "weight_bit_width": 8, + "bias_quant": Int32Bias, + "input_quant": Int8ActPerTensorFloat, + "weight_quant":Int8WeightPerChannelFloat,# no channel wise support in deeploy yet. + #"weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True + } + super(QLiteCNN, self).__init__() + # Convolutional layers + # self.inputQuant = qnn.QuantIdentity( + # act_quant=Int8ActPerTensorFloat, return_quant_tensor=True) + + self.conv1 = qnn.QuantConv2d( + in_channels=input_channels, + out_channels=20, + kernel_size=(5,5), + **self.convAndLinQuantParams + ) + self.relu1 = qnn.QuantReLU(bit_width=8, return_quant_tensor=True) + self.pool1 = nn.MaxPool2d(kernel_size=2) # Output: (20, 12, 12) + self.conv2 = qnn.QuantConv2d(20, + 10, + kernel_size=(1, 1), + **self.convAndLinQuantParams) + # Output: (10, 12, 12) + self.relu2 = qnn.QuantReLU(bit_width=8, return_quant_tensor=True) + + self.pool2 = nn.MaxPool2d(kernel_size=2) # Output: (10, 6, 6) + + self.conv3 = qnn.QuantConv2d(10, 12, kernel_size=(3, 3), + **self.convAndLinQuantParams) # Output: (12, 4, 4) + + self.relu3 = qnn.QuantReLU(bit_width=8, return_quant_tensor=True) + + self.conv4 = qnn.QuantConv2d(12, 10, kernel_size=(1, 1), + **self.convAndLinQuantParams) + # Output: (10, 4, 4) + self.relu4 = qnn.QuantReLU(bit_width=8, return_quant_tensor=True) + + self.fc = qnn.QuantLinear(self.fc_channels, num_classes, + **self.convAndLinQuantParamsOut) # Output: num_classes + + def forward(self, x): + + # Convolutional layers with ReLU activation and pooling + # compute min and max of input and scale for quantization debugging + # if isinstance(x, torch.Tensor): + # print(f"Input tensor shape: {x.shape}, dtype: {x.dtype}, min: {x.min().item():.4f}, max: {x.max().item():.4f}") + # print(f"After input quantization: shape: {x.shape}, dtype: {x.dtype}, min: {x.min().item():.4f}, max: {x.max().item():.4f}") + # print(f"scale of input quantizer: {self.inputQuant.act_quant.scale().item():.6f}") + # x = self.inputQuant(x) + x = self.conv1(x) + x = self.relu1(x) + + x = self.pool1(x) # Output: (20, 12, 12) + x = self.conv2(x) + x = self.relu2(x) + + x = self.pool2(x) # Output: (10, 6, 6) + x = self.conv3(x) + x = self.relu3(x) + + x = self.conv4(x) # Output: (10, 4, 4) + x = self.relu4(x) + + # Flatten the feature map + x = x.flatten(start_dim=1) # Flatten to (batch_size, 10 * 4 * 4) + # Fully connected layer + x = self.fc(x) + + return x diff --git a/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn_scales.json b/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn_scales.json new file mode 100644 index 0000000..bad43da --- /dev/null +++ b/onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn_scales.json @@ -0,0 +1,508 @@ +{ + "inputQuant.act_quant": 0.007786148693412542, + "inputQuant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv1.weight_quant": [ + [ + [ + [ + 0.0046827648766338825 + ] + ] + ], + [ + [ + [ + 0.004657038953155279 + ] + ] + ], + [ + [ + [ + 0.005794027354568243 + ] + ] + ], + [ + [ + [ + 0.004027717746794224 + ] + ] + ], + [ + [ + [ + 0.005896744318306446 + ] + ] + ], + [ + [ + [ + 0.003534339601173997 + ] + ] + ], + [ + [ + [ + 0.005034759175032377 + ] + ] + ], + [ + [ + [ + 0.006583436857908964 + ] + ] + ], + [ + [ + [ + 0.005040623247623444 + ] + ] + ], + [ + [ + [ + 0.0054032206535339355 + ] + ] + ], + [ + [ + [ + 0.004712222144007683 + ] + ] + ], + [ + [ + [ + 0.0044171675108373165 + ] + ] + ], + [ + [ + [ + 0.0045336028560996056 + ] + ] + ], + [ + [ + [ + 0.005082833115011454 + ] + ] + ], + [ + [ + [ + 0.005390029400587082 + ] + ] + ], + [ + [ + [ + 0.004269406199455261 + ] + ] + ], + [ + [ + [ + 0.0057991985231637955 + ] + ] + ], + [ + [ + [ + 0.005343677010387182 + ] + ] + ], + [ + [ + [ + 0.005501016974449158 + ] + ] + ], + [ + [ + [ + 0.004162007477134466 + ] + ] + ] + ], + "conv1.bias_quant": [ + 3.6067656765226275e-05, + 4.737563358503394e-05, + 4.514355896390043e-05, + 3.0122700991341844e-05, + 4.522434755926952e-05, + 2.542074980738107e-05, + 4.058502963744104e-05, + 5.085647353553213e-05, + 3.5445831599645317e-05, + 4.2773499444592744e-05, + 3.603152072173543e-05, + 2.8624375772778876e-05, + 3.2512234611203894e-05, + 4.047499533044174e-05, + 4.007756433566101e-05, + 3.143427602481097e-05, + 3.283462137915194e-05, + 3.1593885069014505e-05, + 4.189912215224467e-05, + 3.179134728270583e-05 + ], + "conv1.input_quant": 0.007760446518659592, + "conv1.output_quant": 0.02516048587858677, + "conv1.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv1.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "relu1.act_quant": 0.012514653615653515, + "relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv2.weight_quant": [ + [ + [ + [ + 0.0059503596276044846 + ] + ] + ], + [ + [ + [ + 0.004254304803907871 + ] + ] + ], + [ + [ + [ + 0.005004961043596268 + ] + ] + ], + [ + [ + [ + 0.0070042419247329235 + ] + ] + ], + [ + [ + [ + 0.005205494351685047 + ] + ] + ], + [ + [ + [ + 0.005354207940399647 + ] + ] + ], + [ + [ + [ + 0.005926031619310379 + ] + ] + ], + [ + [ + [ + 0.006488564424216747 + ] + ] + ], + [ + [ + [ + 0.0034671735484153032 + ] + ] + ], + [ + [ + [ + 0.0042591276578605175 + ] + ] + ] + ], + "conv2.bias_quant": [ + 0.00015088623331394047, + 0.00011274521966697648, + 0.00010193250636802986, + 0.00015958795847836882, + 0.0001279811403946951, + 0.00014609555364586413, + 0.0001375889842165634, + 0.00012088024959666654, + 0.00010532839951338246, + 8.798576891422272e-05 + ], + "conv2.input_quant": 0.025035124272108078, + "conv2.output_quant": 0.038704611361026764, + "conv2.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv2.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "relu2.act_quant": 0.019238846376538277, + "relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv3.weight_quant": [ + [ + [ + [ + 0.0032192468643188477 + ] + ] + ], + [ + [ + [ + 0.002923605963587761 + ] + ] + ], + [ + [ + [ + 0.0027376932557672262 + ] + ] + ], + [ + [ + [ + 0.003503928892314434 + ] + ] + ], + [ + [ + [ + 0.0031144043896347284 + ] + ] + ], + [ + [ + [ + 0.0034859958104789257 + ] + ] + ], + [ + [ + [ + 0.0034436206333339214 + ] + ] + ], + [ + [ + [ + 0.004039732739329338 + ] + ] + ], + [ + [ + [ + 0.0014960176777094603 + ] + ] + ], + [ + [ + [ + 0.0032225819304585457 + ] + ] + ], + [ + [ + [ + 0.0036653317511081696 + ] + ] + ], + [ + [ + [ + 0.003544930834323168 + ] + ] + ] + ], + "conv3.bias_quant": [ + 0.00012999656610190868, + 0.00014617544366046786, + 0.00011316189193166792, + 0.00014395458856597543, + 0.000142513687023893, + 0.00015837881073821336, + 0.00018707614799495786, + 0.00016153800243046135, + 0.00010428271343698725, + 0.00015711947344243526, + 0.00015028772759251297, + 0.00014441728126257658 + ], + "conv3.input_quant": 0.03846479952335358, + "conv3.output_quant": 0.05587683245539665, + "conv3.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv3.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "relu3.act_quant": 0.027778906747698784, + "relu3.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv4.weight_quant": [ + [ + [ + [ + 0.005106140859425068 + ] + ] + ], + [ + [ + [ + 0.0067635830491781235 + ] + ] + ], + [ + [ + [ + 0.005935069173574448 + ] + ] + ], + [ + [ + [ + 0.005529573652893305 + ] + ] + ], + [ + [ + [ + 0.006021668203175068 + ] + ] + ], + [ + [ + [ + 0.006458999123424292 + ] + ] + ], + [ + [ + [ + 0.004328868351876736 + ] + ] + ], + [ + [ + [ + 0.005672066938132048 + ] + ] + ], + [ + [ + [ + 0.009927690960466862 + ] + ] + ], + [ + [ + [ + 0.003988152835518122 + ] + ] + ] + ], + "conv4.bias_quant": [ + 0.00032779158209450543, + 0.0003805554879363626, + 0.00034476761356927454, + 0.0003091433900408447, + 0.0003567334497347474, + 0.00034780139685608447, + 0.00028347299667075276, + 0.00029056513449177146, + 0.000570900272578001, + 0.00031566276447847486 + ], + "conv4.input_quant": 0.05556188151240349, + "conv4.output_quant": 0.06223675236105919, + "conv4.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv4.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "relu4.act_quant": 0.030943837016820908, + "relu4.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "fc.weight_quant": [ + [ + 0.002200733171775937 + ], + [ + 0.003274584887549281 + ], + [ + 0.0030999528244137764 + ], + [ + 0.003084302879869938 + ], + [ + 0.003228644607588649 + ], + [ + 0.0023938606027513742 + ], + [ + 0.002668095985427499 + ], + [ + 0.003263178514316678 + ], + [ + 0.0037302691489458084 + ], + [ + 0.0028242983389645815 + ] + ], + "fc.bias_quant": [ + 0.0001680418208707124, + 0.00022119912318885326, + 0.00019093586888629943, + 0.0001730432704789564, + 0.0002461412223055959, + 0.00026946672005578876, + 0.00020659832807723433, + 0.00031415317789651453, + 0.0002444065175950527, + 0.00017614015087019652 + ], + "fc.input_quant": 0.061888109892606735, + "fc.output_quant": 0.06395246833562851, + "fc.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "fc.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0 +} \ No newline at end of file diff --git a/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit.pth b/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit.pth new file mode 100755 index 0000000..45c030a Binary files /dev/null and b/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit.pth differ diff --git a/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit.py b/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit.py new file mode 100644 index 0000000..f906e40 --- /dev/null +++ b/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit.py @@ -0,0 +1,367 @@ +""" This is a ViT model fur sleep staging / sleep stage classification """ + +import torch +import torch.nn as nn +import math +from collections import OrderedDict +from functools import partial +import numpy as np +import brevitas.nn as qnn +import torch.nn.functional as F +from brevitas.inject.enum import FloatToIntImplType +from brevitas.quant.scaled_int import ( + Int8ActPerTensorFloat, + Int32Bias, + Int8WeightPerChannelFloat, + Int8WeightPerTensorFloat, + Uint8ActPerTensorFloat) + +from brevitas.quant_tensor import QuantTensor + +# local imports +from DeepQuant.ExportBrevitas import exportBrevitas + + + +class Int8ActStochasticPerTensorFloat(Int8ActPerTensorFloat): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class Int8WeightStochasticPerTensorFloat(Int8WeightPerTensorFloat): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class Int8WeightStochasticPerChannelFloat(Int8WeightPerChannelFloat): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class Int32BiasStochastic(Int32Bias): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +class Uint8StochasticActPerTensorFloat(Uint8ActPerTensorFloat): + float_to_int_impl_type=FloatToIntImplType.STOCHASTIC_ROUND + +convAndLinQuantParamsNoOutputQuant = { + "weight_bit_width": 8, + "bias_quant": Int32Bias, + "input_quant": Int8ActPerTensorFloat, + # "weight_quant": Int8WeightPerChannelFloat, + "weight_quant": Int8WeightPerChannelFloat, + "output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, +} + +convAndLinQuantParams = { + "weight_bit_width": 8, + "bias_quant": Int32Bias, + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerChannelFloat, + "output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, +} + +convAndLinQuantParamsNoInputQuant = { + "weight_bit_width": 8, + "output_bit_width": 8, + "bias_quant": Int32Bias, + "input_quant": None, + "weight_quant": Int8WeightPerChannelFloat, + "output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, +} + +convAndLinQuantParamsNoInputNoOutputQuant = { + "weight_bit_width": 8, + "bias_quant": Int32Bias, + "input_quant": None, + # "weight_quant": Int8WeightPerChannelFloat, + "weight_quant": Int8WeightPerChannelFloat, + "output_quant": None, + "return_quant_tensor": False, +} + +actQuantParams = { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8 +} + + +mhaQuantParams = { + "in_proj_input_quant":Int8ActPerTensorFloat, + "in_proj_weight_quant":Int8WeightPerChannelFloat, + "in_proj_bias_quant":Int32Bias, + "attn_output_weights_quant":Uint8ActPerTensorFloat, + "q_scaled_quant":Int8ActPerTensorFloat, + "k_transposed_quant":Int8ActPerTensorFloat, + "v_quant":Int8ActPerTensorFloat, + "out_proj_input_quant":Int8ActPerTensorFloat, + "out_proj_weight_quant":Int8WeightPerChannelFloat, + "out_proj_bias_quant":Int32Bias, + "out_proj_output_quant":Int8ActPerTensorFloat, + "return_quant_tensor":True +} + +class MLPHead(nn.Module): + def __init__(self, dim, hidden_dim, dropout_rate=0.0): + super(MLPHead, self).__init__() + self.ff1 = qnn.QuantLinear(dim, hidden_dim, **convAndLinQuantParams) + self.activation = F.gelu + #self.dropout1 = nn.Dropout(p=dropout_rate) + self.ff2 = qnn.QuantLinear(hidden_dim, dim, **convAndLinQuantParams) + #self.dropout2 = nn.Dropout(p=dropout_rate) + + def forward(self, x): + # input dim = encoder dim = 48 + x = self.ff1(x) + x = self.activation(x) + #x = self.dropout1(x) + x = self.ff2(x) + #x = self.dropout2(x) + return x + +class Encoder(nn.Module): + """ + Transformer Encoder block with multi-head attention and feedforward network. + + Args: + embed_dim: Embedding dimension + num_heads: Number of attention heads + seq_len: Fixed sequence length + batch_size: Fixed batch size + att_dropout: Attention dropout rate (ignored in deploy version) + mlp_head_hidden_dim: Hidden dimension for MLP head + mlp_head_dropout: MLP dropout rate (ignored in deploy version) + """ + def __init__(self, + embed_dim, + nheads, + att_dropout, + mlp_head_hidden_dim, + mlp_head_dropout): + super(Encoder, self).__init__() + self.ln_1 = nn.LayerNorm(embed_dim) + self.mha = qnn.QuantMultiheadAttention(embed_dim, + nheads, + dropout=att_dropout, + batch_first=True, + packed_in_proj=False, + **mhaQuantParams) + self.ln_2 = nn.LayerNorm(embed_dim) + self.ff = MLPHead(embed_dim, + hidden_dim=mlp_head_hidden_dim, + dropout_rate=mlp_head_dropout) + self.rescale_residual1 = qnn.QuantIdentity(**actQuantParams) + self.rescale_residual2 = qnn.QuantIdentity(**actQuantParams) + # self.residual1 = qnn.QuantEltwiseAdd(**actQuantParams) + # self.residual2 = qnn.QuantEltwiseAdd(**actQuantParams) + + + def forward(self, x): + # save x for residual + _x = x + x = self.ln_1(x) + x, _ = self.mha(x, x, x, need_weights=False) + # apply residual + x = self.rescale_residual1(x) + _x = self.rescale_residual1(_x) + x = x + _x + _x = x + x = self.ln_2(x) + x = self.ff(x) + # apply residual + x = self.rescale_residual2(x) + _x = self.rescale_residual2(_x) + x = x + _x + return x + +class ConvStem(nn.Module): + def __init__(self, in_channels=1, out_channels=48, kernel_sizes=(25, 200, 100), stride=4, pool_kernel=4): + """ + CNN branch for multi-scale feature extraction. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Total number of output channels across all branches. + kernel_sizes (tuple): Kernel sizes for the three branches. + stride (int): Stride for downsampling in convolutional layers. + pool_kernel (int): Kernel size for pooling layers. + """ + super(ConvStem, self).__init__() + branch_out_channels = out_channels // 3 + + # Branch 1: Kernel size 25 + self.branch1 = nn.Sequential( + qnn.QuantConv2d( + in_channels=in_channels, + out_channels=branch_out_channels, + kernel_size=(1, kernel_sizes[2]), # (height, width) + stride=(1, stride), + padding=(0, kernel_sizes[2] // 2), + bias=True, + **convAndLinQuantParamsNoOutputQuant), + qnn.QuantReLU(bit_width=8, return_quant_tensor=True), + nn.MaxPool2d(kernel_size=(1, pool_kernel), stride=(1, pool_kernel)), + qnn.QuantConv2d( + in_channels=branch_out_channels, + out_channels=branch_out_channels, + kernel_size=(1, 3), + stride=(1, 2), + padding=(0, 1), + bias=True, + **convAndLinQuantParamsNoOutputQuant), + qnn.QuantReLU(bit_width=8, return_quant_tensor=True) + ) + + # Branch 2: Kernel size 200 + self.branch2 = nn.Sequential( + qnn.QuantConv2d( + in_channels=in_channels, + out_channels=branch_out_channels, + kernel_size=(1, kernel_sizes[1]), # (height, width) + stride=(1, stride), + padding=(0, kernel_sizes[1] // 2), + bias=True, + **convAndLinQuantParamsNoOutputQuant), + qnn.QuantReLU(bit_width=8, return_quant_tensor=True), + nn.MaxPool2d(kernel_size=(1, pool_kernel), stride=(1, pool_kernel)), + qnn.QuantConv2d( + in_channels=branch_out_channels, + out_channels=branch_out_channels, + kernel_size=(1, 3), + stride=(1, 2), + padding=(0, 1), + bias=True, + **convAndLinQuantParamsNoOutputQuant), + qnn.QuantReLU(bit_width=8, return_quant_tensor=True) + ) + + + # Branch 3: Kernel size 100 + self.branch3 = nn.Sequential( + qnn.QuantConv2d( + in_channels=in_channels, + out_channels=branch_out_channels, + kernel_size=(1, kernel_sizes[2]), # (height, width) + stride=(1, stride), + padding=(0, kernel_sizes[2] // 2), + bias=True, + **convAndLinQuantParamsNoOutputQuant), + qnn.QuantReLU(bit_width=8, return_quant_tensor=True), + nn.MaxPool2d(kernel_size=(1, pool_kernel), stride=(1, pool_kernel)), + qnn.QuantConv2d( + in_channels=branch_out_channels, + out_channels=branch_out_channels, + kernel_size=(1, 3), + stride=(1, 2), + padding=(0, 1), + bias=True, + **convAndLinQuantParamsNoOutputQuant), + qnn.QuantReLU(bit_width=8, return_quant_tensor=True) + ) + + self.cat_rescale = qnn.QuantIdentity(**actQuantParams) + + + def forward(self, x): + """ + Forward pass through dual-branch convolutional stem. + + Args: + x: Input tensor of shape (batch_size, channels, height, width) + Expected: (B, 1, 1, 3000) + + Returns: + Concatenated features from both branches (B, model_dim, 1, num_patches) + """ + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + + x1 = self.cat_rescale(x1) + x2 = self.cat_rescale(x2) + x3 = self.cat_rescale(x3) + x12 = torch.cat((x1, x2), dim=1) + x123 = torch.cat((x12, x3), dim=1) + return x123 + +class QSleepConViT(nn.Module): + """Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, + config:dict): + # img_size=224, + # patch_size=16, + # in_chans=3, + # num_classes=1000, + # embed_dim=768, + # depth=12, + # num_heads=12, + # mlp_ratio=4.0, + # qkv_bias=True, + # qk_scale=None, + # representation_size=None, + # drop_rate=0.0, + # attn_drop_rate=0.0, + # drop_path_rate=0.0, + # norm_layer=None): + super().__init__() + self.num_heads = config.get("num_heads", 8) + self.model_dim = config.get("model_dim", 48) + self.num_patches = config.get("num_patches", 94) + self.num_classes = config.get("num_classes", 4) + self.batch_size = config.get("batch_size", 1) + seq_len = config.get("seq_len", 95) # num_patches + 1 (for CLS token) + + self.conv_stem = ConvStem(in_channels=1, + out_channels=self.model_dim, + kernel_sizes=(25, 200, 100), + stride=4) + + num_patches = config["num_patches"] + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.model_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, self.model_dim)) + + # CLS token selector (fixed one-hot vector for ONNX-friendly extraction) + self.cls_selector = nn.Parameter(torch.zeros(1, self.num_patches + 1), requires_grad=False) + self.cls_selector.data[0, 0] = 1.0 # Select only the first token (CLS) + + self.pos_drop = nn.Dropout(p=config["attention_dropout"]) + + self.qaddpos = qnn.QuantIdentity(**actQuantParams) + + self.encoder = Encoder( + self.model_dim, + config["num_heads"], + config["attention_dropout"], + config["mlp_head_hidden_dim"], + config["encoder_ff_dropout"]) + self.norm = nn.LayerNorm(self.model_dim, eps =1e-6) + self.rescale_norm = qnn.QuantIdentity(**actQuantParams) + + self.head = qnn.QuantLinear(self.model_dim, self.num_classes, **convAndLinQuantParams) + + def forward(self, x): + x = self.conv_stem(x) + + x = x.reshape(self.batch_size, self.model_dim, self.num_patches).permute(0, 2, 1) + + cls_tokens = self.cls_token.expand(self.batch_size, -1, -1) + cls_tokens = self.qaddpos(cls_tokens) + x = self.qaddpos(x) + x_cls = torch.cat((cls_tokens, x), dim=1) + pos = self.qaddpos(self.pos_embed) + x_pos = x_cls + pos + x = self.encoder(x_pos) + x = self.norm(x) + x = torch.matmul(self.cls_selector, x) + x = x.squeeze(1) # [B, 1, 48] -> [B, 48] + x = self.rescale_norm(x) + print(F"hello") + x = self.head(x) + print(F"hello2222") + + return x diff --git a/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit_scales.json b/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit_scales.json new file mode 100755 index 0000000..5170c2f --- /dev/null +++ b/onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit_scales.json @@ -0,0 +1,1764 @@ +{ + "cls_token_quant": 0.01728847809135914, + "pos_embed_quant": 0.01728847809135914, + "inputQuant.act_quant": 0.0377829372882843, + "inputQuant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch1.conv1.weight_quant": [ + [ + [ + 0.007700167130678892 + ] + ], + [ + [ + 0.0043680076487362385 + ] + ], + [ + [ + 0.0037727945018559694 + ] + ], + [ + [ + 0.0077829216606915 + ] + ], + [ + [ + 0.005485333036631346 + ] + ], + [ + [ + 0.005003750324249268 + ] + ], + [ + [ + 0.006256120279431343 + ] + ], + [ + [ + 0.004952708724886179 + ] + ], + [ + [ + 0.005447492003440857 + ] + ], + [ + [ + 0.005874777678400278 + ] + ], + [ + [ + 0.004525422118604183 + ] + ], + [ + [ + 0.0034753396175801754 + ] + ], + [ + [ + 0.0050322627648711205 + ] + ], + [ + [ + 0.0038158234674483538 + ] + ], + [ + [ + 0.004551413934677839 + ] + ], + [ + [ + 0.005285318940877914 + ] + ] + ], + "conv_stem.branch1.conv1.input_quant": 0.03770062327384949, + "conv_stem.branch1.conv1.output_quant": 0.04255806654691696, + "conv_stem.branch1.conv1.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch1.conv1.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch1.relu1.act_quant": 0.021949702873826027, + "conv_stem.branch1.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch1.conv2.weight_quant": [ + [ + [ + 0.004444824997335672 + ] + ], + [ + [ + 0.004555459599941969 + ] + ], + [ + [ + 0.0044272239319980145 + ] + ], + [ + [ + 0.004483164753764868 + ] + ], + [ + [ + 0.004080228973180056 + ] + ], + [ + [ + 0.004470501095056534 + ] + ], + [ + [ + 0.0034556023310869932 + ] + ], + [ + [ + 0.004151413682848215 + ] + ], + [ + [ + 0.0036266050301492214 + ] + ], + [ + [ + 0.0033739982172846794 + ] + ], + [ + [ + 0.0030401505064219236 + ] + ], + [ + [ + 0.00424550985917449 + ] + ], + [ + [ + 0.003907631151378155 + ] + ], + [ + [ + 0.004164642188698053 + ] + ], + [ + [ + 0.005155718885362148 + ] + ], + [ + [ + 0.0036308413837105036 + ] + ] + ], + "conv_stem.branch1.conv2.input_quant": 0.03199717402458191, + "conv_stem.branch1.conv2.output_quant": 0.017414741218090057, + "conv_stem.branch1.conv2.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch1.conv2.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch1.relu2.act_quant": 0.008679273538291454, + "conv_stem.branch1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch2.conv1.weight_quant": [ + [ + [ + 0.00259192381054163 + ] + ], + [ + [ + 0.0032467590644955635 + ] + ], + [ + [ + 0.0031782411970198154 + ] + ], + [ + [ + 0.0030435211956501007 + ] + ], + [ + [ + 0.003025458427146077 + ] + ], + [ + [ + 0.0034957064781337976 + ] + ], + [ + [ + 0.0043477690778672695 + ] + ], + [ + [ + 0.003210951341316104 + ] + ], + [ + [ + 0.003571817884221673 + ] + ], + [ + [ + 0.0030537741258740425 + ] + ], + [ + [ + 0.002389652421697974 + ] + ], + [ + [ + 0.003742265747860074 + ] + ], + [ + [ + 0.0027983197942376137 + ] + ], + [ + [ + 0.002486150711774826 + ] + ], + [ + [ + 0.003615174675360322 + ] + ], + [ + [ + 0.0026133409701287746 + ] + ] + ], + "conv_stem.branch2.conv1.input_quant": 0.03344094380736351, + "conv_stem.branch2.conv1.output_quant": 0.04139895737171173, + "conv_stem.branch2.conv1.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch2.conv1.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch2.relu1.act_quant": 0.02064749039709568, + "conv_stem.branch2.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch2.conv2.weight_quant": [ + [ + [ + 0.004379000514745712 + ] + ], + [ + [ + 0.004424882587045431 + ] + ], + [ + [ + 0.0032617414835840464 + ] + ], + [ + [ + 0.0035319533199071884 + ] + ], + [ + [ + 0.0037408589851111174 + ] + ], + [ + [ + 0.00466295937076211 + ] + ], + [ + [ + 0.0034633143804967403 + ] + ], + [ + [ + 0.004278861917555332 + ] + ], + [ + [ + 0.003981275949627161 + ] + ], + [ + [ + 0.004527073819190264 + ] + ], + [ + [ + 0.003036445938050747 + ] + ], + [ + [ + 0.003864831058308482 + ] + ], + [ + [ + 0.004883200395852327 + ] + ], + [ + [ + 0.004086349159479141 + ] + ], + [ + [ + 0.0044512441381812096 + ] + ], + [ + [ + 0.003594176610931754 + ] + ] + ], + "conv_stem.branch2.conv2.input_quant": 0.03914378210902214, + "conv_stem.branch2.conv2.output_quant": 0.017506461590528488, + "conv_stem.branch2.conv2.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch2.conv2.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch2.relu2.act_quant": 0.008729592896997929, + "conv_stem.branch2.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch3.conv1.weight_quant": [ + [ + [ + 0.0022716252133250237 + ] + ], + [ + [ + 0.0020840235520154238 + ] + ], + [ + [ + 0.002908001421019435 + ] + ], + [ + [ + 0.0022023706696927547 + ] + ], + [ + [ + 0.00227094953879714 + ] + ], + [ + [ + 0.0025678740348666906 + ] + ], + [ + [ + 0.0023773445282131433 + ] + ], + [ + [ + 0.0018388611497357488 + ] + ], + [ + [ + 0.0028274417854845524 + ] + ], + [ + [ + 0.0026444629766047 + ] + ], + [ + [ + 0.0024917894043028355 + ] + ], + [ + [ + 0.003058363450691104 + ] + ], + [ + [ + 0.0025761625729501247 + ] + ], + [ + [ + 0.002443275647237897 + ] + ], + [ + [ + 0.002814142731949687 + ] + ], + [ + [ + 0.0023239522706717253 + ] + ] + ], + "conv_stem.branch3.conv1.input_quant": 0.036097608506679535, + "conv_stem.branch3.conv1.output_quant": 0.04434343799948692, + "conv_stem.branch3.conv1.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch3.conv1.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch3.relu1.act_quant": 0.022029001265764236, + "conv_stem.branch3.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch3.conv2.weight_quant": [ + [ + [ + 0.0034151491709053516 + ] + ], + [ + [ + 0.004080288112163544 + ] + ], + [ + [ + 0.0038126357831060886 + ] + ], + [ + [ + 0.005141297355294228 + ] + ], + [ + [ + 0.0032300332095474005 + ] + ], + [ + [ + 0.004376872908324003 + ] + ], + [ + [ + 0.0042967842891812325 + ] + ], + [ + [ + 0.0038058566860854626 + ] + ], + [ + [ + 0.005534668453037739 + ] + ], + [ + [ + 0.004367930814623833 + ] + ], + [ + [ + 0.00320817856118083 + ] + ], + [ + [ + 0.00449279835447669 + ] + ], + [ + [ + 0.004353045951575041 + ] + ], + [ + [ + 0.003384666284546256 + ] + ], + [ + [ + 0.0037087006494402885 + ] + ], + [ + [ + 0.0034532949794083834 + ] + ] + ], + "conv_stem.branch3.conv2.input_quant": 0.04130538925528526, + "conv_stem.branch3.conv2.output_quant": 0.018119122833013535, + "conv_stem.branch3.conv2.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch3.conv2.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.branch3.relu2.act_quant": 0.008606902323663235, + "conv_stem.branch3.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "conv_stem.cat_rescale.act_quant": 0.033355433493852615, + "conv_stem.cat_rescale.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "qaddpos.act_quant": 0.01728847809135914, + "qaddpos.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.ln_1.weight_quant": 1.0, + "encoder.ln_1.bias_quant": 1.0, + "encoder.mha.in_proj.weight_quant": [ + [ + 0.002227425342425704 + ], + [ + 0.0019200658425688744 + ], + [ + 0.0015780801186338067 + ], + [ + 0.0022304558660835028 + ], + [ + 0.0021207055542618036 + ], + [ + 0.0025219551753252745 + ], + [ + 0.0020652839448302984 + ], + [ + 0.0016119007486850023 + ], + [ + 0.001190426992252469 + ], + [ + 0.0015318552032113075 + ], + [ + 0.0021116258576512337 + ], + [ + 0.0014778232434764504 + ], + [ + 0.0018063544994220138 + ], + [ + 0.00247784866951406 + ], + [ + 0.0021212827414274216 + ], + [ + 0.0021191565319895744 + ], + [ + 0.002262998139485717 + ], + [ + 0.0015699389623478055 + ], + [ + 0.0015913064125925303 + ], + [ + 0.0020706388168036938 + ], + [ + 0.0013684657169505954 + ], + [ + 0.0017239818116649985 + ], + [ + 0.0022642784751951694 + ], + [ + 0.002249767305329442 + ], + [ + 0.0019359608413651586 + ], + [ + 0.0016061724163591862 + ], + [ + 0.0017447177087888122 + ], + [ + 0.0016570232110098004 + ], + [ + 0.0017342462670058012 + ], + [ + 0.002182023599743843 + ], + [ + 0.001516907592304051 + ], + [ + 0.0019236189546063542 + ], + [ + 0.0021568378433585167 + ], + [ + 0.0016910543199628592 + ], + [ + 0.002450961619615555 + ], + [ + 0.0019052011193707585 + ], + [ + 0.001985696377232671 + ], + [ + 0.0017749292310327291 + ], + [ + 0.002220303053036332 + ], + [ + 0.0018171115079894662 + ], + [ + 0.002102222293615341 + ], + [ + 0.0016377090942114592 + ], + [ + 0.001501825638115406 + ], + [ + 0.0023003611713647842 + ], + [ + 0.002273709047585726 + ], + [ + 0.00214559119194746 + ], + [ + 0.001888459431938827 + ], + [ + 0.0025583531241863966 + ], + [ + 0.002527237869799137 + ], + [ + 0.0017287884838879108 + ], + [ + 0.0023058399092406034 + ], + [ + 0.002082264283671975 + ], + [ + 0.00220063840970397 + ], + [ + 0.002563020447269082 + ], + [ + 0.001976113999262452 + ], + [ + 0.002403559861704707 + ], + [ + 0.002103435341268778 + ], + [ + 0.002617834834381938 + ], + [ + 0.0018967618234455585 + ], + [ + 0.002261911751702428 + ], + [ + 0.002293736208230257 + ], + [ + 0.0033512054942548275 + ], + [ + 0.002045930363237858 + ], + [ + 0.002225602278485894 + ], + [ + 0.0022986880503594875 + ], + [ + 0.001997053623199463 + ], + [ + 0.0017529248725622892 + ], + [ + 0.0025695357471704483 + ], + [ + 0.0032922953832894564 + ], + [ + 0.0016513023292645812 + ], + [ + 0.0019141251686960459 + ], + [ + 0.0021567908115684986 + ], + [ + 0.0023256458807736635 + ], + [ + 0.0018727181013673544 + ], + [ + 0.0017376739997416735 + ], + [ + 0.0020870454609394073 + ], + [ + 0.0024144649505615234 + ], + [ + 0.0017871428281068802 + ], + [ + 0.0015482084127143025 + ], + [ + 0.0017507793381810188 + ], + [ + 0.0017274431884288788 + ], + [ + 0.0015455292304977775 + ], + [ + 0.0021849204786121845 + ], + [ + 0.002436551498249173 + ], + [ + 0.0020151345524936914 + ], + [ + 0.001740949461236596 + ], + [ + 0.0020386443939059973 + ], + [ + 0.0021721338853240013 + ], + [ + 0.0021542191971093416 + ], + [ + 0.0019129628781229258 + ], + [ + 0.0022090948186814785 + ], + [ + 0.0023202283773571253 + ], + [ + 0.0017562947468832135 + ], + [ + 0.002050582552328706 + ], + [ + 0.0016276971437036991 + ], + [ + 0.0018958599539473653 + ], + [ + 0.0028672567568719387 + ], + [ + 0.0023723426274955273 + ], + [ + 0.0019987544510513544 + ], + [ + 0.0014957032399252057 + ], + [ + 0.002106226049363613 + ], + [ + 0.0023632990196347237 + ], + [ + 0.002179571893066168 + ], + [ + 0.001977774081751704 + ], + [ + 0.002300483640283346 + ], + [ + 0.002420345088467002 + ], + [ + 0.001927125733345747 + ], + [ + 0.0026756927836686373 + ], + [ + 0.001601619180291891 + ], + [ + 0.002247429918497801 + ], + [ + 0.002316506579518318 + ], + [ + 0.0024339063093066216 + ], + [ + 0.0030354605987668037 + ], + [ + 0.002391157438978553 + ], + [ + 0.002622373402118683 + ], + [ + 0.0017479618545621634 + ], + [ + 0.0017675459384918213 + ], + [ + 0.0015539828455075622 + ], + [ + 0.001859157346189022 + ], + [ + 0.001591840642504394 + ], + [ + 0.0022190390154719353 + ], + [ + 0.002155131893232465 + ], + [ + 0.0020647626370191574 + ], + [ + 0.0016848515952005982 + ], + [ + 0.002332368167117238 + ], + [ + 0.0026039527729153633 + ], + [ + 0.0025643929839134216 + ], + [ + 0.002681432059034705 + ], + [ + 0.0021394183859229088 + ], + [ + 0.001830349792726338 + ], + [ + 0.0018074617255479097 + ], + [ + 0.0022083832882344723 + ], + [ + 0.00212676078081131 + ], + [ + 0.002255026251077652 + ], + [ + 0.001608288032002747 + ], + [ + 0.0021116887219250202 + ], + [ + 0.0017895271303132176 + ], + [ + 0.001574153546243906 + ], + [ + 0.002267166506499052 + ], + [ + 0.002660273341462016 + ], + [ + 0.0020764567889273167 + ], + [ + 0.0020272093825042248 + ], + [ + 0.0019864789210259914 + ], + [ + 0.00267822970636189 + ] + ], + "encoder.mha.in_proj.bias_quant": [ + 6.050807860447094e-05, + 4.922443986288272e-05, + 4.68466714664828e-05, + 5.296996459946968e-05, + 3.756155274459161e-05, + 5.658749432768673e-05, + 4.557514330372214e-05, + 4.555678242468275e-05, + 3.539890894899145e-05, + 3.611960346461274e-05, + 5.6647539167897776e-05, + 4.4585354771697894e-05, + 4.352255564299412e-05, + 3.7262441765051335e-05, + 4.690066634793766e-05, + 5.057119778939523e-05, + 5.256723670754582e-05, + 3.767906309803948e-05, + 4.2507370380917564e-05, + 4.036747486679815e-05, + 4.279847053112462e-05, + 4.438845280674286e-05, + 5.2103394409641623e-05, + 4.336747952038422e-05, + 3.657146226032637e-05, + 3.230433867429383e-05, + 4.955203257850371e-05, + 4.33584900747519e-05, + 3.4494762076064944e-05, + 5.161468652659096e-05, + 4.749869913212024e-05, + 4.490981882554479e-05, + 6.151020352263004e-05, + 3.709495649673045e-05, + 5.546927059185691e-05, + 5.4975083912722766e-05, + 4.561202513286844e-05, + 4.2529383790679276e-05, + 4.7862526116659865e-05, + 5.6377695727860555e-05, + 4.419066681293771e-05, + 3.627356272772886e-05, + 4.332325261202641e-05, + 4.6744971768930554e-05, + 4.89504418510478e-05, + 4.948950299876742e-05, + 3.5182780266040936e-05, + 4.947255001752637e-05, + 9.07920693862252e-05, + 4.894350058748387e-05, + 6.287409632932395e-05, + 8.18209518911317e-05, + 8.316156890941784e-05, + 0.00010528480197535828, + 6.556376320077106e-05, + 4.793433254235424e-05, + 5.923550634179264e-05, + 5.246044020168483e-05, + 4.607233131537214e-05, + 9.146483353106305e-05, + 7.756686682114378e-05, + 0.000116885727038607, + 6.287929863901809e-05, + 5.9779620642075315e-05, + 8.138502016663551e-05, + 5.619056537398137e-05, + 4.1679748392198235e-05, + 8.837071800371632e-05, + 6.762079283362255e-05, + 7.908021507319063e-05, + 6.078336446080357e-05, + 0.00010703787120291963, + 4.26152691943571e-05, + 5.7693319831741974e-05, + 7.634259964106604e-05, + 4.140152668696828e-05, + 5.301801866153255e-05, + 7.98362452769652e-05, + 6.117582233855501e-05, + 9.58439995883964e-05, + 0.00017761954222805798, + 3.056228888453916e-05, + 4.116747004445642e-05, + 0.00014178532001096755, + 0.00016005351790226996, + 4.5663866330869496e-05, + 3.932779509341344e-05, + 0.00021386156731750816, + 6.770022446289659e-05, + 6.698924698866904e-05, + 4.681516657001339e-05, + 6.321432738332078e-05, + 4.233980871504173e-05, + 7.646171434316784e-05, + 3.360000846441835e-05, + 9.593149297870696e-05, + 0.0001145838905358687, + 0.00014977862883824855, + 8.654566772747785e-05, + 7.856641605030745e-05, + 0.0001036906469380483, + 8.620283915661275e-05, + 9.573156421538442e-05, + 8.059927495196462e-05, + 0.00010001923510571942, + 0.00010245416342513636, + 8.587779302615672e-05, + 0.00010281941649736837, + 8.3255632489454e-05, + 9.708908328320831e-05, + 8.199297735700384e-05, + 8.647535287309438e-05, + 0.0001182715714094229, + 0.00011093866487499326, + 0.00012129334936616942, + 0.0001401532645104453, + 8.586016338085756e-05, + 6.286949792411178e-05, + 0.00010298903362127021, + 8.130848436849192e-05, + 0.0001292371889576316, + 0.00010454284347360954, + 9.429518831893802e-05, + 8.248539961641654e-05, + 0.00011942695709876716, + 9.332699119113386e-05, + 0.00010913529695244506, + 8.825951954349875e-05, + 0.00011692324187606573, + 0.00017953739734366536, + 8.932058699429035e-05, + 7.868179091019556e-05, + 8.940260158851743e-05, + 6.824493902968243e-05, + 0.00010123774700332433, + 7.339402509387583e-05, + 8.13867591205053e-05, + 0.00010148149885935709, + 0.00010597008804325014, + 0.0001041899886331521, + 7.151288446038961e-05, + 9.605907689547166e-05, + 9.291098831454292e-05, + 9.540798782836646e-05 + ], + "encoder.mha.in_proj.input_quant": 0.032149918377399445, + "encoder.mha.in_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.mha.out_proj.weight_quant": [ + [ + 0.0025450787506997585 + ], + [ + 0.0033634952269494534 + ], + [ + 0.0022102659568190575 + ], + [ + 0.002740070689469576 + ], + [ + 0.0025040581822395325 + ], + [ + 0.0031312594655901194 + ], + [ + 0.002642548643052578 + ], + [ + 0.0029453104361891747 + ], + [ + 0.0027452234644442797 + ], + [ + 0.0028631347231566906 + ], + [ + 0.0025560350622981787 + ], + [ + 0.0032959417439997196 + ], + [ + 0.0034846037160605192 + ], + [ + 0.0026389879640191793 + ], + [ + 0.00330344308167696 + ], + [ + 0.0037991786375641823 + ], + [ + 0.0023303786292672157 + ], + [ + 0.0028480386827141047 + ], + [ + 0.0020462400279939175 + ], + [ + 0.0022324773017317057 + ], + [ + 0.002819840796291828 + ], + [ + 0.003238159930333495 + ], + [ + 0.0027130895759910345 + ], + [ + 0.0027771799359470606 + ], + [ + 0.0029279368463903666 + ], + [ + 0.003152851015329361 + ], + [ + 0.00266431481577456 + ], + [ + 0.002973100868985057 + ], + [ + 0.0026854926254600286 + ], + [ + 0.0032350695692002773 + ], + [ + 0.004139230120927095 + ], + [ + 0.0029536543879657984 + ], + [ + 0.004061378538608551 + ], + [ + 0.0032096209470182657 + ], + [ + 0.0027513550594449043 + ], + [ + 0.0031433275435119867 + ], + [ + 0.0024465785827487707 + ], + [ + 0.0027551387902349234 + ], + [ + 0.0037565298844128847 + ], + [ + 0.002779862843453884 + ], + [ + 0.0021981794852763414 + ], + [ + 0.0032158789690583944 + ], + [ + 0.002521520247682929 + ], + [ + 0.003021752927452326 + ], + [ + 0.0027959533035755157 + ], + [ + 0.0024837807286530733 + ], + [ + 0.0032649582717567682 + ], + [ + 0.0028171902522444725 + ] + ], + "encoder.mha.out_proj.bias_quant": [ + 2.3149881599238142e-05, + 2.5793873646762222e-05, + 2.4190048861782998e-05, + 3.5503599065123126e-05, + 2.2698917746311054e-05, + 1.9987739506177604e-05, + 2.1740779629908502e-05, + 2.1676312826457433e-05, + 2.1760217350674793e-05, + 3.1122242944547907e-05, + 3.110301622655243e-05, + 2.599974141048733e-05, + 2.6771669581648894e-05, + 2.8198943255119957e-05, + 2.716815833991859e-05, + 2.5208031729562208e-05, + 2.6325098588131368e-05, + 2.5534336600685492e-05, + 2.02864521270385e-05, + 2.6430747311678715e-05, + 2.4457227482344024e-05, + 3.087884761043824e-05, + 2.8083823053748347e-05, + 2.0496856450336054e-05, + 2.0321023839642294e-05, + 2.8517079044831917e-05, + 2.86385457002325e-05, + 2.222750845248811e-05, + 3.508044756017625e-05, + 2.2550859284820035e-05, + 2.8742635549861006e-05, + 2.03707877517445e-05, + 2.6890724257100374e-05, + 2.7156776923220605e-05, + 2.663042687345296e-05, + 3.320925679872744e-05, + 1.901651921798475e-05, + 2.392613168922253e-05, + 2.8610702429432422e-05, + 2.460148789396044e-05, + 2.430097447359003e-05, + 3.24661705235485e-05, + 2.562650479376316e-05, + 2.752750151557848e-05, + 2.4320879674633034e-05, + 2.595531623228453e-05, + 2.7301739464746788e-05, + 2.3921917090774514e-05 + ], + "encoder.mha.out_proj.input_quant": 0.007638814393430948, + "encoder.mha.out_proj.output_quant": 0.01099616289138794, + "encoder.mha.out_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.mha.out_proj.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.mha.attn_output_weights_quant.act_quant": 0.0010784146143123507, + "encoder.mha.attn_output_weights_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.mha.q_scaled_quant.act_quant": 0.004697044380009174, + "encoder.mha.q_scaled_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.mha.k_transposed_quant.act_quant": 0.016757408156991005, + "encoder.mha.k_transposed_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.mha.v_quant.act_quant": 0.016788391396403313, + "encoder.mha.v_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.ln_2.weight_quant": 1.0, + "encoder.ln_2.bias_quant": 1.0, + "encoder.ff.ff1.weight_quant": [ + [ + 0.0032328846864402294 + ], + [ + 0.0025348379276692867 + ], + [ + 0.002104574115946889 + ], + [ + 0.0023122639395296574 + ], + [ + 0.002344944979995489 + ], + [ + 0.003326729405671358 + ], + [ + 0.003282792866230011 + ], + [ + 0.002623545238748193 + ], + [ + 0.0027602515183389187 + ], + [ + 0.003014691872522235 + ], + [ + 0.002916456200182438 + ], + [ + 0.0024082937743514776 + ], + [ + 0.002552654827013612 + ], + [ + 0.0031073656864464283 + ], + [ + 0.0035184575244784355 + ], + [ + 0.0032426666002720594 + ], + [ + 0.002688169479370117 + ], + [ + 0.0024928851053118706 + ], + [ + 0.0021280536893755198 + ], + [ + 0.0032955931965261698 + ], + [ + 0.0033104673493653536 + ], + [ + 0.003410267410799861 + ], + [ + 0.0032123830169439316 + ], + [ + 0.0023184348829090595 + ], + [ + 0.002515599364414811 + ], + [ + 0.0031135152094066143 + ], + [ + 0.002972954884171486 + ], + [ + 0.002515817992389202 + ], + [ + 0.0029994600918143988 + ], + [ + 0.002357449848204851 + ], + [ + 0.0033730571158230305 + ], + [ + 0.002846228191629052 + ], + [ + 0.003079464193433523 + ], + [ + 0.003881868440657854 + ], + [ + 0.0028620802331715822 + ], + [ + 0.002888432936742902 + ], + [ + 0.002717219525948167 + ], + [ + 0.0016008632956072688 + ], + [ + 0.002567456103861332 + ], + [ + 0.0025209919549524784 + ], + [ + 0.002560699824243784 + ], + [ + 0.0027761533856391907 + ], + [ + 0.002413744805380702 + ], + [ + 0.002689571352675557 + ], + [ + 0.0020539232064038515 + ], + [ + 0.002438942203298211 + ], + [ + 0.002750420942902565 + ], + [ + 0.0027201515622437 + ] + ], + "encoder.ff.ff1.bias_quant": [ + 7.498719060095027e-05, + 4.926235487801023e-05, + 5.050877371104434e-05, + 5.95391247770749e-05, + 6.156202289275825e-05, + 7.18023075023666e-05, + 6.633693556068465e-05, + 6.831396603956819e-05, + 5.878887895960361e-05, + 6.835787644376978e-05, + 5.9965568652842194e-05, + 4.8976828111335635e-05, + 4.337828795542009e-05, + 7.494446617783979e-05, + 8.367924601770937e-05, + 8.022222755244002e-05, + 6.349650357151404e-05, + 6.474449764937162e-05, + 5.1259808969916776e-05, + 6.987973029026762e-05, + 7.845495565561578e-05, + 4.757084752782248e-05, + 6.418041448341683e-05, + 9.647008846513927e-05, + 8.411757153226063e-05, + 5.395405969466083e-05, + 8.463787526125088e-05, + 5.833321120007895e-05, + 6.0272974224062636e-05, + 5.969961785012856e-05, + 7.512702723033726e-05, + 6.501157622551546e-05, + 5.7351207942701876e-05, + 7.466744864359498e-05, + 6.815824599470943e-05, + 4.395714131533168e-05, + 6.773317727493122e-05, + 3.7081736081745476e-05, + 6.156482413643971e-05, + 7.225068111438304e-05, + 6.43636958557181e-05, + 5.388354111346416e-05, + 6.28075358690694e-05, + 5.432395118987188e-05, + 5.495657023857348e-05, + 6.445019971579313e-05, + 7.341690798057243e-05, + 5.165556285646744e-05 + ], + "encoder.ff.ff1.input_quant": 0.019113758578896523, + "encoder.ff.ff1.output_quant": 0.021797163411974907, + "encoder.ff.ff1.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.ff.ff1.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.ff.ff2.weight_quant": [ + [ + 0.002540784189477563 + ], + [ + 0.0025547058321535587 + ], + [ + 0.0026918419171124697 + ], + [ + 0.002338874852284789 + ], + [ + 0.003251499729231 + ], + [ + 0.003947967663407326 + ], + [ + 0.0030620135366916656 + ], + [ + 0.003166659502312541 + ], + [ + 0.0035519192460924387 + ], + [ + 0.0036688782274723053 + ], + [ + 0.0031043800991028547 + ], + [ + 0.002984486985951662 + ], + [ + 0.0019274965161457658 + ], + [ + 0.002529407385736704 + ], + [ + 0.002311328426003456 + ], + [ + 0.00278285495005548 + ], + [ + 0.0036396037321537733 + ], + [ + 0.0032709587831050158 + ], + [ + 0.002647558692842722 + ], + [ + 0.0027007642202079296 + ], + [ + 0.0025122605729848146 + ], + [ + 0.002891841111704707 + ], + [ + 0.0025530215352773666 + ], + [ + 0.0021426621824502945 + ], + [ + 0.003227367764338851 + ], + [ + 0.002991440938785672 + ], + [ + 0.0020417352207005024 + ], + [ + 0.0031559448689222336 + ], + [ + 0.002952039008960128 + ], + [ + 0.0026928395964205265 + ], + [ + 0.002132965950295329 + ], + [ + 0.0027183855418115854 + ], + [ + 0.0031710907351225615 + ], + [ + 0.003934640437364578 + ], + [ + 0.0021151858381927013 + ], + [ + 0.002895907498896122 + ], + [ + 0.0031890079844743013 + ], + [ + 0.003015984781086445 + ], + [ + 0.003341935109347105 + ], + [ + 0.0025853347033262253 + ], + [ + 0.0030722706578671932 + ], + [ + 0.0024727594573050737 + ], + [ + 0.00430693943053484 + ], + [ + 0.002899958286434412 + ], + [ + 0.0029888523276895285 + ], + [ + 0.002468215301632881 + ], + [ + 0.0021575740538537502 + ], + [ + 0.0036121737211942673 + ] + ], + "encoder.ff.ff2.bias_quant": [ + 4.566737334243953e-05, + 5.065666846348904e-05, + 5.845870327902958e-05, + 3.3672829886199906e-05, + 5.23770613654051e-05, + 5.571214569499716e-05, + 6.301180110312998e-05, + 5.294337825034745e-05, + 5.839547156938352e-05, + 5.684942516381852e-05, + 4.855397855862975e-05, + 4.916664693155326e-05, + 3.133434438495897e-05, + 3.880137956002727e-05, + 4.7777139116078615e-05, + 4.578085645334795e-05, + 4.729117063106969e-05, + 3.761542393476702e-05, + 5.1902861741837114e-05, + 6.013745587551966e-05, + 4.579684537020512e-05, + 4.3461157474666834e-05, + 4.273742524674162e-05, + 4.934638491249643e-05, + 5.9881193010369316e-05, + 7.744455069769174e-05, + 5.4648073273710907e-05, + 6.650984869338572e-05, + 3.375795859028585e-05, + 4.954998075845651e-05, + 4.33569002780132e-05, + 4.52802560175769e-05, + 4.2928571929223835e-05, + 4.6040015149628744e-05, + 4.666531094699167e-05, + 7.07791987224482e-05, + 5.071628402220085e-05, + 4.680823258240707e-05, + 5.402659371611662e-05, + 4.765831909026019e-05, + 5.396133929025382e-05, + 4.038778934045695e-05, + 4.6490800741594285e-05, + 4.875994272879325e-05, + 5.950702689005993e-05, + 4.113189788768068e-05, + 3.662782546598464e-05, + 6.549015233758837e-05 + ], + "encoder.ff.ff2.input_quant": 0.016389423981308937, + "encoder.ff.ff2.output_quant": 0.015031063929200172, + "encoder.ff.ff2.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.ff.ff2.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.rescale_residual1.act_quant": 0.011087513528764248, + "encoder.rescale_residual1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "encoder.rescale_residual2.act_quant": 0.01659729890525341, + "encoder.rescale_residual2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "norm.weight_quant": 1.0, + "norm.bias_quant": 1.0, + "rescale_norm.act_quant": 0.015843253582715988, + "rescale_norm.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "head.weight_quant": [ + [ + 0.0028364479076117277 + ], + [ + 0.003566386876627803 + ], + [ + 0.0033993548713624477 + ], + [ + 0.0034161575604230165 + ] + ], + "head.bias_quant": [ + 5.7410190493101254e-05, + 4.766035999637097e-05, + 5.813076859340072e-05, + 6.15611279499717e-05 + ], + "head.input_quant": 0.015449205413460732, + "head.output_quant": 0.0427057258784771, + "head.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0, + "head.output_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value_quant": 1.0 +} \ No newline at end of file diff --git a/onnx4deeploy/models/pytorch_models/sleep_convit/sleep_convit.py b/onnx4deeploy/models/pytorch_models/sleep_convit/sleep_convit.py index 0261cb3..945a015 100644 --- a/onnx4deeploy/models/pytorch_models/sleep_convit/sleep_convit.py +++ b/onnx4deeploy/models/pytorch_models/sleep_convit/sleep_convit.py @@ -137,7 +137,6 @@ class ConvStem(nn.Module): This module applies two parallel convolutional branches with different kernel sizes to capture multi-scale temporal features, then concatenates the results. - Simplified to 2 branches for better Deeploy compatibility. Uses Conv2d for better compatibility with ONNX Runtime transformer optimizer. Args: @@ -149,11 +148,11 @@ class ConvStem(nn.Module): """ def __init__( - self, in_channels=1, out_channels=48, kernel_sizes=(25, 100), stride=4, pool_kernel=4 + self, in_channels=1, out_channels=48, kernel_sizes=(25, 200, 100), stride=4, pool_kernel=4 ): super().__init__() # Divide the total output channels equally across the 2 branches - branch_out_channels = out_channels // 2 + branch_out_channels = out_channels // 3 # Branch 1: Kernel size 25 (fine-grained features) self.branch1 = nn.Sequential( @@ -177,8 +176,8 @@ def __init__( ), nn.ReLU(inplace=False), ) - - # Branch 2: Kernel size 100 (coarse-grained features) + + # Branch 2: Kernel size 200 (middle-grained features) self.branch2 = nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -186,6 +185,29 @@ def __init__( kernel_size=(1, kernel_sizes[1]), stride=(1, stride), padding=(0, kernel_sizes[1] // 2), + bias=False + ), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(1, pool_kernel), stride=(1, pool_kernel)), + nn.Conv2d( + in_channels=branch_out_channels, + out_channels=branch_out_channels, + kernel_size=(1, 3), + stride=(1, 2), + padding=(0, 1), + bias=False + ), + nn.ReLU(inplace=True) + ) + + # Branch 3: Kernel size 100 (coarse-grained features) + self.branch3 = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=branch_out_channels, + kernel_size=(1, kernel_sizes[2]), + stride=(1, stride), + padding=(0, kernel_sizes[2] // 2), bias=False, ), nn.ReLU(inplace=False), @@ -214,9 +236,12 @@ def forward(self, x): """ x1 = self.branch1(x) x2 = self.branch2(x) + x3 = self.branch3(x) # Single concatenation (Deeploy compatible) - x = torch.cat((x1, x2), dim=1) - return x + x12 = torch.cat((x1, x2), dim=1) + x123 = torch.cat((x12, x3), dim=1) + print(F"ConvStem output shape: {x123.shape}") # Debug print to verify output shape + return x123 class Encoder(nn.Module): @@ -327,7 +352,7 @@ def __init__(self, config: dict): self.conv_stem = ConvStem( in_channels=1, out_channels=self.model_dim, - kernel_sizes=(25, 100), # 2 branches: fine-grained (25) and coarse-grained (100) + kernel_sizes=(25, 200, 100), # 3 branches with different kernel sizes stride=4, ) diff --git a/onnx4deeploy/models/qlite_cnn_exporter.py b/onnx4deeploy/models/qlite_cnn_exporter.py new file mode 100644 index 0000000..afd627a --- /dev/null +++ b/onnx4deeploy/models/qlite_cnn_exporter.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""QLite CNN Model Exporter.""" + +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch +import torch.onnx.utils +from brevitas.quant_tensor import QuantTensor + +from DeepQuant.ExportBrevitas import exportBrevitas + + +from ..core.base_exporter import BaseONNXExporter + +# Import QLiteCNN PyTorch model from new location +from .pytorch_models.lightweight_cnn import QLiteCNN +import brevitas.onnx as bo +import json +import onnx +import onnx_graphsurgeon as gs +from onnx4deeploy.transform.quant_transform import replace_qdq_with_deeploy, insert_rqs_from_map + + +class QLiteCnnExporter(BaseONNXExporter): + """ONNX exporter for QLite CNN model.""" + + def __init__(self, save_path: str = None, config_file: str = "config.yaml"): + """ + Initialize QLite CNN exporter. + + Args: + save_path: Optional custom path to save ONNX files + config_file: Path to configuration YAML file + """ + super().__init__(save_path, config_file) + self.model_config = {} + + def load_config(self) -> Dict[str, Any]: + """ + Load QLite CNN configuration. + + Returns: + Dictionary containing QLite CNN configuration parameters + """ + # Default QLite CNN configuration + config = { + "batch_size": 1, + "input_height": 28, + "input_width": 28, + "input_channels": 1, # Grayscale images + "num_classes": 10, + "opset_version": 17, + "dropout": 0.0, # No dropout for inference + # Training configuration + "training_strategy": "full", # Options: "full", "last_layer", "custom" + "custom_trainable_params": [], + "zo": { + "epsilon": 0.1, + "seed": 42 + }, + "weights_path":"onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.pth" + } + + self.model_config = config + return config + + def create_model(self) -> torch.nn.Module: + """ + Create QLite CNN PyTorch model. + + Returns: + QLite CNN model ready for export + """ + model = QLiteCNN( + batch_size=self.model_config["batch_size"], + input_channels=self.model_config["input_channels"], + num_classes=self.model_config["num_classes"], + dropout=self.model_config["dropout"], + ) + + return model + + def get_input_shape(self) -> Tuple[int, ...]: + """ + Get the input tensor shape for QLite CNN. + + Returns: + Tuple representing input shape (batch_size, channels, height, width) + """ + batch_size = self.config["batch_size"] + channels = self.config["input_channels"] + height = self.config["input_height"] + width = self.config["input_width"] + return (batch_size, channels, height, width) + + def get_trainable_params(self, all_param_names: List[str]) -> List[str]: + """ + Get list of trainable parameter names for QLite CNN. + + Supports multiple training strategies: + - "full": Train all parameters (default) + - "last_layer": Only train the final classification layer + - "custom": Use custom_trainable_params from config + + Args: + all_param_names: List of all parameter names in the model + + Returns: + List of parameter names that should be trainable + """ + strategy = self.config.get("training_strategy", "full") + + # Define training strategies + strategy_params = { + "full": all_param_names, # Train everything + "last_layer": [ + "fc.weight", + "fc.bias", + ], + "custom": self.config.get("custom_trainable_params", []), + } + + # Get trainable params based on strategy + if strategy not in strategy_params: + print(f"โš ๏ธ Unknown training strategy '{strategy}', using 'full' as fallback") + strategy = "full" + + trainable_params = strategy_params[strategy] + + # Filter to only include params that exist in the model + requires_grad = [name for name in all_param_names if name in trainable_params] + + # Print strategy info + print(f"\n๐ŸŽฏ Training Strategy: '{strategy}'") + print(f" Total params in model: {len(all_param_names)}") + print(f" Params to train: {len(requires_grad)}") + print(f" Frozen params: {len(all_param_names) - len(requires_grad)}") + + return requires_grad + + def _get_config_string(self) -> str: + """ + Get configuration string for folder naming. + + Returns: + Configuration string like "_28x28_1ch_10" + """ + return ( + f"_{self.config['input_height']}x{self.config['input_width']}" + f"_{self.config['input_channels']}ch_{self.config['num_classes']}" + ) + + def save_test_data(self, model: torch.nn.Module, save_dir: str): + """ + Save test input/output data for validation. + + Uses PyTorch model to generate reference output for validating ONNX correctness. + + Args: + model: PyTorch model to run inference with + save_dir: Directory to save test data + """ + print("๐Ÿ’พ Saving test input/output data...") + + # Create test input + input_shape = self.get_input_shape() + test_input = np.random.randn(*input_shape).astype(np.float32) + + # Get PyTorch output (reference for validating ONNX) + was_training = model.training + model.eval() + + with torch.no_grad(): + input_tensor = torch.from_numpy(test_input) + output_tensor = model(input_tensor) + if isinstance(output_tensor, QuantTensor): + output_tensor = output_tensor.value + test_output = output_tensor.numpy() + + # Restore training mode if needed + if was_training: + model.train() + + # Save as .npz files + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + np.savez(save_path / "inputs.npz", input=test_input) + np.savez(save_path / "outputs.npz", output=test_output) + + print(" โœ… Saved test data (PyTorch reference):") + print(f" Input: {save_path / 'inputs.npz'} shape={test_input.shape}") + print(f" Output: {save_path / 'outputs.npz'} shape={test_output.shape}") + + def _build_rqs_map(self, graph: gs.Graph, brevitas_scales: Dict[str, Any]) -> Dict[str, Any]: + """ + Translate the flat Brevitas scales dump into the Deeploy edges map + expected by insert_rqs_from_map. + """ + rqs_map = {"edges": []} + + # Traverse the ONNX graph to find operators that require precision reduction + for node in graph.nodes: + if node.op in ["Conv", "Gemm", "MatMul"]: + layer_name = None + + # Match ONNX node name (e.g., "/conv1/Conv") with Brevitas scale keys + for key in brevitas_scales.keys(): + if key.endswith(".weight_quant"): + base_name = key.split(".")[0] # Extracts "conv1", "fc", etc. + if f"/{base_name}/" in node.name or node.name.startswith(base_name): + layer_name = base_name + break + + if layer_name: + in_scale = brevitas_scales.get(f"{layer_name}.input_quant") + w_scale = brevitas_scales.get(f"{layer_name}.weight_quant") + out_scale = brevitas_scales.get(f"{layer_name}.output_quant") + + + if in_scale is not None and w_scale is not None and out_scale is not None: + # Conv accumulation scale = input_scale * weight_scale + + w_flat = np.array(w_scale).flatten() + src_scale = in_scale * w_flat + + out_tensor_name = node.outputs[0].name + + # We use out_tensor_name for both src and dst to intercept and rewire + # all nodes strictly consuming the output of this convolution. + rqs_map["edges"].append({ + "src_tensor": out_tensor_name, + "dst_tensor": out_tensor_name, + "src_scale": src_scale.tolist(), + "dst_scale": float(out_scale) + }) + return rqs_map \ No newline at end of file diff --git a/onnx4deeploy/models/qsleep_convit_exporter.py b/onnx4deeploy/models/qsleep_convit_exporter.py new file mode 100644 index 0000000..472ccef --- /dev/null +++ b/onnx4deeploy/models/qsleep_convit_exporter.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""QSleepConViT Model Exporter.""" + +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch +import torch.onnx.utils +from brevitas.quant_tensor import QuantTensor + +from DeepQuant.ExportBrevitas import exportBrevitas + + +from ..core.base_exporter import BaseONNXExporter + +# Import QSleepConViT PyTorch model from new location +from .pytorch_models.sleep_convit.qsleep_convit import QSleepConViT +import brevitas.onnx as bo +import json +import onnx +import onnx_graphsurgeon as gs +from onnx4deeploy.transform.quant_transform import replace_qdq_with_deeploy, insert_rqs_from_map + + +class QSleepConViTExporter(BaseONNXExporter): + """ONNX exporter for QSleepConViT model.""" + + def __init__(self, save_path: str = None, config_file: str = "config.yaml"): + """ + Initialize QSleepConViT exporter. + + Args: + save_path: Optional custom path to save ONNX files + config_file: Path to configuration YAML file + """ + super().__init__(save_path, config_file) + self.model_config = {} + + def load_config(self) -> Dict[str, Any]: + """ + Load QSleepConViT configuration. + + Returns: + Dictionary containing QSleepConViT configuration parameters + """ + # Default QSleepConViT configuration + config = { + "batch_size": 1, + "input_channels": 1, + "input_length": 3000, # Time-series sequence length + "model_dim": 48, + "num_heads": 6, + "num_patches": 94, # Computed from ConvStem output + "seq_len": 95, # num_patches + 1 (CLS token) + "attention_dropout": 0.0, # No dropout for inference + "mlp_head_hidden_dim": 48, + "encoder_ff_dropout": 0.0, # No dropout for inference + "num_classes": 4, # Sleep stages: Wake, N1, N2, N3, REM + "opset_version": 17, # Match CCT opset version for compatibility + # Training configuration + "training_strategy": "full", # Options: "full", "last_layer", "custom" + "custom_trainable_params": [], + # ZO training configuration + "zo": { + "epsilon": 0.1, + "seed": 42, + "exceptions": ["node_matmul", "node_bmm_requant", "node_bmm_1_requant"] + }, + "weights_path":"onnx4deeploy/models/pytorch_models/sleep_convit/qsleep_convit.pth" + } + + self.model_config = config + return config + + def create_model(self) -> torch.nn.Module: + """ + Create SleepConViT PyTorch model. + + Returns: + SleepConViT model ready for export + """ + model = QSleepConViT(config=self.model_config) + return model + + def get_input_shape(self) -> Tuple[int, ...]: + """ + Get the input tensor shape for SleepConViT. + + Returns: + Tuple representing input shape (batch_size, channels, height, width) + Shape: (B, 1, 1, 3000) for compatibility with ViT/transformer optimizer + """ + batch_size = self.config["batch_size"] + channels = self.config["input_channels"] + length = self.config["input_length"] + return (batch_size, channels, 1, length) # 4D: (B, C, H, W) + + def get_trainable_params(self, all_param_names: List[str]) -> List[str]: + """ + Get list of trainable parameter names for SleepConViT. + + Supports multiple training strategies: + - "full": Train all parameters (default) + - "last_layer": Only train the final classification layer + - "custom": Use custom_trainable_params from config + + Args: + all_param_names: List of all parameter names in the model + + Returns: + List of parameter names that should be trainable + """ + strategy = self.config.get("training_strategy", "full") + + # Define training strategies + strategy_params = { + "full": all_param_names, # Train everything + "last_layer": [ + "classifier.lin1.weight", + "classifier.lin1.bias", + ], + "custom": self.config.get("custom_trainable_params", []), + } + + # Get trainable params based on strategy + if strategy not in strategy_params: + print(f"โš ๏ธ Unknown training strategy '{strategy}', using 'full' as fallback") + strategy = "full" + + trainable_params = strategy_params[strategy] + + # Filter to only include params that exist in the model + requires_grad = [name for name in all_param_names if name in trainable_params] + + # Print strategy info + print(f"\n๐ŸŽฏ Training Strategy: '{strategy}'") + print(f" Total params in model: {len(all_param_names)}") + print(f" Params to train: {len(requires_grad)}") + print(f" Frozen params: {len(all_param_names) - len(requires_grad)}") + + return requires_grad + + def _get_config_string(self) -> str: + """ + Get configuration string for folder naming. + + Returns: + Configuration string like "_3000_48d_8h_5cls" + """ + return ( + f"_{self.config['input_length']}" + f"_{self.config['model_dim']}d" + f"_{self.config['num_heads']}h" + f"_{self.config['num_classes']}cls" + ) + + def save_test_data(self, model: torch.nn.Module, save_dir: str): + """ + Save test input/output data for validation. + + Uses PyTorch model to generate reference output for validating ONNX correctness. + + Args: + model: PyTorch model to run inference with + save_dir: Directory to save test data + """ + print("๐Ÿ’พ Saving test input/output data...") + + # Create test input + input_shape = self.get_input_shape() + test_input = np.random.randn(*input_shape).astype(np.float32) + + # Get PyTorch output (reference for validating ONNX) + was_training = model.training + model.eval() + + with torch.no_grad(): + input_tensor = torch.from_numpy(test_input) + output_tensor = model(input_tensor) + if isinstance(output_tensor, QuantTensor): + output_tensor = output_tensor.value + test_output = output_tensor.numpy() + + # Restore training mode if needed + if was_training: + model.train() + + # Save as .npz files + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + np.savez(save_path / "inputs.npz", input=test_input) + np.savez(save_path / "outputs.npz", output=test_output) + + print(" โœ… Saved test data (PyTorch reference):") + print(f" Input: {save_path / 'inputs.npz'} shape={test_input.shape}") + print(f" Output: {save_path / 'outputs.npz'} shape={test_output.shape}") + + def _build_rqs_map(self, graph: gs.Graph, brevitas_scales: Dict[str, Any]) -> Dict[str, Any]: + """ + Translate the flat Brevitas scales dump into the Deeploy edges map + expected by insert_rqs_from_map. + """ + rqs_map = {"edges": []} + + # Traverse the ONNX graph to find operators that require precision reduction + for node in graph.nodes: + if node.op in ["Conv", "Gemm", "MatMul"]: + layer_name = None + + # Match ONNX node name (e.g., "/conv1/Conv") with Brevitas scale keys + for key in brevitas_scales.keys(): + if key.endswith(".weight_quant"): + base_name = key.split(".")[0] # Extracts "conv1", "fc", etc. + if f"/{base_name}/" in node.name or node.name.startswith(base_name): + layer_name = base_name + break + + if layer_name: + in_scale = brevitas_scales.get(f"{layer_name}.input_quant") + w_scale = brevitas_scales.get(f"{layer_name}.weight_quant") + out_scale = brevitas_scales.get(f"{layer_name}.output_quant") + + + if in_scale is not None and w_scale is not None and out_scale is not None: + # Conv accumulation scale = input_scale * weight_scale + + w_flat = np.array(w_scale).flatten() + src_scale = in_scale * w_flat + + out_tensor_name = node.outputs[0].name + + # We use out_tensor_name for both src and dst to intercept and rewire + # all nodes strictly consuming the output of this convolution. + rqs_map["edges"].append({ + "src_tensor": out_tensor_name, + "dst_tensor": out_tensor_name, + "src_scale": src_scale.tolist(), + "dst_scale": float(out_scale) + }) + return rqs_map \ No newline at end of file diff --git a/onnx4deeploy/models/sleep_convit_exporter.py b/onnx4deeploy/models/sleep_convit_exporter.py index 5960eee..936f945 100644 --- a/onnx4deeploy/models/sleep_convit_exporter.py +++ b/onnx4deeploy/models/sleep_convit_exporter.py @@ -67,17 +67,23 @@ def load_config(self) -> Dict[str, Any]: "input_channels": 1, "input_length": 3000, # Time-series sequence length "model_dim": 48, - "num_heads": 8, + "num_heads": 6, "num_patches": 94, # Computed from ConvStem output "seq_len": 95, # num_patches + 1 (CLS token) "attention_dropout": 0.0, # No dropout for inference - "mlp_head_hidden_dim": 192, + "mlp_head_hidden_dim": 48, "encoder_ff_dropout": 0.0, # No dropout for inference - "num_classes": 5, # Sleep stages: Wake, N1, N2, N3, REM + "num_classes": 4, # Sleep stages: Wake, N1, N2, N3, REM "opset_version": 17, # Match CCT opset version for compatibility # Training configuration "training_strategy": "full", # Options: "full", "last_layer", "custom" "custom_trainable_params": [], + # ZO training configuration + "zo": { + "epsilon": 0.1, + "seed": 42, + "exceptions": "node_matmul_2" + } } self.model_config = config diff --git a/onnx4deeploy/operators/__init__.py b/onnx4deeploy/operators/__init__.py index f3c6512..f609d62 100644 --- a/onnx4deeploy/operators/__init__.py +++ b/onnx4deeploy/operators/__init__.py @@ -37,6 +37,11 @@ from .softmax_grad import SoftmaxGradOperatorTest from .split import SplitOperatorTest from .transpose import TransposeOperatorTest +from .perturbnormal import PerturbNormalOperatorTest +from .perturbeggroll import PerturbEggrollOperatorTest +from .perturbrademacher import PerturbRademacherOperatorTest +from .perturbuniform import PerturbUniformOperatorTest +from .perturbtriangle import PerturbTriangleOperatorTest __all__ = [ "BaseOperatorTest", @@ -69,4 +74,11 @@ "ConvGradXOperatorTest", "ConvGradWOperatorTest", "ConvGradBOperatorTest", + "PerturbNormalOperatorTest", + "PerturbEggrollOperatorTest", + "PerturbRademacherOperatorTest", + "PerturbUniformOperatorTest", + "PerturbTriangleOperatorTest", + "RQSPerturbRademacherOperatorTest", + "RQSPerturbUniformOperatorTest", ] diff --git a/onnx4deeploy/operators/base_operator.py b/onnx4deeploy/operators/base_operator.py index fb16d6c..5ae89f5 100644 --- a/onnx4deeploy/operators/base_operator.py +++ b/onnx4deeploy/operators/base_operator.py @@ -231,7 +231,7 @@ def generate(self) -> Tuple[str, str, str]: # Generate inputs inputs = self.generate_inputs() - + print(F"inputs: {inputs}") # Create ONNX graph graph = self.create_onnx_graph(inputs) diff --git a/onnx4deeploy/operators/config.yaml b/onnx4deeploy/operators/config.yaml new file mode 100644 index 0000000..3f8f98c --- /dev/null +++ b/onnx4deeploy/operators/config.yaml @@ -0,0 +1,30 @@ +perturbnormal: + input_shape: [128, 48] + +perturbuniform: + input_shape: [128, 48] + +perturbtriangle: + input_shape: [128, 48] + +perturbrademacher: + input_shape: [128, 48] + +perturbeggroll: + input_shape: [128, 48] + +rqsperturbrademacher: + input_shape: [128, 48] + +rqsperturbuniform: + input_shape: [128, 48] + +conv2d: + input_shape: [1, 1, 28, 28] + kernel_size: 5 + out_channels: 16 + stride: 1 + padding: 0 + use_bias: true + group: 1 + dilation: 1 \ No newline at end of file diff --git a/onnx4deeploy/operators/perturbeggroll.py b/onnx4deeploy/operators/perturbeggroll.py new file mode 100644 index 0000000..8b2eae7 --- /dev/null +++ b/onnx4deeploy/operators/perturbeggroll.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""PerturbEggroll operator test implementation.""" + +from typing import Any, Dict, Tuple + +import numpy as np +import onnxruntime as ort +from onnx import TensorProto, helper + +from .base_operator import BaseOperatorTest + + +class PerturbEggrollOperatorTest(BaseOperatorTest): + """Test generator for ONNX PerturbEggroll operator (custom/training op).""" + + def __init__(self, config_path=None, save_path=None): + super().__init__(config_path, save_path) + self.input_shape = None + self.num_classes = None + self.batch_size = None + + def get_operator_name(self) -> str: + return "PerturbEggroll" + + def load_config(self) -> Dict[str, Any]: + """Load PerturbEggroll-specific configuration.""" + config = super().load_config() + + pn_config = config.get("perturbeggroll", {}) + self.input_shape = tuple(pn_config["input_shape"]) + return config + + + def generate_inputs(self) -> np.ndarray: + """Generate input with both positive and negative values.""" + return {"x": np.random.randn(*self.input_shape).astype(np.float32)} + + def create_onnx_graph(self, inputs: Dict[str, np.ndarray]): + """Create ONNX graph for PerturbEggroll operator.""" + # Input tensors (without loss_grad for the final model) + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + # Output tensor + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + normal_epsilon = 0.01 + uniform_epsilon = 0.01 * np.sqrt(3) + rademacher_epsilon = 0.01 + + + # Shape annotation for intermediate outputs + a_shape = [self.input_shape[0], 1] + b_shape = [int(np.prod(self.input_shape[1:])), 1] + + a_tensor = helper.make_tensor_value_info( + "a", TensorProto.FLOAT, a_shape + ) + b_tensor = helper.make_tensor_value_info( + "b", TensorProto.FLOAT, b_shape + ) + + shape_a_tensor = helper.make_tensor(name=f"shape_a", data_type=TensorProto.INT64, dims=[len(a_shape)], + vals=np.array(a_shape, dtype=np.int64)) + shape_b_tensor = helper.make_tensor(name=f"shape_b", data_type=TensorProto.INT64, dims=[len(b_shape)], + vals=np.array(b_shape, dtype=np.int64)) + + shape_input_name = helper.make_tensor(name=f"shape_x", data_type=TensorProto.INT64, dims=[len(self.input_shape)], + vals=np.array(self.input_shape, dtype=np.int64)) + + if len(self.input_shape) > 2: + + shape_flat_name = helper.make_tensor(name=f"shape_x_flat", data_type=TensorProto.INT64, dims=[2], + vals=np.array([a_shape[0], b_shape[0]], dtype=np.int64)) + # insert flattening nodes + flatten_node = helper.make_node( + "Reshape", + inputs=["x", f"shape_x_flat"], + outputs=[f"flattened_x"], + name=f"flatten_x" + ) + + flattened_tensor_name = helper.make_tensor_value_info( + f"flattened_x", TensorProto.FLOAT,[a_shape[0], b_shape[0]] + ) + + unflatten_node = helper.make_node( + "Reshape", + inputs=["flattened_perturbed_x", "shape_x"], + outputs=["perturbed_x"], + name="unflatten_perturbed_x" + ) + flattened_perturbed_tensor_name = helper.make_tensor_value_info( + "flattened_perturbed_x", TensorProto.FLOAT, [a_shape[0], b_shape[0]] + ) + eggroll_input = "flattened_x" + eggroll_output = "flattened_perturbed_x" + else: + eggroll_input = "x" + eggroll_output = "perturbed_x" + + # Eggroll noise node (without loss_grad input) + noise_node_a = helper.make_node( + "PerturbEggroll", + inputs=["shape_a"], + outputs=["a"], + name=f"gen_eggroll_noise_a", + seed=13, + idx=0, + domain="com.microsoft", + doc_string="a = RandomRademacher(x[0], seed)" + ) + + noise_node_b = helper.make_node( + "PerturbEggroll", + inputs=["shape_b"], + outputs=["b"], + name=f"gen_eggroll_noise_b", + seed=14, + idx=1, + domain="com.microsoft", + doc_string="b = RandomRademacher(x[1:], seed)" + ) + + gemm_node = helper.make_node( + "Gemm", + inputs=["a", "b", eggroll_input], + outputs=[eggroll_output], + name=f"eggroll_gemm_perturb_x", + transA=0, + transB=1, + alpha=uniform_epsilon, + beta=0 + ) + # Graph + if len(self.input_shape) > 2: + graph = helper.make_graph( + [flatten_node, noise_node_a, noise_node_b, gemm_node, unflatten_node], + "perturb_eggroll_graph", + [x_tensor], + [perturbed_x_tensor], + [shape_input_name, + shape_flat_name, + shape_a_tensor, + shape_b_tensor], # <-- shape annotations here + value_info=[a_tensor, b_tensor, + flattened_tensor_name, + flattened_perturbed_tensor_name] # <-- shape annotation here + ) + else: + graph = helper.make_graph( + [noise_node_a, noise_node_b, gemm_node], + "perturb_eggroll_graph", + [x_tensor], + [perturbed_x_tensor], + [shape_a_tensor, shape_b_tensor], + value_info=[a_tensor, b_tensor] # <-- shape annotation here + ) + + for vi in graph.value_info: + print(vi.name, [d.dim_value for d in vi.type.tensor_type.shape.dim]) + return graph + + def create_model(self, graph, opset_version: int = 13): + """Create ONNX model for PerturbEggroll with custom domain.""" + model = helper.make_model( + graph, + producer_name=f"{self.get_operator_name().lower()}_test", + opset_imports=[ + helper.make_opsetid("", opset_version), + helper.make_opsetid("com.microsoft", 1), + ], + ) + + return model + + def run_inference(self, onnx_file: str, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Run inference using custom emulation + """ + # perturbation is built from 2 low rank tensors + a = np.random.randn(*self.input_shape[:-1]).astype(np.float32) + b = np.random.randn(self.input_shape[-1]).astype(np.float32) + perturbation = np.outer(a, b) # shape: input_shape + if len(self.input_shape) > 2: + perturbation = perturbation.reshape(self.input_shape) + perturbed_x = inputs["x"] + perturbation + + return {"perturbed_x": perturbed_x} + + def compute_expected_output(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Return None to skip validation - this is a custom operator. + """ + return None diff --git a/onnx4deeploy/operators/perturbnormal.py b/onnx4deeploy/operators/perturbnormal.py new file mode 100644 index 0000000..151217f --- /dev/null +++ b/onnx4deeploy/operators/perturbnormal.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""PerturbNormal operator test implementation.""" + +from typing import Any, Dict, Tuple + + +import torch +from torch.autograd import Function +import numpy as np +import onnxruntime as ort +from onnx import TensorProto, helper + +from .base_operator import BaseOperatorTest + + +class Xorshift32: + def __init__(self, seed: int = 0): + self.state = seed if seed != 0 else 1 # Avoid zero state + + def next(self) -> int: + # Xorshift32 algorithm + self.state ^= (self.state << 13) & 0xFFFFFFFF + self.state ^= (self.state >> 17) & 0xFFFFFFFF + self.state ^= (self.state << 5) & 0xFFFFFFFF + return self.state + +class Ziggurat(): + def __init__(self, seed: int = 0): + self.seed = seed if seed != 0 else 1 # Avoid zero state + # Precompute the Ziggurat tables + self.N = 256 # Number of layers + self.R = 3.442619855899 # Right tail boundary + self.x = np.zeros(self.N + 1) + self.y = np.zeros(self.N) + self.x[0] = self.R + self.x[self.N] = 0 + for i in range(1, self.N): + self.x[i] = np.sqrt(-2.0 * np.log(np.exp(-0.5 * self.x[i-1]**2))) + for i in range(self.N): + self.y[i] = np.exp(-0.5 * self.x[i]**2) + self.rng = Xorshift32(self.seed) + + def next(self) -> float: + while True: + # Generate random layer index + k = self.rng.next() % self.N + # Generate uniform random number + u = self.rng.next() / 0xFFFFFFFF + x = u * (self.x[k] - self.x[k+1]) + self.x[k+1] + # Accept or reject + if u < self.y[k] / self.y[k+1]: + return x + if x < self.R: + y = np.exp(-0.5 * x * x) + if u * (self.y[k+1] - self.y[k]) < (y - self.y[k]): + return x + +class PerturbNormalFunction(Function): + @staticmethod + def forward(ctx, x, seed=42, epsilon=0.01): + # generate noise using Xorshift. + rng = Ziggurat(seed) + for _ in range(x.numel()): + noise = rng.next() * epsilon + perturbed_x = x + noise + return perturbed_x + + @staticmethod + def symbolic(g, x): + return g.op("ai.zo::PerturbNormal", x, outputs=1) + +class PerturbNormalOperatorTest(BaseOperatorTest): + """Test generator for ONNX PerturbNormal operator (custom/training op).""" + + def __init__(self, config_path=None, save_path=None): + super().__init__(config_path, save_path) + self.input_shape = None + self.num_classes = None + self.batch_size = None + + def get_operator_name(self) -> str: + return "PerturbNormal" + + def load_config(self) -> Dict[str, Any]: + """Load PerturbNormal-specific configuration.""" + config = super().load_config() + + pn_config = config.get("perturbnormal", {}) + self.input_shape = tuple(pn_config["input_shape"]) + return config + + + def generate_inputs(self) -> np.ndarray: + """Generate input with both positive and negative values.""" + return {"x": np.random.randn(*self.input_shape).astype(np.float32)} + + def create_onnx_graph(self, inputs: Dict[str, np.ndarray]): + """Create ONNX graph for PerturbNormal operator.""" + # Input tensors (without loss_grad for the final model) + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + # Output tensor + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # PerturbNormal node (without loss_grad input) + perturb_node = helper.make_node( + "PerturbNormal", + inputs=["x"], + outputs=["perturbed_x"], + name="perturb_normal_node", + seed=42, + eps=0.01, + idx=0, + # dtype=dtype, + doc_string="y = x + epsilon * RandomNormal(x, seed)", + domain="com.microsoft" + ) + + # Graph + graph = helper.make_graph( + [perturb_node], + "perturb_normal_graph", + [x_tensor], + [perturbed_x_tensor], + ) + + return graph + + def create_model(self, graph, opset_version: int = 13): + """Create ONNX model for PerturbNormal with custom domain.""" + model = helper.make_model( + graph, + producer_name=f"{self.get_operator_name().lower()}_test", + opset_imports=[ + helper.make_opsetid("", opset_version), + helper.make_opsetid("com.microsoft", 1), + ], + ) + + return model + + def run_inference(self, onnx_file: str, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Run inference using ONNX Runtime. + + For this custom op, we build a separate model that implements the + PerturbNormal functionality using standard ONNX ops (RandomNormal + Add) + and run inference on that to get the output. + """ + # --- Create the "Execution" Graph --- + # This graph implements the behavior of PerturbNormal for testing. + + # Input tensor info + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + # Output tensor info + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # Intermediate tensor for the random noise + noise_tensor_name = "random_noise" + + # 1. RandomNormal node to generate noise + # The shape of the noise must match the input shape. + random_node = helper.make_node( + "RandomNormal", + inputs=[], # RandomNormal has no inputs + outputs=[noise_tensor_name], + name="random_normal_for_perturb", + shape=self.input_shape, + dtype=TensorProto.FLOAT, + mean=0.0, + scale=1.0, # Standard normal distribution + ) + + # 2. Add node to add the noise to the input + add_node = helper.make_node( + "Add", + inputs=["x", noise_tensor_name], + outputs=["perturbed_x"], + name="add_perturbation", + ) + + # Create the graph that implements the custom op's logic + execution_graph = helper.make_graph( + [random_node, add_node], + "perturb_normal_execution_graph", + [x_tensor], + [perturbed_x_tensor], + ) + + # Create the ONNX model for execution + execution_model = self.create_model(execution_graph) + + # Run inference on the execution model + sess_options = ort.SessionOptions() + # Disable all optimizations to ensure nodes are not fused or altered + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + session = ort.InferenceSession(execution_model.SerializeToString(), sess_options) + # The output name is "perturbed_x" + output_names = ["perturbed_x"] + outputs = session.run(output_names, inputs) + + # Return the output in the expected dictionary format + return {"perturbed_x": outputs[0]} + + def compute_expected_output(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Return None to skip validation - this is a custom operator. + """ + return None diff --git a/onnx4deeploy/operators/perturbrademacher.py b/onnx4deeploy/operators/perturbrademacher.py new file mode 100644 index 0000000..5fb8f71 --- /dev/null +++ b/onnx4deeploy/operators/perturbrademacher.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""PerturbRademacher operator test implementation.""" + +from typing import Any, Dict, Tuple + +import numpy as np +import onnxruntime as ort +from onnx import TensorProto, helper + +from .base_operator import BaseOperatorTest + + +class PerturbRademacherOperatorTest(BaseOperatorTest): + """Test generator for ONNX PerturbRademacher operator (custom/training op).""" + + def __init__(self, config_path=None, save_path=None): + super().__init__(config_path, save_path) + self.input_shape = None + self.num_classes = None + self.batch_size = None + + def get_operator_name(self) -> str: + return "PerturbRademacher" + + def load_config(self) -> Dict[str, Any]: + """Load PerturbRademacher-specific configuration.""" + config = super().load_config() + + pn_config = config.get("perturbrademacher", {}) + self.input_shape = tuple(pn_config["input_shape"]) + return config + + + def generate_inputs(self) -> np.ndarray: + """Generate input with both positive and negative values.""" + return {"x": np.random.randn(*self.input_shape).astype(np.float32)} + + def create_onnx_graph(self, inputs: Dict[str, np.ndarray]): + """Create ONNX graph for PerturbRademacher operator.""" + # Input tensors (without loss_grad for the final model) + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + # Output tensor + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # PerturbRademacher node (without loss_grad input) + perturb_node = helper.make_node( + "PerturbRademacher", + inputs=["x"], + outputs=["perturbed_x"], + seed=42, + eps=0.01, + idx=0, + name="perturb_rademacher_node", + domain="com.microsoft" + ) + + # Graph + graph = helper.make_graph( + [perturb_node], + "perturb_rademacher_graph", + [x_tensor], + [perturbed_x_tensor], + ) + + return graph + + def create_model(self, graph, opset_version: int = 13): + """Create ONNX model for PerturbRademacher with custom domain.""" + model = helper.make_model( + graph, + producer_name=f"{self.get_operator_name().lower()}_test", + opset_imports=[ + helper.make_opsetid("", opset_version), + helper.make_opsetid("com.microsoft", 1), + ], + ) + + return model + + def run_inference(self, onnx_file: str, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Run inference using custom emulation + """ + # perturbation is built from -1's and 1's + perturbation = np.random.choice([-1, 1], size=self.input_shape).astype(np.float32) + perturbed_x = inputs["x"] + perturbation + + return {"perturbed_x": perturbed_x} + + def compute_expected_output(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Return None to skip validation - this is a custom operator. + """ + return None diff --git a/onnx4deeploy/operators/perturbtriangle.py b/onnx4deeploy/operators/perturbtriangle.py new file mode 100644 index 0000000..4e4af07 --- /dev/null +++ b/onnx4deeploy/operators/perturbtriangle.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""PerturbTriangle operator test implementation.""" + +from typing import Any, Dict, Tuple + +import numpy as np +import onnxruntime as ort +from onnx import TensorProto, helper + +from .base_operator import BaseOperatorTest + + +class PerturbTriangleOperatorTest(BaseOperatorTest): + """Test generator for ONNX PerturbTriangle operator (custom/training op).""" + + def __init__(self, config_path=None, save_path=None): + super().__init__(config_path, save_path) + self.input_shape = None + self.num_classes = None + self.batch_size = None + + def get_operator_name(self) -> str: + return "PerturbTriangle" + + def load_config(self) -> Dict[str, Any]: + """Load PerturbTriangle-specific configuration.""" + config = super().load_config() + + pn_config = config.get("perturbtriangle", {}) + self.input_shape = tuple(pn_config["input_shape"]) + return config + + def generate_inputs(self) -> np.ndarray: + """Generate input with both positive and negative values.""" + return {"x": np.random.randn(*self.input_shape).astype(np.float32)} + + def create_onnx_graph(self, inputs: Dict[str, np.ndarray]): + """Create ONNX graph for PerturbTriangle operator.""" + # Input tensors (without loss_grad for the final model) + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + # Output tensor + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # PerturbTriangle node (without loss_grad input) + perturb_node = helper.make_node( + "PerturbTriangle", + inputs=["x"], + outputs=["perturbed_x"], + seed=42, + eps=0.01*np.sqrt(6), # Scale epsilon for triangle distribution + idx=0, + name="perturb_triangle_node", + domain="com.microsoft" + ) + + # Graph + graph = helper.make_graph( + [perturb_node], + "perturb_triangle_graph", + [x_tensor], + [perturbed_x_tensor], + ) + + return graph + + def create_model(self, graph, opset_version: int = 13): + """Create ONNX model for PerturbTriangle with custom domain.""" + model = helper.make_model( + graph, + producer_name=f"{self.get_operator_name().lower()}_test", + opset_imports=[ + helper.make_opsetid("", opset_version), + helper.make_opsetid("com.microsoft", 1), + ], + ) + + return model + + def run_inference(self, onnx_file: str, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Run inference using custom emulation + """ + # perturbation is built from 2 uniforms + a = np.random.rand(*self.input_shape).astype(np.float32) + b = np.random.rand(*self.input_shape).astype(np.float32) + perturbation = (a - b)*np.sqrt(6) # triangle distribution + perturbed_x = inputs["x"] + perturbation + + return {"perturbed_x": perturbed_x} + + def compute_expected_output(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Return None to skip validation - this is a custom operator. + """ + return None diff --git a/onnx4deeploy/operators/perturbuniform.py b/onnx4deeploy/operators/perturbuniform.py new file mode 100644 index 0000000..c0fc118 --- /dev/null +++ b/onnx4deeploy/operators/perturbuniform.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""PerturbUniform operator test implementation.""" + +from typing import Any, Dict, Tuple + +import numpy as np +import onnxruntime as ort +from onnx import TensorProto, helper + +from .base_operator import BaseOperatorTest + + +class PerturbUniformOperatorTest(BaseOperatorTest): + """Test generator for ONNX PerturbUniform operator (custom/training op).""" + + def __init__(self, config_path=None, save_path=None): + super().__init__(config_path, save_path) + self.input_shape = None + self.num_classes = None + self.batch_size = None + + def get_operator_name(self) -> str: + return "PerturbUniform" + + def load_config(self) -> Dict[str, Any]: + """Load PerturbUniform-specific configuration.""" + config = super().load_config() + + pn_config = config.get("perturbuniform", {}) + self.input_shape = tuple(pn_config["input_shape"]) + return config + + + def generate_inputs(self) -> np.ndarray: + """Generate input with both positive and negative values.""" + return {"x": np.random.randn(*self.input_shape).astype(np.float32)} + + def create_onnx_graph(self, inputs: Dict[str, np.ndarray]): + """Create ONNX graph for PerturbUniform operator.""" + # Input tensors (without loss_grad for the final model) + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + # Output tensor + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # PerturbUniform node (without loss_grad input) + perturb_node = helper.make_node( + "PerturbUniform", + inputs=["x"], + outputs=["perturbed_x"], + name="perturb_uniform_node", + idx=0, + seed=42, + eps=0.01*np.sqrt(3), # Scale epsilon for uniform distribution + low=-np.sqrt(3), + high=np.sqrt(3), + # dtype=dtype, + doc_string="y = x + eps * RandomUniform(x, seed)", + domain="com.microsoft" + ) + + # Graph + graph = helper.make_graph( + [perturb_node], + "perturb_uniform_graph", + [x_tensor], + [perturbed_x_tensor], + ) + + return graph + + def create_model(self, graph, opset_version: int = 13): + """Create ONNX model for PerturbUniform with custom domain.""" + model = helper.make_model( + graph, + producer_name=f"{self.get_operator_name().lower()}_test", + opset_imports=[ + helper.make_opsetid("", opset_version), + helper.make_opsetid("com.microsoft", 1), + ], + ) + + return model + + def run_inference(self, onnx_file: str, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Run inference using ONNX Runtime. + + For this custom op, we build a separate model that implements the + PerturbUniform functionality using standard ONNX ops (RandomUniform + Add) + and run inference on that to get the output. + """ + # --- Create the "Execution" Graph --- + # This graph implements the behavior of PerturbUniform for testing. + + # Input tensor info + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + # Output tensor info + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # Intermediate tensor for the random noise + noise_tensor_name = "random_noise" + + # 1. RandomUniform node to generate noise + # The shape of the noise must match the input shape. + random_node = helper.make_node( + "RandomUniform", + inputs=[], # RandomUniform has no inputs + outputs=[noise_tensor_name], + name="random_uniform_for_perturb", + shape=self.input_shape, + dtype=TensorProto.FLOAT, + low = -np.sqrt(3), + high = np.sqrt(3) + ) + + # 2. Add node to add the noise to the input + add_node = helper.make_node( + "Add", + inputs=["x", noise_tensor_name], + outputs=["perturbed_x"], + name="add_perturbation", + ) + + # Create the graph that implements the custom op's logic + execution_graph = helper.make_graph( + [random_node, add_node], + "perturb_uniform_execution_graph", + [x_tensor], + [perturbed_x_tensor], + ) + + # Create the ONNX model for execution + execution_model = self.create_model(execution_graph) + + # Run inference on the execution model + sess_options = ort.SessionOptions() + # Disable all optimizations to ensure nodes are not fused or altered + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + session = ort.InferenceSession(execution_model.SerializeToString(), sess_options) + # The output name is "perturbed_x" + output_names = ["perturbed_x"] + outputs = session.run(output_names, inputs) + + # Return the output in the expected dictionary format + return {"perturbed_x": outputs[0]} + + def compute_expected_output(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Return None to skip validation - this is a custom operator. + """ + return None diff --git a/onnx4deeploy/operators/rqsperturbrademacher.py b/onnx4deeploy/operators/rqsperturbrademacher.py new file mode 100644 index 0000000..7cb0a4a --- /dev/null +++ b/onnx4deeploy/operators/rqsperturbrademacher.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""PerturbRademacher operator test implementation.""" + +from typing import Any, Dict, Tuple + +import numpy as np +import onnxruntime as ort +from onnx import TensorProto, helper + +from .base_operator import BaseOperatorTest + + +class RQSPerturbRademacherOperatorTest(BaseOperatorTest): + """Test generator for ONNX PerturbRademacher operator (custom/training op).""" + + def __init__(self, config_path=None, save_path=None): + super().__init__(config_path, save_path) + self.input_shape = None + self.num_classes = None + self.batch_size = None + + def get_operator_name(self) -> str: + return "PerturbRademacher" + + def load_config(self) -> Dict[str, Any]: + """Load PerturbRademacher-specific configuration.""" + config = super().load_config() + + pn_config = config.get("perturbrademacher", {}) + self.input_shape = tuple(pn_config["input_shape"]) + return config + + def generate_inputs(self) -> np.ndarray: + """Generate input with both positive and negative values.""" + x = np.random.randn(*self.input_shape).astype(np.float32) + # quantize: + max_val = np.max(np.abs(x), axis=1) + s = max_val / 127.0 + s[s == 0] = 1.0 # Avoid division by zero + mul = np.round(0.01 / s * (2**15)).astype(np.int32) # quantized multiplier for perturbation + x_quantized = np.round(x / s[:, np.newaxis]) + return {"x": x_quantized.astype(np.float32), "mul": mul.astype(np.float32)} + + def create_onnx_graph(self, inputs: Dict[str, np.ndarray]): + """Create ONNX graph for PerturbRademacher operator.""" + # Input tensors (without loss_grad for the final model) + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + mul_initializer = helper.make_tensor( + name="mul", + data_type=TensorProto.FLOAT, + dims=[self.input_shape[0]], + vals=inputs["mul"], + ) + # Output tensor + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # PerturbRademacher node (without loss_grad input) + perturb_node = helper.make_node( + "RQSPerturbRademacher", + inputs=["x", "mul"], + outputs=["perturbed_x"], + seed=42, + idx=0, + div=2**15, + n_levels=256, + signed=1, + name="rqs_perturb_rademacher_node", + domain="com.microsoft" + ) + + # Graph + graph = helper.make_graph( + [perturb_node], + "perturb_rademacher_graph", + [x_tensor], + [perturbed_x_tensor], + [mul_initializer] + ) + + return graph + + def create_model(self, graph, opset_version: int = 13): + """Create ONNX model for PerturbRademacher with custom domain.""" + model = helper.make_model( + graph, + producer_name=f"{self.get_operator_name().lower()}_test", + opset_imports=[ + helper.make_opsetid("", opset_version), + helper.make_opsetid("com.microsoft", 1), + ], + ) + + return model + + def run_inference(self, onnx_file: str, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Run inference using custom emulation + """ + # perturbation is built from -1's and 1's + perturbation = np.random.choice([-1, 1], size=self.input_shape).astype(np.int8) + perturbation = perturbation * inputs["mul"].reshape(-1, 1) // 2 ** 15 + perturbed_x = inputs["x"] + perturbation + + return {"perturbed_x": perturbed_x} + + def compute_expected_output(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Return None to skip validation - this is a custom operator. + """ + return None diff --git a/onnx4deeploy/operators/rqsperturbuniform.py b/onnx4deeploy/operators/rqsperturbuniform.py new file mode 100644 index 0000000..1ac9451 --- /dev/null +++ b/onnx4deeploy/operators/rqsperturbuniform.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""PerturbRademacher operator test implementation.""" + +from typing import Any, Dict, Tuple + +import numpy as np +import onnxruntime as ort +from onnx import TensorProto, helper + +from .base_operator import BaseOperatorTest + + +class RQSPerturbUniformOperatorTest(BaseOperatorTest): + """Test generator for ONNX PerturbUniform operator (custom/training op).""" + + def __init__(self, config_path=None, save_path=None): + super().__init__(config_path, save_path) + self.input_shape = None + self.num_classes = None + self.batch_size = None + + def get_operator_name(self) -> str: + return "PerturbUniform" + + def load_config(self) -> Dict[str, Any]: + """Load PerturbUniform-specific configuration.""" + config = super().load_config() + + pn_config = config.get("perturbuniform", {}) + self.input_shape = tuple(pn_config["input_shape"]) + return config + + def generate_inputs(self) -> np.ndarray: + """Generate input with both positive and negative values.""" + x = np.random.randn(*self.input_shape).astype(np.float32) + # quantize: + max_val = np.max(np.abs(x), axis=1) + s = max_val / 127.0 + s[s == 0] = 1.0 # Avoid division by zero + mul = np.round(0.01*np.sqrt(3) / s * (2**15)).astype(np.int32) # quantized multiplier for perturbation + x_quantized = np.round(x / s[:, np.newaxis]) + return {"x": x_quantized.astype(np.int32), "mul": mul.astype(np.int32)} + + def create_onnx_graph(self, inputs: Dict[str, np.ndarray]): + """Create ONNX graph for PerturbUniform operator.""" + # Input tensors (without loss_grad for the final model) + x_tensor = helper.make_tensor_value_info( + "x", TensorProto.FLOAT, self.input_shape + ) + mul_initializer = helper.make_tensor( + name="mul", + data_type=TensorProto.FLOAT, + dims=[self.input_shape[0]], + vals=inputs["mul"], + ) + # Output tensor + perturbed_x_tensor = helper.make_tensor_value_info( + "perturbed_x", TensorProto.FLOAT, self.input_shape + ) + + # PerturbUniform node (without loss_grad input) + perturb_node = helper.make_node( + "RQSPerturbUniform", + inputs=["x", "mul"], + outputs=["perturbed_x"], + seed=42, + idx=0, + div=2**15, + n_levels=256, + signed=1, + name="rqs_perturb_uniform_node", + domain="com.microsoft" + ) + + # Graph + graph = helper.make_graph( + [perturb_node], + "perturb_uniform_graph", + [x_tensor], + [perturbed_x_tensor], + [mul_initializer] + ) + + return graph + + def create_model(self, graph, opset_version: int = 13): + """Create ONNX model for PerturbUniform with custom domain.""" + model = helper.make_model( + graph, + producer_name=f"{self.get_operator_name().lower()}_test", + opset_imports=[ + helper.make_opsetid("", opset_version), + helper.make_opsetid("com.microsoft", 1), + ], + ) + + return model + + def run_inference(self, onnx_file: str, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Run inference using custom emulation + """ + # perturbation is built from -1's and 1's + perturbation = np.random.randint(-1, 2, size=self.input_shape).astype(np.int8) + perturbation = perturbation * inputs["mul"].reshape(-1, 1) // 2 ** 15 + perturbed_x = inputs["x"] + perturbation + + return {"perturbed_x": perturbed_x} + + def compute_expected_output(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Return None to skip validation - this is a custom operator. + """ + return None diff --git a/onnx4deeploy/optimization/shape_optimizer.py b/onnx4deeploy/optimization/shape_optimizer.py index 0cc7857..60db271 100644 --- a/onnx4deeploy/optimization/shape_optimizer.py +++ b/onnx4deeploy/optimization/shape_optimizer.py @@ -551,12 +551,28 @@ def softmax_cross_entropy_grad_shape_inference(ctx): print( f" SoftmaxCrossEntropyGrad shape inference: output shape set from log_prob input" ) + + def requantize_shift_shape_inference(ctx): + ctx.node + # Get first input + proto = ctx.get_input_type(0) + if proto is None: + return + + # Output shape matches log_prob input shape + ctx.set_output_type(0, proto) + print( + f" RequantShift shape inference: output shape set from first input" + ) # Register the custom shape inference function shape_calculator_dict = _get_shape_calculator_dict() shape_calculator_dict["com.microsoft.SoftmaxCrossEntropyGrad"] = ( softmax_cross_entropy_grad_shape_inference ) + shape_calculator_dict["ai.onnx.contrib.RequantizeShift"] = ( + requantize_shift_shape_inference + ) return True except (ImportError, AttributeError): # Internal API not available, will use fallback @@ -586,14 +602,17 @@ def infer_shapes_with_custom_ops( # Check for Microsoft custom ops op_types = set(node.op_type for node in model.graph.node) microsoft_ops = [op for op in op_types if "com.microsoft" in op] + ai_onnx_contrib_ops = [op for op in op_types if "ai.onnx.contrib" in op] if microsoft_ops: print(f" Found Microsoft custom ops: {microsoft_ops}") + if ai_onnx_contrib_ops: + print(f" Found ai.onnx.contrib ops: {ai_onnx_contrib_ops}") try: # Register custom shape inference for Microsoft ops (if available) registration_success = register_custom_shape_inference() if not registration_success: - print(" โ„น๏ธ Custom op registration not available, using fallback for Microsoft ops") + print(" โ„น๏ธ Custom op registration not available, using fallback for custom ops") # Try standard shape inference inferred_model = shape_inference.infer_shapes(model) @@ -613,7 +632,7 @@ def infer_shapes_with_custom_ops( except Exception as node_err: print(f" Node {i}: {node.op_type} failed: {str(node_err)}") # Try custom inference for Microsoft ops - if "com.microsoft" in node.op_type: + if "com.microsoft" in node.op_type or "ai.onnx.contrib" in node.op_type: print(f" Applying custom inference for: {node.op_type}") try: apply_custom_inference(inferred_model.graph, node) @@ -655,6 +674,13 @@ def apply_custom_inference(graph: onnx.GraphProto, node: onnx.NodeProto) -> None set_tensor_shape(graph, node.output[0], input_shape) print(f" {node.op_type} output shape: {input_shape}") + elif "ai.onnx.contrib.RequantizeShift" in node.op_type: + # RequantizeShift output shape matches first input + if len(node.input) >= 1 and len(node.output) >= 1: + input_shape = get_tensor_shape(graph, node.input[0]) + if input_shape: + set_tensor_shape(graph, node.output[0], input_shape) + print(f" RequantizeShift output shape: {input_shape}") def extract_subgraph(model: onnx.ModelProto, nodes: List[onnx.NodeProto]) -> onnx.ModelProto: """ diff --git a/onnx4deeploy/transform/quant_transform.py b/onnx4deeploy/transform/quant_transform.py new file mode 100644 index 0000000..f3a2763 --- /dev/null +++ b/onnx4deeploy/transform/quant_transform.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +qdq_to_deeploy.py + +Convert ONNX QDQ/QCDQ graphs (QuantizeLinear/DequantizeLinear) into Deeploy-style +Quant/Dequant nodes, and optionally insert RequantShift nodes for known +scale transitions. + +Dependencies: + pip install onnx onnx-graphsurgeon numpy + +Usage: + python qdq_to_deeploy.py --in model_qdq.onnx --out model_deeploy.onnx + python qdq_to_deeploy.py --in model_qdq.onnx --out model_deeploy.onnx --rqs rqs_map.json +""" + +import argparse +import json +from collections import defaultdict +from typing import Any, Dict, Tuple, Optional + +import numpy as np +import onnx +import onnx_graphsurgeon as gs + + +# --------------------------- +# Helpers: initializers / dtypes +# --------------------------- + +def _as_numpy(t: gs.Constant) -> np.ndarray: + return np.asarray(t.values) + +def _const_scalar(name: str, value: Any, dtype=np.float32) -> gs.Constant: + return gs.Constant(name=name, values=np.array(value, dtype=dtype)) + +def _get_initializer_value(graph: gs.Graph, name: str) -> Optional[np.ndarray]: + """Return initializer values as numpy array if present.""" + for t in graph.initializers: + if t.name == name: + if isinstance(t, gs.Constant): + return _as_numpy(t) + # gs may store initializers as Constants; this is fallback + # Also check graph.tensors() for Constant + tensors = graph.tensors() + if name in tensors and isinstance(tensors[name], gs.Constant): + return _as_numpy(tensors[name]) + return None + +def _infer_signed_from_zero_point(zp_arr: Optional[np.ndarray]) -> bool: + """ + Deeploy Quant node expects 'signed' attribute. We infer it from zero_point dtype + when possible (int8 -> signed, uint8 -> unsigned). If zp missing, default signed. + """ + if zp_arr is None: + return True + dt = zp_arr.dtype + if dt == np.int8 or dt == np.int32 or dt == np.int16: + return True + if dt == np.uint8 or dt == np.uint16: + return False + # fallback + return True + + +# --------------------------- +# Duplicate tensor name repair +# --------------------------- + +def fix_duplicate_tensor_names(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """ + Fix duplicate initializer and node output names in an ONNX graph. + + When torch.onnx.export traces a model with multiple quantizers that share + the same name prefix, it can produce: + - Multiple initializers with the same name but different values + - Multiple node outputs with the same name + + Both conditions produce invalid ONNX that ORT rejects. This pass repairs + them by renaming duplicate occurrences to unique names and rewiring + downstream consumers accordingly. + """ + graph = onnx_model.graph + + # --- Step 1: fix duplicate initializers --- + init_occs = defaultdict(list) + for init in graph.initializer: + init_occs[init.name].append(init) + + init_rename = {} + new_inits = [] + for name, occ_list in init_occs.items(): + for k, orig in enumerate(occ_list): + new_tensor = onnx.TensorProto() + new_tensor.CopyFrom(orig) + new_name = name if k == 0 else f"{name}__v{k + 1}" + new_tensor.name = new_name + new_inits.append(new_tensor) + init_rename[(name, k)] = new_name + + del graph.initializer[:] + graph.initializer.extend(new_inits) + + dup_init_names = {name for name, occs in init_occs.items() if len(occs) > 1} + next_occ = defaultdict(int) + for node in graph.node: + updated = list(node.input) + for i, inp in enumerate(node.input): + if inp in dup_init_names: + occ = next_occ[inp] + next_occ[inp] += 1 + updated[i] = init_rename[(inp, min(occ, len(init_occs[inp]) - 1))] + del node.input[:] + node.input.extend(updated) + + # --- Step 2: fix duplicate node outputs --- + produced = set() + active_name: Dict[str, str] = {} + + for node in graph.node: + updated_in = list(node.input) + for i, inp in enumerate(node.input): + if inp in active_name: + updated_in[i] = active_name[inp] + del node.input[:] + node.input.extend(updated_in) + + updated_out = list(node.output) + for i, out in enumerate(node.output): + if not out: + continue + if out in produced: + v = 2 + new_out = f"{out}__v{v}" + while new_out in produced: + v += 1 + new_out = f"{out}__v{v}" + active_name[out] = new_out + updated_out[i] = new_out + produced.add(new_out) + else: + produced.add(out) + del node.output[:] + node.output.extend(updated_out) + + for out_vi in graph.output: + if out_vi.name in active_name: + out_vi.name = active_name[out_vi.name] + + return onnx_model + + +# --------------------------- +# RequantShift parameterization +# --------------------------- + +def float_to_rqs_params(scale_ratio: np.ndarray, max_shift: int = 30) -> Tuple[np.ndarray, int, int]: + """ + Convert scale_ratio to (mul, add, div_pow2) such that: + out โ‰ˆ ((in * mul) + add) >> shift where div = 2^shift + + Deeploy RQS formula uses: + out = clip(((in * mul) + add) >> log2(div), ...) [2](https://deepwiki.com/pulp-platform/Deeploy/8.1-quantization-and-training-support) + + We pick a single shift (power-of-two div), and per-element mul if scale_ratio is vector. + add implements rounding-to-nearest: add = 1<<(shift-1). + """ + # Ensure array + r = np.asarray(scale_ratio, dtype=np.float64) + + # Choose largest shift that keeps mul within int32 range. + # mul = round(r * 2^shift) + # For safety use signed int32. + best_shift = None + for shift in range(max_shift, -1, -1): + mul = np.round(r * (2.0 ** shift)) + if np.all(np.abs(mul) <= (2**31 - 1)): + best_shift = shift + break + if best_shift is None: + raise RuntimeError("Could not find valid shift for scale_ratio") + + mul = np.round(r * (2.0 ** best_shift)).astype(np.int32) + add = (1 << (best_shift - 1)) if best_shift > 0 else 0 + div = 1 << best_shift + return mul, add, div + + +# --------------------------- +# Core transforms: QDQ -> Deeploy +# --------------------------- + +def replace_qdq_with_deeploy(graph: gs.Graph) -> None: + """ + Replace QuantizeLinear and DequantizeLinear nodes with Deeploy Quant and Dequant nodes. + + Deeploy Quant op parameters: scale, zero_point, bit_width, signed. [2](https://deepwiki.com/pulp-platform/Deeploy/8.1-quantization-and-training-support) + Deeploy Dequant op parameters: scale, zero_point; supports int8/int32 inputs. [2](https://deepwiki.com/pulp-platform/Deeploy/8.1-quantization-and-training-support) + """ + new_nodes = [] + for node in graph.nodes: + if node.op == "QuantizeLinear": + # Inputs: x, y_scale, y_zero_point (zp optional per ONNX) + x = node.inputs[0] + scale_in = node.inputs[1] if len(node.inputs) > 1 else None + zp_in = node.inputs[2] if len(node.inputs) > 2 else None + + scale = _as_numpy(scale_in) if isinstance(scale_in, gs.Constant) else _get_initializer_value(graph, scale_in.name) + zp = None + if zp_in is not None: + zp = _as_numpy(zp_in) if isinstance(zp_in, gs.Constant) else _get_initializer_value(graph, zp_in.name) + + if scale is None: + raise RuntimeError(f"QuantizeLinear node {node.name} has non-constant scale; provide constant initializer.") + + signed = _infer_signed_from_zero_point(zp) + bit_width = 8 # Deeploy QuantParser expects bit_width attribute; typical is 8. [2](https://deepwiki.com/pulp-platform/Deeploy/8.1-quantization-and-training-support) + + # Create Deeploy Quant node + old_outputs = list(node.outputs) + node.inputs.clear() + node.outputs.clear() + q = gs.Node( + op="Quant", + name=(node.name or "Quant") + "_Deeploy", + inputs=[x], + outputs=old_outputs, + attrs={ + "scale": float(scale.reshape(-1)[0]) if scale.size == 1 else scale.astype(np.float32), + "zero_point": int(zp.reshape(-1)[0]) if (zp is not None and zp.size == 1) else (zp.astype(np.int32) if zp is not None else 0), + "bit_width": bit_width, + "signed": int(signed), + } + ) + new_nodes.append(q) + continue + + if node.op == "DequantizeLinear": + # Inputs: x, x_scale, x_zero_point (zp optional) + xq = node.inputs[0] + scale_in = node.inputs[1] if len(node.inputs) > 1 else None + zp_in = node.inputs[2] if len(node.inputs) > 2 else None + + scale = _as_numpy(scale_in) if isinstance(scale_in, gs.Constant) else _get_initializer_value(graph, scale_in.name) + zp = None + if zp_in is not None: + zp = _as_numpy(zp_in) if isinstance(zp_in, gs.Constant) else _get_initializer_value(graph, zp_in.name) + + if scale is None: + raise RuntimeError(f"DequantizeLinear node {node.name} has non-constant scale; provide constant initializer.") + + old_outputs = list(node.outputs) + node.inputs.clear() + node.outputs.clear() + dq = gs.Node( + op="Dequant", + name=(node.name or "Dequant") + "_Deeploy", + inputs=[xq], + outputs=old_outputs, + attrs={ + "scale": float(scale.reshape(-1)[0]) if scale.size == 1 else scale.astype(np.float32), + "zero_point": int(zp.reshape(-1)[0]) if (zp is not None and zp.size == 1) else (zp.astype(np.int32) if zp is not None else 0), + } + ) + new_nodes.append(dq) + continue + + # passthrough + new_nodes.append(node) + + graph.nodes = new_nodes + + +def insert_rqs_from_map(graph: gs.Graph, rqs_map: Dict[str, Any]) -> None: + """ + Insert Deeploy RequantShift nodes according to a user-supplied mapping. + + Deeploy RequantShift formula: out = clip(((in * mul) + add) >> log2(div), ...) [2](https://deepwiki.com/pulp-platform/Deeploy/8.1-quantization-and-training-support) + + rqs_map format: + { + "edges": [ + { + "src_tensor": "tensorA_q", + "dst_tensor": "tensorB_q", + "src_scale": 0.0078125, + "dst_scale": 0.00390625 + } + ] + } + + This will insert RequantShift between producer(src_tensor) and consumers(dst_tensor). + """ + edges = rqs_map.get("edges", []) + if not edges: + return + + tensor_dict = graph.tensors() + + for e in edges: + src = e["src_tensor"] + dst = e["dst_tensor"] + s_src = np.array(e["src_scale"], dtype=np.float64) + s_dst = np.array(e["dst_scale"], dtype=np.float64) + + if src not in tensor_dict: + raise KeyError(f"src_tensor '{src}' not found in graph tensors.") + if dst not in tensor_dict: + raise KeyError(f"dst_tensor '{dst}' not found in graph tensors.") + + src_tensor = tensor_dict[src] + dst_tensor = tensor_dict[dst] + + # Compute ratio and integer params + ratio = s_src / s_dst + mul, add, div = float_to_rqs_params(ratio) + + # Create constants for mul/add/div + mul_const = gs.Constant(name=f"{src}_to_{dst}_mul", values=np.array(mul, dtype=np.int32)) + add_const = gs.Constant(name=f"{src}_to_{dst}_add", values=np.array(add, dtype=np.int32)) + div_const = gs.Constant(name=f"{src}_to_{dst}_div", values=np.array(div, dtype=np.int32)) + + # New intermediate tensor + rqs_out = gs.Variable(name=f"{src}_rqs_to_{dst}", dtype=src_tensor.dtype, shape=src_tensor.shape) + + rqs_node = gs.Node( + op="RequantShift", + name=f"RQS_{src}_to_{dst}", + inputs=[src_tensor, mul_const, add_const, div_const], + outputs=[rqs_out], + attrs={} + ) + + # Rewire: every consumer of src_tensor that was expecting dst_tensor gets rqs_out + # Safer approach: replace uses of src_tensor in dst's producer chain is complex; + # instead, we replace occurrences of src_tensor in inputs of nodes where it appears as dst_tensor input. + for n in graph.nodes: + for i, inp in enumerate(n.inputs): + if inp is dst_tensor: + n.inputs[i] = rqs_out + + graph.nodes.append(rqs_node) + + graph.cleanup().toposort() + + +# --------------------------- +# Main +# --------------------------- + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--in", dest="inp", required=True, help="Input ONNX (Brevitas QDQ/QCDQ)") + ap.add_argument("--out", dest="out", required=True, help="Output ONNX (Deeploy Quant/Dequant/RQS)") + ap.add_argument("--rqs", dest="rqs", default=None, help="Optional JSON map to insert RequantShift nodes") + args = ap.parse_args() + + model = onnx.load(args.inp) + graph = gs.import_onnx(model) + + replace_qdq_with_deeploy(graph) + + if args.rqs is not None: + with open(args.rqs, "r") as f: + rqs_map = json.load(f) + insert_rqs_from_map(graph, rqs_map) + + graph.cleanup().toposort() + out_model = gs.export_onnx(graph) + onnx.save(out_model, args.out) + print(f"[OK] Wrote Deeploy-style ONNX to: {args.out}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/onnx4deeploy/transform/zo_transform.py b/onnx4deeploy/transform/zo_transform.py new file mode 100644 index 0000000..8882105 --- /dev/null +++ b/onnx4deeploy/transform/zo_transform.py @@ -0,0 +1,788 @@ +import onnx +import onnxruntime as ort +import numpy as np +import sys +import os +import json +import torch +from pathlib import Path +from onnx import TensorProto, helper, shape_inference + +from onnx4deeploy.transform.model_transform import ensure_all_tensor_shapes +from DeepQuant.QuantDequantOnnx import Quant, Dequant, RequantShift + +def generate_zo_graph(inference_onnx:str, output_onnx:str, zo_config:dict, noise_type: str) -> None: + """ Generate MeZO ONNX graph for model based on its inference onnx""" + + epsilon, seed, exceptions = zo_config["epsilon"], zo_config["seed"], zo_config.get("exceptions", []) + + base_path = os.path.dirname(output_onnx) + os.makedirs(base_path, exist_ok=True) + inject_perturbation_nodes(inference_onnx, + output_path=output_onnx, + epsilon=epsilon, + seed=seed, + noise_type=noise_type, + exceptions=exceptions) + + ensure_all_tensor_shapes(model_path=output_onnx, output_path=output_onnx) + append_cross_entropy_loss(output_onnx, output_onnx, label_name='label') + +def generate_weight_update_graph(onnx_path: str, output_path: str, zo_config: dict, noise_type: str) -> None: + """ + Generates a weight update ONNX graph: for each weight/bias, creates a perturbation node + that updates the initializer in-place. No inputs, no outputs, just perturbation nodes. + """ + epsilon, seed, exceptions = zo_config["epsilon"], zo_config["seed"], zo_config.get("exceptions", []) + + model = onnx.load(onnx_path) + initializers = [init for init in model.graph.initializer if ("weight" in init.name or "bias" in init.name) and init.name not in exceptions] + new_initializers = list(initializers) # Start with the original initializers and add new ones as needed + nodes = [] + perturbation_counter = 0 + + for init in initializers: + perturbed_name = init.name # Overwrite the initializer directly + if noise_type == "gaussian": + node = helper.make_node( + "PerturbNormal", + inputs=[init.name], + outputs=[perturbed_name], + name=f"perturbnormal_{perturbed_name}", + domain="mezo", + seed=seed, + eps=epsilon, + idx=perturbation_counter, + doc_string="y = x + epsilon * RandomNormal(x, seed)" + ) + nodes.append(node) + perturbation_counter += 1 + + elif noise_type == "uniform": + node = helper.make_node( + "PerturbUniform", + inputs=[init.name], + outputs=[perturbed_name], + name=f"perturbuniform_{perturbed_name}", + domain="mezo", + idx=perturbation_counter, + seed=seed, + eps=epsilon, + low=-np.sqrt(3), + high=np.sqrt(3), + doc_string="y = x + epsilon * RandomUniform(x, seed)" + ) + nodes.append(node) + perturbation_counter += 1 + + elif noise_type == "eggroll": + # Shape annotation for intermediate outputs + # Prepare shapes for eggroll perturbation + noise_shape = list(onnx.numpy_helper.to_array(init).shape) + a_shape = [noise_shape[0], 1] + b_shape = [int(np.prod(noise_shape[1:])), 1] + print(F"noise_shape: {noise_shape}, a_shape: {a_shape}, b_shape: {b_shape}") + # Create shape initializers + shape_a_tensor = helper.make_tensor( + name=f"shape_a_{init.name}", data_type=TensorProto.INT64, dims=[len(a_shape)], + vals=np.array(a_shape, dtype=np.int64) + ) + shape_b_tensor = helper.make_tensor( + name=f"shape_b_{init.name}", data_type=TensorProto.INT64, dims=[len(b_shape)], + vals=np.array(b_shape, dtype=np.int64) + ) + shape_input_tensor = helper.make_tensor( + name=f"shape_{init.name}", data_type=TensorProto.INT64, dims=[len(noise_shape)], + vals=np.array(noise_shape, dtype=np.int64) + ) + + # Optionally flatten if needed + if len(noise_shape) > 2: + shape_flat_tensor = helper.make_tensor( + name=f"shape_{init.name}_flat", data_type=TensorProto.INT64, dims=[2], + vals=np.array([a_shape[0], b_shape[0]], dtype=np.int64) + ) + # Flatten node + flatten_node = helper.make_node( + "Reshape", + inputs=[init.name, f"shape_{init.name}_flat"], + outputs=[f"flattened_{init.name}"], + name=f"flatten_{init.name}" + ) + nodes.append(flatten_node) + eggroll_input = f"flattened_{init.name}" + eggroll_output = f"flattened_{init.name}_perturbed" + # Unflatten node + unflatten_node = helper.make_node( + "Reshape", + inputs=[eggroll_output, f"shape_{init.name}"], + outputs=[init.name], + name=f"unflatten_{init.name}_perturbed" + ) + nodes.append(unflatten_node) + # Add flat shape initializer + new_initializers.append([shape_flat_tensor]) + else: + eggroll_input = init.name + eggroll_output = f"{init.name}_perturbed" + + # Add shape initializers + new_initializers.extend([shape_a_tensor, shape_b_tensor, shape_input_tensor]) + + # Eggroll noise nodes + noise_node_a = helper.make_node( + "PerturbEggroll", + inputs=[f"shape_a_{init.name}"], + outputs=[f"a_{init.name}"], + name=f"gen_eggroll_noise_a_{init.name}", + seed=seed, + eps=epsilon, + idx=perturbation_counter, + domain="com.microsoft", + doc_string="a = RandomRademacher(x[0], seed)" + ) + noise_node_b = helper.make_node( + "PerturbEggroll", + inputs=[f"shape_b_{init.name}"], + outputs=[f"b_{init.name}"], + name=f"gen_eggroll_noise_b_{init.name}", + seed=seed, + eps=epsilon, + idx=perturbation_counter, + domain="com.microsoft", + doc_string="b = RandomRademacher(x[1:], seed)" + ) + gemm_node = helper.make_node( + "Gemm", + inputs=[f"a_{init.name}", f"b_{init.name}", eggroll_input], + outputs=[eggroll_output], + name=f"eggroll_gemm_{init.name}", + transA=0, + transB=1, + alpha=epsilon, + beta=0 + ) + + nodes.extend([noise_node_a, noise_node_b, gemm_node]) + perturbation_counter += 2 + + elif noise_type == "rademacher": + node = helper.make_node( + "PerturbRademacher", + inputs=[init.name], + outputs=[perturbed_name], + name=f"perturbrademacher_{perturbed_name}", + domain="mezo", + idx=perturbation_counter, + seed=seed, + eps=epsilon, + doc_string="y = x + epsilon * RandomRademacher(x, seed)" + ) + nodes.append(node) + perturbation_counter += 1 + + elif noise_type == "rqs_rademacher": + # compute mul factor from scale. + scale = np.max(np.abs(onnx.numpy_helper.to_array(init)), axis=tuple(range(1, len(init.dims))), keepdims=True) + epsilon = 0.01 + # Use different scaling for weights (8-bit) and biases (32-bit) + if '_add' in init.name: + quant_max = 2**31 - 1 + mul = np.round(epsilon / (scale / quant_max) * (2**31)).astype(np.int32) # quantized multiplier for perturbation + + else: + quant_max = 127 + mul = np.round(epsilon / (scale / quant_max) * (2**15)).astype(np.int64) # quantized multiplier for perturbation + init_mul = helper.make_tensor( + name=f"{init.name}_mul", + data_type=TensorProto.FLOAT, + dims=[mul.shape[0]], + vals=mul + ) + node = helper.make_node( + "RQSPerturbRademacher", + inputs=[init.name, f"{init.name}_mul"], + outputs=[perturbed_name], + name=f"rqs_perturb_rademacher_{perturbed_name}", + domain="mezo", + idx=perturbation_counter, + seed=seed, + signed=1, + div=2**15, + n_levels=256, + doc_string="y = x + epsilon * RQSRandomRademacher(x, seed)" + ) + nodes.append(node) + new_initializers.append(init_mul) + perturbation_counter += 1 + + elif noise_type == "rqs_uniform": + + # compute mul factor from scale. + scale = np.max(np.abs(onnx.numpy_helper.to_array(init)), axis=tuple(range(1, len(init.dims))), keepdims=True) + epsilon = 0.01 + # Use different scaling for weights (8-bit) and biases (32-bit) + if '_add' in init.name: + quant_max = 2**31 - 1 + else: + quant_max = 127.0 + mul = np.round(epsilon / (scale / quant_max) * (2**15)).astype(np.int32) + init_mul = helper.make_tensor( + name=f"{init.name}_mul", + data_type=TensorProto.FLOAT, + dims=mul.shape, + vals=mul + ) + + node = helper.make_node( + "RQSPerturbUniform", + inputs=[init.name, f"{init.name}_mul"], + outputs=[perturbed_name], + name=f"rqs_perturb_uniform_{perturbed_name}", + domain="mezo", + seed=seed, + idx=perturbation_counter, + signed=1, + div=2**15, + n_levels=256, + doc_string="y = x + epsilon * RQSRandomUniform(x, seed)" + ) + nodes.append(node) + new_initializers.append(init_mul) + perturbation_counter += 1 + else: + raise ValueError(f"Unsupported noise_type: {noise_type}") + + # Build a minimal graph: no inputs, no outputs, just initializers and nodes + graph = helper.make_graph( + nodes=nodes, + name="weight_update_graph", + inputs=[], # No inputs + outputs=[], # No outputs + initializer=initializers + ) + + # Use the same opset as the original model, plus mezo domain + standard_opset_version = next((op.version for op in model.opset_import if op.domain == ""), 13) + opset_list = [ + helper.make_opsetid("", standard_opset_version), + helper.make_opsetid("mezo", 1) + ] + new_model = helper.make_model(graph, producer_name="mezo-weight-update", opset_imports=opset_list) + onnx.save(new_model, output_path) + print(f"Saved weight update graph to: {output_path}") + + +def inject_perturbation_nodes( + onnx_path: str, + output_path: str, + epsilon: float = 0.01, + seed: float = 42.0, + noise_type: str = "gaussian", + exceptions: list[str] = [] +) -> None: + """ + This function inserts statically-seeded random operators. The unique seed for each + operator serves as an identifier that a custom hardware runtime can override with + a dynamic, runtime-provided seed. + + Args: + onnx_path: Path to the original ONNX model. + epsilon: The magnitude of the perturbation. + For the negative forward pass, just reverse the sign. + For inference, set to 0. + seed: A base seed to generate unique, deterministic seeds for each operator. + noise_type: The type of random distribution to use ('gaussian' or 'uniform'). + """ + # Load original ONNX model + p = Path(onnx_path) + + # --- 1. Identify target weights and biases --- + model = onnx.load(onnx_path) + weights_and_biases = { + init.name + for init in model.graph.initializer + if "weight" in init.name or "bias" in init.name + } + + if not weights_and_biases: + print("Warning: No weights or biases containing 'weight' or 'bias' in their names were found to perturb.") + return + + print(f"Found {len(weights_and_biases)} weight/bias tensors to perturb.") + + def modify_graph(original_model: onnx.ModelProto, output_path: str, exceptions: list[str]): + new_nodes = [] + extra_value_infos = [] + + # Keep track of all initializers. We will add to this list. + new_initializers = list(original_model.graph.initializer) + + # Create a set of initializer names for quick lookups + initializer_names = {init.name for init in new_initializers} + + base_seed = int(seed) + perturbation_counter = 0 + epsilon=0.01 + # Prepare a fast lookup for initializer names + initializer_names = {init.name for init in new_initializers} + + print(f"all initializer names: {initializer_names}") + + for node in original_model.graph.node: + # Check if this is a node we want to modify + if node.op_type in ["Conv", "Gemm", "MatMul", "RequantShift"] and node.name not in exceptions: + print(F"node: {node.name}, op_type: {node.op_type}") + modified_inputs = list(node.input) + made_change = False + + for i, input_name in enumerate(node.input): + # Check if the input is a weight/bias initializer + + # For RequantShift, only perturb the 3rd input (the 'add' term/bias). + if node.op_type == "RequantShift" and i != 2: + continue + + if input_name in initializer_names: + made_change = True + + print(f"input_name {i}: {input_name}") + # Find the original weight tensor to get its properties + original_weight_tensor = next(t for t in new_initializers if t.name == input_name) + dtype = TensorProto.DataType.Name(original_weight_tensor.data_type) # "FLOAT" + noise_shape = original_weight_tensor.dims + noise_shape = [int(x) for x in noise_shape] + + # --- This is the core logic for injecting nodes --- + + # 1. Define names for the new intermediate tensors + perturbed_tensor_name = f"{perturbation_counter}_{input_name}" + # 2. Create the RandomNormal/RandomUniform node + unique_seed = float(base_seed + perturbation_counter) + + if noise_type == "gaussian": + perturbation_node = helper.make_node( + "PerturbNormal", + inputs=[input_name], + outputs=[perturbed_tensor_name], + name=f"perturbnormal_{perturbed_tensor_name}", + domain="mezo", + seed=seed, + eps=epsilon, + idx=perturbation_counter, + # dtype=dtype, + doc_string="y = x + epsilon * RandomNormal(x, seed)" + ) + new_nodes.append(perturbation_node) + perturbation_counter += 1 + + elif noise_type == "uniform": + perturbation_node = helper.make_node( + "PerturbUniform", + inputs=[input_name], + outputs=[perturbed_tensor_name], + name=f"perturbuniform_{perturbed_tensor_name}", + domain="mezo", + idx=perturbation_counter, + seed=seed, + eps=epsilon*2*np.sqrt(3), + low=-np.sqrt(3), + high=np.sqrt(3), + # dtype=dtype, + doc_string="y = x + epsilon * RandomUniform(x, seed)" + ) + new_nodes.append(perturbation_node) + perturbation_counter += 1 + + elif noise_type == "triangle": + perturbation_node = helper.make_node( + "PerturbTriangle", + inputs=[input_name], + outputs=[perturbed_tensor_name], + name=f"perturbtriangle_{perturbed_tensor_name}", + domain="mezo", + idx=perturbation_counter, + seed=seed, + eps=epsilon*2*np.sqrt(6), + low=-np.sqrt(6), + high=np.sqrt(6), + # dtype=dtype, + doc_string="y = x + epsilon * RandomTriangle(x, seed)" + ) + new_nodes.append(perturbation_node) + perturbation_counter += 1 + + elif noise_type == "rademacher": + perturbation_node = helper.make_node( + "PerturbRademacher", + inputs=[input_name], + outputs=[perturbed_tensor_name], + name=f"perturbrademacher_{perturbed_tensor_name}", + domain="mezo", + idx=perturbation_counter, + seed=seed, + eps=epsilon, + # dtype=dtype, + doc_string="y = x + epsilon * RandomRademacher(x, seed)" + ) + new_nodes.append(perturbation_node) + perturbation_counter += 1 + + elif noise_type == "rqs_rademacher": + scale = np.max(np.abs(onnx.numpy_helper.to_array(original_weight_tensor)), + axis=tuple(range(1, len(original_weight_tensor.dims))), keepdims=True) + epsilon = 0.01 + # compute mul factor from scale. input 1 is normally the weight tensor. + if '_add' in input_name: + print(F"found add in name: {input_name}, treating as bias with 32-bit quantization") + quant_max = 2**31 - 1 + # For biases, inherit the 'div' from the RequantShift node itself + print(F"NOde: {node.name}, op_type: {node.op_type}, attributes: {node.attribute}") + # CORRECT: 'div' is a TENSOR attribute, not an INT attribute. + div_tensor_proto = next((attr.t for attr in node.attribute if attr.name == 'div'), None) + if div_tensor_proto is None: + raise ValueError(f"Could not find 'div' TENSOR attribute on RequantShift node: {node.name}") + + # Convert the TensorProto to a numpy array and get the scalar value. + producer_div = onnx.numpy_helper.to_array(div_tensor_proto).item() + producer_mul_tensor = None + if len(node.input) > 1: + mul_input_name = node.input[1] + producer_mul_tensor = next((init for init in new_initializers if init.name == mul_input_name), None) + if producer_div is None: + raise ValueError(f"Could not find 'div' attribute on RequantShift node: {node.name}") + print(F"producer_div: {producer_div}, producer_mul: {producer_mul_tensor}") + + div=producer_div + n_levels = 2**32 + producer_mul = onnx.numpy_helper.to_array(producer_mul_tensor) + + mul = np.round(epsilon *producer_mul).astype(np.int32) # quantized multiplier for perturbation + + else: + quant_max = 127 + div = 2**15 + n_levels = 2**8 + mul = np.round(epsilon / (scale / quant_max) * (div)).astype(np.int64) + init_mul = helper.make_tensor( + name=f"{input_name}_mul", + data_type=TensorProto.FLOAT, + dims=[mul.shape[0]], + vals=mul if mul.ndim > 0 else mul + ) + new_initializers.append(init_mul) + + perturbation_node = helper.make_node( + "RQSPerturbRademacher", + inputs=[input_name, f"{input_name}_mul"], + outputs=[perturbed_tensor_name], + name=f"rqs_perturb_rademacher_{perturbed_tensor_name}", + domain="mezo", + idx=perturbation_counter, + seed=seed, + signed=1, + div=div, + n_levels=n_levels, + doc_string="y = x + epsilon * RQSRandomRademacher(x, seed)" + ) + new_nodes.append(perturbation_node) + perturbation_counter += 1 + + elif noise_type == "rqs_uniform": + scale = np.max(np.abs(onnx.numpy_helper.to_array(original_weight_tensor)), + axis=tuple(range(1, len(original_weight_tensor.dims))), keepdims=True) + # compute mul factor from scale. input 1 is normally the weight tensor. + if '_add' in input_name: + quant_max = 2**31 - 1 + producer_div = next((attr.i for attr in node.attribute if attr.name == 'div'), None) + if producer_div is None: + raise ValueError(f"Could not find 'div' attribute on RequantShift node: {node.name}") + div = producer_div + n_levels = 2**8 + mul = np.round(epsilon / (scale / quant_max) * (2**31)).astype(np.int32) # quantized multiplier for perturbation + + else: + quant_max = 127 + div = 2**15 + n_levels = 2**8 + mul = np.round(epsilon / (scale / quant_max) * (2**15)).astype(np.int64) # quantized multiplier for perturbation + init_mul = helper.make_tensor( + name=f"{input_name}_mul", + data_type=TensorProto.FLOAT, + dims=mul.shape, + vals=mul + ) + init_mul = helper.make_tensor( + name=f"{input_name}_mul", + data_type=TensorProto.FLOAT, + dims=mul.shape, + vals=mul + ) + new_initializers.append(init_mul) + + perturbation_node = helper.make_node( + "RQSPerturbUniform", + inputs=[input_name, f"{input_name}_mul"], + outputs=[perturbed_tensor_name], + name=f"rqs_perturb_uniform_{perturbed_tensor_name}", + domain="mezo", + seed=seed, + idx=perturbation_counter, + signed=1, + div=div, + n_levels=n_levels, + doc_string="y = x + epsilon * RQSRandomUniform(x, seed)" + ) + new_nodes.append(perturbation_node) + perturbation_counter += 1 + + + elif noise_type == "eggroll": + # Shape annotation for intermediate outputs + a_shape = [noise_shape[0], 1] + b_shape = [int(np.prod(noise_shape[1:])), 1] + + shape_a_tensor = helper.make_tensor(name=f"shape_a_{input_name}", data_type=TensorProto.INT64, dims=[len(a_shape)], + vals=np.array(a_shape, dtype=np.int64)) + shape_b_tensor = helper.make_tensor(name=f"shape_b_{input_name}", data_type=TensorProto.INT64, dims=[len(b_shape)], + vals=np.array(b_shape, dtype=np.int64)) + + shape_input_name = helper.make_tensor(name=f"shape_{input_name}", data_type=TensorProto.INT64, dims=[len(noise_shape)], + vals=np.array(noise_shape, dtype=np.int64)) + new_initializers.append(shape_input_name) + new_initializers.append(shape_a_tensor) + new_initializers.append(shape_b_tensor) + + if len(noise_shape) > 2: + + shape_flat_name = helper.make_tensor(name=f"shape_{input_name}_flat", data_type=TensorProto.INT64, dims=[2], + vals=np.array([a_shape[0], b_shape[0]], dtype=np.int64)) + + new_initializers.append(shape_flat_name) + + # insert flattening nodes + flatten_node = helper.make_node( + "Reshape", + inputs=[input_name, f"shape_{input_name}_flat"], + outputs=[f"flattened_{input_name}"], + name=f"flatten_{input_name}" + ) + new_nodes.append(flatten_node) + extra_value_infos.append(helper.make_tensor_value_info( + f"flattened_{input_name}", TensorProto.FLOAT,[noise_shape[0], int(np.prod(noise_shape[1:]))] + )) + + unflatten_node = helper.make_node( + "Reshape", + inputs=[f"flattened_{perturbed_tensor_name}", f"shape_{input_name}"], + outputs=[perturbed_tensor_name], + name=f"unflatten_{perturbed_tensor_name}" + ) + new_nodes.append(unflatten_node) + extra_value_infos.append(helper.make_tensor_value_info( + f"flattened_{perturbed_tensor_name}", TensorProto.FLOAT, [noise_shape[0], int(np.prod(noise_shape[1:]))] + )) + eggroll_input = f"flattened_{input_name}" + eggroll_output = f"flattened_{perturbed_tensor_name}" + else: + eggroll_input = input_name + eggroll_output = perturbed_tensor_name + + extra_value_infos.append(helper.make_tensor_value_info( + f"a_{perturbed_tensor_name}", TensorProto.FLOAT, a_shape + )) + extra_value_infos.append(helper.make_tensor_value_info( + f"b_{perturbed_tensor_name}", TensorProto.FLOAT, b_shape + )) + + # Eggroll noise node (without loss_grad input) + noise_node_a = helper.make_node( + "PerturbEggroll", + inputs=[f"shape_a_{input_name}"], + outputs=[f"a_{perturbed_tensor_name}"], + name=f"gen_eggroll_noise_a_{perturbed_tensor_name}", + seed=seed, + eps=epsilon, + idx=perturbation_counter, + domain="com.microsoft", + doc_string="a = RandomRademacher(x[0], seed)" + ) + + noise_node_b = helper.make_node( + "PerturbEggroll", + inputs=[f"shape_b_{input_name}"], + outputs=[f"b_{perturbed_tensor_name}"], + name=f"gen_eggroll_noise_b_{perturbed_tensor_name}", + seed=seed, + eps=epsilon, + idx=perturbation_counter, + domain="com.microsoft", + doc_string="b = RandomRademacher(x[1:], seed)" + ) + + gemm_node = helper.make_node( + "Gemm", + inputs=[f"a_{perturbed_tensor_name}", f"b_{perturbed_tensor_name}", eggroll_input], + outputs=[eggroll_output], + name=f"eggroll_gemm_{perturbed_tensor_name}", + transA=0, + transB=1, + alpha=epsilon, + beta=0 + ) + + new_nodes.append(noise_node_a) + new_nodes.append(noise_node_b) + new_nodes.append(gemm_node) + + else: + raise ValueError(f"Unsupported noise_type: {noise_type}") + + # **CRITICAL**: annotate perturbed edge with same dtype/shape as weight + if len(original_weight_tensor.dims) == 1: + out_shape = (original_weight_tensor.dims[0], ) + print(f"out_shape: {out_shape}") + else: + out_shape = original_weight_tensor.dims + extra_value_infos.append( + helper.make_tensor_value_info(perturbed_tensor_name, + elem_type=TensorProto.FLOAT, + shape=out_shape) + ) + + # 5. Update the input list for the *original* node + modified_inputs[i] = perturbed_tensor_name + perturbation_counter += 1 + + if made_change: + + # handle attributes + kwargs = {} + for attr in node.attribute: + # Use get_attribute_value to extract the python value from the AttributeProto + kwargs[attr.name] = helper.get_attribute_value(attr) + + # Create a new version of the Conv/Gemm node with the modified inputs + new_original_node = helper.make_node( + node.op_type, + modified_inputs, # Use the updated input list + node.output, + name=node.name, + domain=node.domain, + **kwargs + ) + new_nodes.append(new_original_node) + else: + # If no weights were perturbed, add the original node back unchanged + new_nodes.append(node) + else: + # This node is not a target, so add it to our new list as-is + new_nodes.append(node) + + new_value_info = list(original_model.graph.value_info) + extra_value_infos + + # Create a new graph with the new list of nodes and initializers + new_graph = helper.make_graph( + nodes=new_nodes, + name=f"{original_model.graph.name}-{node.op_type}", + inputs=original_model.graph.input, + outputs=original_model.graph.output, + initializer=new_initializers, + value_info=new_value_info + ) + + # Create and save the new model + for op in original_model.opset_import: + if op.domain == "": + standard_opset_version = op.version + break + + opset_list = [ + # Add the standard opset with the version we found + helper.make_opsetid("", standard_opset_version), + + # Addcustom domains + helper.make_opsetid("mezo", 1), + helper.make_opsetid("ai.onnx.contrib", 1), + helper.make_opsetid("com.microsoft", 1) + ] + new_model = helper.make_model(new_graph, producer_name="mezo-graph-generator", + opset_imports=opset_list) + + # onnx.checker.check_model(new_model) + onnx.save(new_model, output_path) + + # --- Main execution --- + original_model = onnx.load(onnx_path) + + print(f"Found {len(original_model.graph.initializer)} initializers. Perturbing weights/biases in Conv, MatMul, Gemm nodes.") + + modify_graph(original_model, output_path, exceptions=exceptions) + + print(f"Saved perturbed models to:\n- {output_path}") + return output_path + +def append_cross_entropy_loss(onnx_path, output_path, label_name='y', logits_output_idx=0, reduction='mean'): + """ + Adds a brand-new label input (INT64, shape ['batch_size']) -- guaranteed to be a new input name -- + appends a SoftmaxCrossEntropyLoss node consuming the model logits and the new label input, + and replaces the graph output with the scalar loss. + """ + model = onnx.load(onnx_path) + graph = model.graph + + if len(graph.output) == 0: + raise RuntimeError("Model has no outputs to attach the loss to.") + + # resolve logits tensor name (default: first graph output) + if logits_output_idx < 0 or logits_output_idx >= len(graph.output): + raise RuntimeError(f"Invalid logits_output_idx {logits_output_idx}") + logits_name = graph.output[logits_output_idx].name + + existing_inputs = {inp.name for inp in graph.input} + label_input_name = label_name + suffix = 0 + while label_input_name in existing_inputs: + label_input_name = f"{label_name}_mezo{suffix}" + suffix += 1 + + # get batch size from the first graph input (no initializer checks) + batch_dim = "batch_size" # fallback symbolic name + if graph.input: + first_inp = graph.input[0] + if first_inp.type.HasField("tensor_type") and first_inp.type.tensor_type.shape.dim: + first_dim = first_inp.type.tensor_type.shape.dim[0] + if first_dim.HasField("dim_value") and first_dim.dim_value > 0: + batch_dim = int(first_dim.dim_value) + elif first_dim.HasField("dim_param") and first_dim.dim_param: + batch_dim = first_dim.dim_param + + # add the new label input using resolved batch_dim + label_vi = helper.make_tensor_value_info(label_input_name, TensorProto.INT64, [batch_dim, 1]) + graph.input.append(label_vi) + + # create loss node (standard SoftmaxCrossEntropyLoss) with proper attribute + logprob = "log_prob" + loss_node = helper.make_node( + "SoftmaxCrossEntropyLoss", + inputs=[logits_name, label_input_name], + outputs=[logprob], + name="CrossEntropyLoss", + reduction=reduction, + ) + graph.node.append(loss_node) + + # replace graph outputs with the log prob + output_shape = [d.dim_value if d.HasField("dim_value") else d.dim_param + for d in graph.output[0].type.tensor_type.shape.dim] + + del graph.output[:] + graph.output.append(helper.make_tensor_value_info(logprob, TensorProto.FLOAT, output_shape)) + + # try to infer shapes and save + try: + inferred = shape_inference.infer_shapes(model) + onnx.save(inferred, output_path) + except Exception as e: + print(F" shape inference failed, saving without shape inference. Error was: {e}") + onnx.save(model, output_path) diff --git a/onnx4deeploy/utils/__init__.py b/onnx4deeploy/utils/__init__.py index 0dca197..97c572f 100644 --- a/onnx4deeploy/utils/__init__.py +++ b/onnx4deeploy/utils/__init__.py @@ -5,6 +5,7 @@ """Utility functions for ONNX model manipulation.""" from .node_naming import make_c_name, rename_and_save_onnx, rename_nodes, rename_onnx_nodes +from .onnx_node_implementations import run_onnx_graph __all__ = [ "make_c_name", diff --git a/onnx4deeploy/utils/onnx_node_implementations.py b/onnx4deeploy/utils/onnx_node_implementations.py new file mode 100644 index 0000000..686de66 --- /dev/null +++ b/onnx4deeploy/utils/onnx_node_implementations.py @@ -0,0 +1,794 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +Pure-Python ONNX graph executor. + +Executes ONNX graphs without relying on onnxruntime.InferenceSession, so that +graphs containing custom ops (Deeploy Quant/Dequant/RequantShift and MeZO +perturbation ops) can be run natively. + +Supported op domains: + - "" (standard ONNX ops) + - "ai.onnx.contrib" โ†’ Quant, Dequant, RequantShift + - "mezo" โ†’ PerturbNormal, PerturbUniform, PerturbRademacher, + PerturbTriangle, PerturbEggroll, + RQSPerturbRademacher, RQSPerturbUniform +""" + +from __future__ import annotations + +import math +from typing import Any, Dict, List, Optional + +import numpy as np +import onnx +from onnx import numpy_helper, TensorProto + +# --------------------------------------------------------------------------- +# RNG reference matching Deeploy C++ implementation +# --------------------------------------------------------------------------- + +NUM_CORES: int = 8 + + +def _scramble(seed: int) -> np.uint32: + """seed * 1664525 + 1013904223 (mod 2**32, 32-bit LCG scramble).""" + return np.uint32(np.uint32(seed) * np.uint32(1664525) + np.uint32(1013904223)) + + +def _xorshift32(state: np.uint32) -> np.uint32: + """One Xorshift32 step.""" + state = np.uint32(state) + state ^= np.uint32(state << np.uint32(13)) + state ^= np.uint32(state >> np.uint32(17)) + state ^= np.uint32(state << np.uint32(5)) + return state + + +def _uint32_to_signed_float(u: np.uint32) -> float: + """Map uint32 uniformly to (-1, 1).""" + return float(u) / float(np.iinfo(np.uint32).max) * 2.0 - 1.0 + + +def _perturb_uniform( + data: np.ndarray, + global_seed: int, + node_id: int, + eps: float, + sign: int = 1, +) -> np.ndarray: + """ + Uniform perturbation matching Deeploy PerturbUniform kernel. + + seed per core = scramble(global_seed + NUM_CORES*node_id + core_id) + rand in (-1, 1) scaled by eps (the ONNX node attribute already encodes + the full scale, e.g. epsilon*2*sqrt(3)). + """ + flat = data.flatten().astype(np.float32) + size = flat.size + log2core = int(math.log2(NUM_CORES)) + + for core_id in range(NUM_CORES): + chunk = (size >> log2core) + (1 if (size & (NUM_CORES - 1)) else 0) + chunk_start = min(chunk * core_id, size) + chunk_stop = min(chunk_start + chunk, size) + + seed = _scramble(global_seed + NUM_CORES * node_id + core_id) + for i in range(chunk_start, chunk_stop): + seed = _xorshift32(seed) + flat[i] += np.float32(sign * _uint32_to_signed_float(seed) * eps) + + return flat.reshape(data.shape) + + +def _perturb_rademacher( + data: np.ndarray, + global_seed: int, + node_id: int, + eps: float, + sign: int = 1, +) -> np.ndarray: + """Rademacher perturbation: each element offset by ยฑeps.""" + flat = data.flatten().astype(np.float32) + size = flat.size + log2core = int(math.log2(NUM_CORES)) + + for core_id in range(NUM_CORES): + chunk = (size >> log2core) + (1 if (size & (NUM_CORES - 1)) else 0) + chunk_start = min(chunk * core_id, size) + chunk_stop = min(chunk_start + chunk, size) + + seed = _scramble(global_seed + NUM_CORES * node_id + core_id) + for i in range(chunk_start, chunk_stop): + seed = _xorshift32(seed) + rad = np.float32(1.0) if (seed & np.uint32(1)) else np.float32(-1.0) + flat[i] += np.float32(sign) * rad * np.float32(eps) + + return flat.reshape(data.shape) + + +def _perturb_normal( + data: np.ndarray, + global_seed: int, + node_id: int, + eps: float, + sign: int = 1, +) -> np.ndarray: + """Gaussian perturbation via Box-Muller on Xorshift32 draws.""" + flat = data.flatten().astype(np.float32) + size = flat.size + log2core = int(math.log2(NUM_CORES)) + + for core_id in range(NUM_CORES): + chunk = (size >> log2core) + (1 if (size & (NUM_CORES - 1)) else 0) + chunk_start = min(chunk * core_id, size) + chunk_stop = min(chunk_start + chunk, size) + + seed = _scramble(global_seed + NUM_CORES * node_id + core_id) + for i in range(chunk_start, chunk_stop): + seed = _xorshift32(seed) + u1 = max(float(seed) / float(np.iinfo(np.uint32).max), 1e-10) + seed = _xorshift32(seed) + u2 = float(seed) / float(np.iinfo(np.uint32).max) * 2.0 * math.pi + z = math.sqrt(-2.0 * math.log(u1)) * math.cos(u2) + flat[i] += np.float32(sign * z * eps) + + return flat.reshape(data.shape) + + +def _perturb_rqs_rademacher( + data: np.ndarray, + mul: np.ndarray, + global_seed: int, + node_id: int, + div: int, + n_levels: int, + signed: int, + sign: int = 1, +) -> np.ndarray: + """RQS Rademacher perturbation for integer-quantised tensors.""" + flat = data.astype(np.int64).flatten() + size = flat.size + log2core = int(math.log2(NUM_CORES)) + + mul_flat = mul.flatten().astype(np.int64) + num_out = mul_flat.size + # Broadcast mul across the remaining dimensions + elems_per = size // num_out if num_out > 0 and size >= num_out else 1 + mul_per_elem = np.repeat(mul_flat, elems_per) + if mul_per_elem.size < size: + mul_per_elem = np.resize(mul_per_elem, size) + + for core_id in range(NUM_CORES): + chunk = (size >> log2core) + (1 if (size & (NUM_CORES - 1)) else 0) + chunk_start = min(chunk * core_id, size) + chunk_stop = min(chunk_start + chunk, size) + + seed = _scramble(global_seed + NUM_CORES * node_id + core_id) + for i in range(chunk_start, chunk_stop): + seed = _xorshift32(seed) + rad = np.int64(1) if (seed & np.uint32(1)) else np.int64(-1) + delta = (rad * mul_per_elem[i]) // np.int64(div) + flat[i] += sign * delta + + lo = -(n_levels // 2) if signed else 0 + hi = (n_levels // 2) - 1 if signed else n_levels - 1 + flat = np.clip(flat, lo, hi) + return flat.reshape(data.shape).astype(data.dtype) + + +def _perturb_rqs_uniform( + data: np.ndarray, + mul: np.ndarray, + global_seed: int, + node_id: int, + div: int, + n_levels: int, + signed: int, + sign: int = 1, +) -> np.ndarray: + """RQS Uniform perturbation for integer-quantised tensors.""" + flat = data.astype(np.int64).flatten() + size = flat.size + log2core = int(math.log2(NUM_CORES)) + + mul_flat = mul.flatten().astype(np.int64) + num_out = mul_flat.size + elems_per = size // num_out if num_out > 0 and size >= num_out else 1 + mul_per_elem = np.repeat(mul_flat, elems_per) + if mul_per_elem.size < size: + mul_per_elem = np.resize(mul_per_elem, size) + + for core_id in range(NUM_CORES): + chunk = (size >> log2core) + (1 if (size & (NUM_CORES - 1)) else 0) + chunk_start = min(chunk * core_id, size) + chunk_stop = min(chunk_start + chunk, size) + + seed = _scramble(global_seed + NUM_CORES * node_id + core_id) + for i in range(chunk_start, chunk_stop): + seed = _xorshift32(seed) + rand_f = _uint32_to_signed_float(seed) + delta = int(rand_f * float(mul_per_elem[i])) // div + flat[i] += sign * delta + + lo = -(n_levels // 2) if signed else 0 + hi = (n_levels // 2) - 1 if signed else n_levels - 1 + flat = np.clip(flat, lo, hi) + return flat.reshape(data.shape).astype(data.dtype) + + +# --------------------------------------------------------------------------- +# Attribute extraction helper +# --------------------------------------------------------------------------- + +def _node_attrs(node: onnx.NodeProto) -> Dict[str, Any]: + """Return node attributes as a plain Python dict.""" + attrs: Dict[str, Any] = {} + for attr in node.attribute: + if attr.type == onnx.AttributeProto.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == onnx.AttributeProto.INT: + attrs[attr.name] = attr.i + elif attr.type == onnx.AttributeProto.STRING: + attrs[attr.name] = attr.s.decode("utf-8") if attr.s else "" + elif attr.type == onnx.AttributeProto.TENSOR: + attrs[attr.name] = numpy_helper.to_array(attr.t) + elif attr.type == onnx.AttributeProto.INTS: + attrs[attr.name] = list(attr.ints) + elif attr.type == onnx.AttributeProto.FLOATS: + attrs[attr.name] = list(attr.floats) + else: + attrs[attr.name] = attr + return attrs + + +# --------------------------------------------------------------------------- +# Standard ONNX op dispatcher +# --------------------------------------------------------------------------- + +def _exec_standard(op: str, inputs: List, attrs: Dict[str, Any]) -> List[np.ndarray]: # noqa: C901 + if op == "Add": + return [inputs[0] + inputs[1]] + if op == "Sub": + return [inputs[0] - inputs[1]] + if op == "Mul": + return [inputs[0] * inputs[1]] + if op == "Div": + return [inputs[0] / inputs[1]] + if op == "Neg": + return [-inputs[0]] + if op == "Abs": + return [np.abs(inputs[0])] + if op == "Sqrt": + return [np.sqrt(inputs[0])] + if op == "Exp": + return [np.exp(inputs[0])] + if op == "Log": + return [np.log(inputs[0])] + if op == "Pow": + return [np.power(inputs[0], inputs[1])] + if op == "Erf": + return [np.vectorize(math.erf)(inputs[0]).astype(inputs[0].dtype)] + if op == "Ceil": + return [np.ceil(inputs[0])] + if op == "Floor": + return [np.floor(inputs[0])] + if op == "Round": + return [np.round(inputs[0])] + if op == "Sign": + return [np.sign(inputs[0]).astype(inputs[0].dtype)] + + if op == "Clip": + lo = inputs[1] if len(inputs) > 1 and inputs[1] is not None else attrs.get("min", None) + hi = inputs[2] if len(inputs) > 2 and inputs[2] is not None else attrs.get("max", None) + return [np.clip(inputs[0], lo, hi)] + + if op in ("ReduceSum", "ReduceMean", "ReduceMax", "ReduceMin"): + x = inputs[0] + axes = attrs.get("axes", None) + keepdims = bool(attrs.get("keepdims", 1)) + noop_empty = bool(attrs.get("noop_with_empty_axes", 0)) + if len(inputs) > 1 and inputs[1] is not None: + axes = tuple(int(a) for a in inputs[1].flatten()) + elif axes is not None: + axes = tuple(axes) + if noop_empty and (axes is None or len(axes) == 0): + return [x] + fn = {"ReduceSum": np.sum, "ReduceMean": np.mean, + "ReduceMax": np.max, "ReduceMin": np.min}[op] + return [fn(x, axis=axes, keepdims=keepdims)] + + if op == "Relu": + return [np.maximum(inputs[0], 0)] + if op == "Sigmoid": + return [1.0 / (1.0 + np.exp(-inputs[0].astype(np.float64))).astype(np.float32)] + if op == "Tanh": + return [np.tanh(inputs[0])] + if op == "LeakyRelu": + alpha = float(attrs.get("alpha", 0.01)) + x = inputs[0] + return [np.where(x >= 0, x, alpha * x).astype(x.dtype)] + if op in ("Gelu", "FastGelu"): + x = inputs[0] + return [(x * 0.5 * (1.0 + np.vectorize(math.erf)(x / math.sqrt(2)))).astype(x.dtype)] + if op == "Softmax": + axis = int(attrs.get("axis", -1)) + x = inputs[0] + x_max = np.max(x, axis=axis, keepdims=True) + ex = np.exp(x - x_max) + return [(ex / np.sum(ex, axis=axis, keepdims=True)).astype(x.dtype)] + + if op == "MatMul": + return [np.matmul(inputs[0], inputs[1])] + + if op == "Gemm": + A, B = inputs[0], inputs[1] + C = inputs[2] if len(inputs) > 2 and inputs[2] is not None else 0.0 + alpha = float(attrs.get("alpha", 1.0)) + beta = float(attrs.get("beta", 1.0)) + if int(attrs.get("transA", 0)): + A = A.T + if int(attrs.get("transB", 0)): + B = B.T + return [(alpha * np.matmul(A, B) + beta * C).astype(A.dtype)] + + if op == "Conv": + x, w = inputs[0], inputs[1] + b = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None + groups = int(attrs.get("group", 1)) + dilations = list(attrs.get("dilations", [1, 1])) + strides = list(attrs.get("strides", [1, 1])) + pads = list(attrs.get("pads", [0, 0, 0, 0])) + N, C_in, H, W = x.shape + C_out, C_in_pg, kH, kW = w.shape + sH, sW = strides[0], strides[1] + dH, dW = dilations[0], dilations[1] + pH_t, pW_l = pads[0], pads[1] + pH_b, pW_r = pads[2], pads[3] + H_out = (H + pH_t + pH_b - dH * (kH - 1) - 1) // sH + 1 + W_out = (W + pW_l + pW_r - dW * (kW - 1) - 1) // sW + 1 + x_pad = np.pad(x, ((0,0),(0,0),(pH_t,pH_b),(pW_l,pW_r))) + out = np.zeros((N, C_out, H_out, W_out), dtype=np.float32) + cpp = C_out // groups + for g in range(groups): + xi = x_pad[:, g*C_in_pg:(g+1)*C_in_pg] + wg = w[g*cpp:(g+1)*cpp] + for oc in range(cpp): + for oh in range(H_out): + for ow in range(W_out): + patch = xi[:, :, + oh*sH:oh*sH+kH*dH:dH, + ow*sW:ow*sW+kW*dW:dW] + out[:, g*cpp+oc, oh, ow] = np.sum( + patch * wg[oc], axis=(1, 2, 3)) + if b is not None: + out += b[np.newaxis, :, np.newaxis, np.newaxis] + return [out] + + if op == "BatchNormalization": + x, scale, bias, mean, var = inputs[:5] + eps = float(attrs.get("epsilon", 1e-5)) + return [(scale * (x - mean) / np.sqrt(var + eps) + bias).astype(x.dtype)] + + if op == "LayerNormalization": + x = inputs[0] + scale = inputs[1] if len(inputs) > 1 and inputs[1] is not None else np.ones(1, dtype=x.dtype) + b = inputs[2] if len(inputs) > 2 and inputs[2] is not None else np.zeros(1, dtype=x.dtype) + axis = int(attrs.get("axis", -1)) + eps = float(attrs.get("epsilon", 1e-5)) + mean = np.mean(x, axis=axis, keepdims=True) + var = np.var(x, axis=axis, keepdims=True) + out = ((x - mean) / np.sqrt(var + eps)) * scale + b + return [out.astype(x.dtype)] + + if op == "GroupNormalization": + x, scale, bias = inputs[:3] + num_groups = int(attrs.get("num_groups", 1)) + eps = float(attrs.get("epsilon", 1e-5)) + N, C = x.shape[:2] + spatial = x.shape[2:] + xr = x.reshape(N, num_groups, C // num_groups, *spatial) + axes = tuple(range(2, xr.ndim)) + mean = np.mean(xr, axis=axes, keepdims=True) + var = np.var(xr, axis=axes, keepdims=True) + xn = ((xr - mean) / np.sqrt(var + eps)).reshape(N, C, *spatial) + return [(xn * scale.reshape(1, C, *([1]*len(spatial))) + + bias.reshape(1, C, *([1]*len(spatial)))).astype(x.dtype)] + + if op == "MaxPool": + x = inputs[0] + kernel = list(attrs.get("kernel_shape", [2, 2])) + strides = list(attrs.get("strides", [1, 1])) + pads = list(attrs.get("pads", [0, 0, 0, 0])) + N, C, H, W = x.shape + kH, kW = kernel[0], kernel[1] + sH, sW = strides[0], strides[1] + pH_t, pW_l, pH_b, pW_r = pads[0], pads[1], pads[2], pads[3] + xp = np.pad(x, ((0,0),(0,0),(pH_t,pH_b),(pW_l,pW_r)), constant_values=-np.inf) + H_out = (H + pH_t + pH_b - kH) // sH + 1 + W_out = (W + pW_l + pW_r - kW) // sW + 1 + out = np.stack([[ + np.max(xp[:, :, oh*sH:oh*sH+kH, ow*sW:ow*sW+kW], axis=(2, 3)) + for ow in range(W_out)] for oh in range(H_out)], axis=0) + return [out.transpose(2, 3, 0, 1)] # (N,C,H_out,W_out) + + if op == "AveragePool": + x = inputs[0] + kernel = list(attrs.get("kernel_shape", [2, 2])) + strides = list(attrs.get("strides", [1, 1])) + pads = list(attrs.get("pads", [0, 0, 0, 0])) + N, C, H, W = x.shape + kH, kW = kernel[0], kernel[1] + sH, sW = strides[0], strides[1] + pH_t, pW_l, pH_b, pW_r = pads[0], pads[1], pads[2], pads[3] + xp = np.pad(x, ((0,0),(0,0),(pH_t,pH_b),(pW_l,pW_r))) + H_out = (H + pH_t + pH_b - kH) // sH + 1 + W_out = (W + pW_l + pW_r - kW) // sW + 1 + out = np.stack([[ + np.mean(xp[:, :, oh*sH:oh*sH+kH, ow*sW:ow*sW+kW], axis=(2, 3)) + for ow in range(W_out)] for oh in range(H_out)], axis=0) + return [out.transpose(2, 3, 0, 1)] + + if op == "GlobalAveragePool": + x = inputs[0] + return [np.mean(x, axis=tuple(range(2, x.ndim)), keepdims=True)] + if op == "GlobalMaxPool": + x = inputs[0] + return [np.max(x, axis=tuple(range(2, x.ndim)), keepdims=True)] + + if op == "Reshape": + shape = [int(d) for d in inputs[1].flatten()] + src = inputs[0].shape + new_shape = [int(src[i]) if d == 0 else d for i, d in enumerate(shape)] + return [inputs[0].reshape(new_shape)] + if op == "Flatten": + axis = int(attrs.get("axis", 1)) + x = inputs[0] + pre = int(np.prod(x.shape[:axis])) if axis > 0 else 1 + post = int(np.prod(x.shape[axis:])) + return [x.reshape(pre, post)] + if op == "Transpose": + perm = attrs.get("perm", None) + return [np.transpose(inputs[0], axes=perm)] + + if op == "Squeeze": + x = inputs[0] + if len(inputs) > 1 and inputs[1] is not None: + axes = tuple(int(a) for a in inputs[1].flatten()) + else: + raw = attrs.get("axes", None) + if raw is None: + axes = None + elif isinstance(raw, (int, np.integer)): + axes = (int(raw),) + else: + axes = tuple(int(a) for a in raw) or None + return [np.squeeze(x, axis=axes)] + if op == "Unsqueeze": + x = inputs[0] + if len(inputs) > 1 and inputs[1] is not None: + axes = sorted(int(a) for a in inputs[1].flatten()) + else: + axes = sorted(attrs.get("axes", [])) + for ax in axes: + x = np.expand_dims(x, axis=ax) + return [x] + if op == "Expand": + return [np.broadcast_to(inputs[0], inputs[1].tolist()).copy()] + if op == "Concat": + axis = int(attrs.get("axis", 0)) + valid = [t for t in inputs if t is not None] + return [np.concatenate(valid, axis=axis)] + if op == "Split": + axis = int(attrs.get("axis", 0)) + split = list(attrs.get("split", [])) + if len(inputs) > 1 and inputs[1] is not None: + split = inputs[1].flatten().tolist() + if not split: + num_out = int(attrs.get("num_outputs", 2)) + split = [inputs[0].shape[axis] // num_out] * num_out + secs = np.cumsum([int(s) for s in split[:-1]]).tolist() + return list(np.split(inputs[0], secs, axis=axis)) + if op == "Slice": + data, starts, ends = inputs[0], inputs[1].flatten(), inputs[2].flatten() + axes = inputs[3].flatten() if len(inputs) > 3 and inputs[3] is not None else np.arange(len(starts)) + steps = inputs[4].flatten() if len(inputs) > 4 and inputs[4] is not None else np.ones(len(starts), dtype=np.int64) + idx = [slice(None)] * data.ndim + for ax, s, e, st in zip(axes, starts, ends, steps): + idx[int(ax)] = slice(int(s), int(e), int(st)) + return [data[tuple(idx)]] + if op == "Gather": + axis = int(attrs.get("axis", 0)) + return [np.take(inputs[0], inputs[1].astype(np.int64), axis=axis)] + if op == "GatherElements": + axis = int(attrs.get("axis", 0)) + return [np.take_along_axis(inputs[0], inputs[1].astype(np.int64), axis=axis)] + if op == "Shape": + start = int(attrs.get("start", 0)) + end = attrs.get("end", None) + sh = np.array(inputs[0].shape, dtype=np.int64) + return [sh[start:end]] + if op == "Size": + return [np.array(inputs[0].size, dtype=np.int64)] + if op == "Cast": + to_type = int(attrs.get("to", TensorProto.FLOAT)) + _DT = { + TensorProto.FLOAT: np.float32, + TensorProto.DOUBLE: np.float64, + TensorProto.INT32: np.int32, + TensorProto.INT64: np.int64, + TensorProto.INT8: np.int8, + TensorProto.UINT8: np.uint8, + TensorProto.BOOL: bool, + TensorProto.FLOAT16: np.float16, + } + return [inputs[0].astype(_DT.get(to_type, np.float32))] + if op == "Identity": + return [inputs[0].copy()] + if op == "Pad": + x = inputs[0] + pads_arr = inputs[1].flatten().tolist() if len(inputs) > 1 and inputs[1] is not None else list(attrs.get("pads", [])) + val = float(inputs[2].flat[0]) if len(inputs) > 2 and inputs[2] is not None else float(attrs.get("value", 0.0)) + n = x.ndim + pw = [(int(pads_arr[i]), int(pads_arr[i+n])) for i in range(n)] + return [np.pad(x, pw, constant_values=val)] + if op == "Tile": + return [np.tile(inputs[0], inputs[1].flatten().tolist())] + if op == "Where": + return [np.where(inputs[0], inputs[1], inputs[2])] + if op in ("Equal", "Less", "Greater", "LessOrEqual", "GreaterOrEqual"): + fn = {"Equal": np.equal, "Less": np.less, "Greater": np.greater, + "LessOrEqual": np.less_equal, "GreaterOrEqual": np.greater_equal}[op] + return [fn(inputs[0], inputs[1])] + if op == "Not": + return [~inputs[0]] + if op == "Min": + r = inputs[0] + for t in inputs[1:]: + r = np.minimum(r, t) + return [r] + if op == "Max": + r = inputs[0] + for t in inputs[1:]: + r = np.maximum(r, t) + return [r] + if op == "Sum": + r = inputs[0].copy() + for t in inputs[1:]: + r = r + t + return [r] + if op == "Mean": + return [sum(inputs) / len(inputs)] + if op == "Einsum": + return [np.einsum(attrs["equation"], *inputs)] + if op == "ArgMax": + axis = int(attrs.get("axis", 0)) + keepdims = bool(attrs.get("keepdims", 1)) + return [np.argmax(inputs[0], axis=axis, keepdims=keepdims).astype(np.int64)] + if op == "ArgMin": + axis = int(attrs.get("axis", 0)) + keepdims = bool(attrs.get("keepdims", 1)) + return [np.argmin(inputs[0], axis=axis, keepdims=keepdims).astype(np.int64)] + if op == "TopK": + k = int(inputs[1].flat[0]) if len(inputs) > 1 else int(attrs.get("k", 1)) + axis = int(attrs.get("axis", -1)) + largest = bool(attrs.get("largest", 1)) + idx = np.argsort(inputs[0], axis=axis) + if largest: + idx = np.flip(idx, axis=axis) + idx = np.take(idx, np.arange(k), axis=axis) + vals = np.take_along_axis(inputs[0], idx, axis=axis) + return [vals, idx.astype(np.int64)] + if op == "ConstantOfShape": + shape = inputs[0].flatten().tolist() + val_t = attrs.get("value", None) + fill = float(val_t.flat[0]) if val_t is not None else 0.0 + return [np.full([int(d) for d in shape], fill, dtype=np.float32)] + if op == "Constant": + val_t = attrs.get("value", None) + if val_t is not None: + return [val_t.copy() if isinstance(val_t, np.ndarray) else np.array(val_t)] + vf = attrs.get("value_float", None) + if vf is not None: + return [np.array(vf, dtype=np.float32)] + vi = attrs.get("value_int", None) + if vi is not None: + return [np.array(vi, dtype=np.int64)] + raise ValueError("Constant node has no supported value attribute") + if op == "Range": + s, lim, d = inputs + return [np.arange(float(s), float(lim), float(d)).astype(s.dtype)] + if op == "NonZero": + return [np.array(np.nonzero(inputs[0]), dtype=np.int64)] + if op == "Dropout": + return [inputs[0].copy(), np.ones_like(inputs[0], dtype=bool)] + if op == "SoftmaxCrossEntropyLoss": + logits = inputs[0] + labels = inputs[1] + xm = np.max(logits, axis=-1, keepdims=True) + lp = logits - xm - np.log(np.sum(np.exp(logits - xm), axis=-1, keepdims=True)) + if labels.dtype in (np.int32, np.int64): + nll = -lp[np.arange(logits.shape[0]), labels.flatten()] + else: + nll = -np.sum(lp * labels, axis=-1) + red = attrs.get("reduction", "mean") + loss = np.mean(nll) if red == "mean" else (np.sum(nll) if red == "sum" else nll) + return [np.array(loss, dtype=np.float32), lp.astype(np.float32)] + + raise NotImplementedError(f"Standard ONNX op '{op}' is not implemented in run_onnx_graph") + + +# --------------------------------------------------------------------------- +# Deeploy custom ops (ai.onnx.contrib) +# --------------------------------------------------------------------------- + +def _exec_deeploy(op: str, inputs: List, attrs: Dict[str, Any]) -> List[np.ndarray]: + if op == "Quant": + x = inputs[0].astype(np.float64) + scale = inputs[1].astype(np.float64) + zp = inputs[2].astype(np.float64) + bits = int(attrs.get("bits", 8)) + signed = bool(attrs.get("signed", True)) + qmin = -(2 ** (bits-1)) if signed else 0 + qmax = (2 ** (bits-1)) - 1 if signed else (2**bits) - 1 + return [np.clip(np.round(x / scale) + zp, qmin, qmax).astype(np.int8 if signed else np.uint8)] + + if op == "Dequant": + q = inputs[0].astype(np.float64) + scale = inputs[1].astype(np.float64) + zp = inputs[2].astype(np.float64) + return [((q - zp) * scale).astype(np.float32)] + + if op == "RequantShift": + x = inputs[0].astype(np.int64) + mul = inputs[1].astype(np.int64) + add = inputs[2].astype(np.int64) + div_attr = attrs.get("div", None) + if div_attr is None: + raise ValueError("RequantShift missing 'div' tensor attribute") + div = int(div_attr.flat[0]) if hasattr(div_attr, "flat") else int(div_attr) + bits = int(attrs.get("out_bits", 8)) + signed = bool(attrs.get("signed", 1)) + qmin = -(2**(bits-1)) if signed else 0 + qmax = (2**(bits-1))-1 if signed else (2**bits)-1 + return [np.clip((x * mul + add) // div, qmin, qmax).astype(np.int8 if signed else np.uint8)] + + raise NotImplementedError(f"Deeploy op '{op}' not implemented") + + +# --------------------------------------------------------------------------- +# MeZO perturbation ops (mezo / com.microsoft) +# --------------------------------------------------------------------------- + +def _exec_mezo(op: str, inputs: List, attrs: Dict[str, Any]) -> List[np.ndarray]: + seed = int(attrs.get("seed", 0)) + node_id = int(attrs.get("idx", 0)) + eps = float(attrs.get("eps", 0.01)) + sign = 1 # forward perturbation sign + + if op == "PerturbUniform": + return [_perturb_uniform(inputs[0].astype(np.float32), seed, node_id, eps, sign)] + if op == "PerturbRademacher": + return [_perturb_rademacher(inputs[0].astype(np.float32), seed, node_id, eps, sign)] + if op == "PerturbNormal": + return [_perturb_normal(inputs[0].astype(np.float32), seed, node_id, eps, sign)] + if op == "PerturbTriangle": + return [_perturb_uniform(inputs[0].astype(np.float32), seed, node_id, eps, sign)] + if op == "PerturbEggroll": + # Generates a Rademacher column vector of the requested shape + shape = [int(d) for d in inputs[0].flatten()] + x = np.zeros(shape, dtype=np.float32) + return [_perturb_rademacher(x, seed, node_id, 1.0, 1)] + if op == "RQSPerturbRademacher": + mul = inputs[1] if len(inputs) > 1 and inputs[1] is not None else np.ones(inputs[0].shape[0], dtype=np.int64) + div = int(attrs.get("div", 2**15)) + n_levels = int(attrs.get("n_levels", 256)) + signed_flag = int(attrs.get("signed", 1)) + return [_perturb_rqs_rademacher(inputs[0], mul, seed, node_id, div, n_levels, signed_flag, sign)] + if op == "RQSPerturbUniform": + mul = inputs[1] if len(inputs) > 1 and inputs[1] is not None else np.ones(inputs[0].shape[0], dtype=np.int64) + div = int(attrs.get("div", 2**15)) + n_levels = int(attrs.get("n_levels", 256)) + signed_flag = int(attrs.get("signed", 1)) + return [_perturb_rqs_uniform(inputs[0], mul, seed, node_id, div, n_levels, signed_flag, sign)] + + raise NotImplementedError(f"MeZO op '{op}' not implemented") + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +def run_onnx_graph( + onnx_path: str, + inputs: Dict[str, np.ndarray], + output_names: Optional[List[str]] = None, +) -> np.ndarray: + """ + Execute an ONNX graph in pure Python โ€” no onnxruntime required. + + Supports standard ONNX ops, Deeploy custom ops (ai.onnx.contrib domain: + Quant, Dequant, RequantShift), and MeZO perturbation ops (mezo domain: + PerturbUniform, PerturbRademacher, PerturbNormal, PerturbTriangle, + PerturbEggroll, RQSPerturbRademacher, RQSPerturbUniform). + + Args: + onnx_path: Path to the ``.onnx`` model file. + inputs: ``{name: ndarray}`` for each graph input. + output_names: If given, return those specific outputs. + Otherwise the first graph output is returned. + + Returns: + First (or requested) graph output as a numpy array. + """ + model = onnx.load(onnx_path) + graph = model.graph + + # value store: name โ†’ numpy array + values: Dict[str, np.ndarray] = {} + values.update(inputs) + + for init in graph.initializer: + if init.name not in values: + values[init.name] = numpy_helper.to_array(init) + + # ONNX graph is topologically sorted by spec + for node in graph.node: + op = node.op_type + domain = (node.domain or "").strip() + attrs = _node_attrs(node) + + # Gather inputs; absent optional inputs (empty string) become None + node_inputs: List = [] + for name in node.input: + if name == "": + node_inputs.append(None) + elif name in values: + node_inputs.append(values[name]) + else: + raise KeyError( + f"Node '{node.name}' (op='{op}') needs input '{name}' " + f"which is missing from value map." + ) + + # com.microsoft is used for both ORT contrib ops (e.g. Gelu, + # Attention) and MeZO perturbation ops. Route by op name. + _MEZO_OPS = { + "PerturbUniform", "PerturbRademacher", "PerturbNormal", + "PerturbTriangle", "PerturbEggroll", + "RQSPerturbRademacher", "RQSPerturbUniform", + } + + try: + if domain in ("", "ai.onnx"): + outs = _exec_standard(op, node_inputs, attrs) + elif domain == "ai.onnx.contrib": + outs = _exec_deeploy(op, node_inputs, attrs) + elif domain == "mezo" or (domain == "com.microsoft" and op in _MEZO_OPS): + outs = _exec_mezo(op, node_inputs, attrs) + elif domain == "com.microsoft": + # ORT contrib ops that map to standard implementations + outs = _exec_standard(op, node_inputs, attrs) + else: + raise NotImplementedError( + f"Unsupported domain '{domain}' for op '{op}'" + ) + except NotImplementedError: + raise + except Exception as exc: + raise RuntimeError( + f"Error in node '{node.name}' op='{op}' domain='{domain}': {exc}" + ) from exc + + for out_name, out_val in zip(node.output, outs): + if out_name: + values[out_name] = out_val + + graph_outputs = [o.name for o in graph.output] + if output_names is not None: + return [values[n] for n in output_names] + if graph_outputs: + return values[graph_outputs[0]] + raise RuntimeError("ONNX graph has no outputs") diff --git a/pyproject.toml b/pyproject.toml index 2acea90..d95c4ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,12 @@ dependencies = [ "onnx>=1.15.0,<1.17.0", "onnx-graphsurgeon>=0.5.0", "onnxruntime-training==1.19.2", + "onnxruntime-extensions>=0.13.0", "onnxscript>=0.1.0", "onnxsim>=0.4.0", "numpy>=1.24.0,<2.0.0", "pyyaml>=6.0", + "matplotlib>=3.7.0", ] [project.optional-dependencies] @@ -105,6 +107,8 @@ markers = [ "baseline: tests that require baseline comparison", "inference: marks tests for inference mode", "training: marks tests for training mode", + "quantized: marks tests for quantized models", + "zo: marks tests for zeroth-order perturbation", ] [tool.isort] diff --git a/requirements.txt b/requirements.txt index 0d171bf..79d1280 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,19 +3,105 @@ # SPDX-License-Identifier: MIT # Core dependencies -torch>=2.0.0 -onnx>=1.15.0,<1.17.0 -onnx-graphsurgeon>=0.5.0 +torch==2.10.0 +onnx==1.16.2 +onnx-graphsurgeon==0.5.8 +onnx-ir==0.2.0 onnxruntime-training==1.19.2 -onnxscript>=0.1.0 -onnxsim>=0.4.0 -numpy>=1.24.0,<2.0.0 -pyyaml>=6.0 +onnxruntime_extensions==0.13.0 +onnxoptimizer==0.4.2 +onnxscript==0.5.7 +onnxsim==0.6.2 +numpy==1.26.4 +pyyaml==6.0.3 +brevitas==0.12.1 +matplotlib==3.10.8 +deepquant==0.4.3 -# Optional visualization dependencies -# Install with: pip install -e ".[visualization]" -# beautifulsoup4>=4.0.0 -# pandas>=2.0.0 +# CUDA / GPU dependencies +cuda-bindings==12.9.4 +cuda-pathfinder==1.4.2 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.27.5 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvshmem-cu12==3.4.5 +nvidia-nvtx-cu12==12.8.90 +triton==3.6.0 + +# General dependencies +astor==0.8.1 +Bottleneck==1.6.0 +Cerberus==1.3.8 +certifi==2026.2.25 +charset-normalizer==3.4.5 +colorama==0.4.6 +colored==2.3.1 +colored-logs==0.2.10 +coloredlogs==15.0.1 +contourpy==1.3.2 +cycler==0.12.1 +dependencies==2.0.1 +empyrical-reloaded==0.5.12 +filelock==3.25.2 +flatbuffers==25.12.19 +fonttools==4.62.0 +fsspec==2026.2.0 +h5py==3.16.0 +hdf5plugin==5.0.0 +humanfriendly==10.0 +idna==3.11 +Jinja2==3.1.6 +jsonpickle==4.1.1 +kiwisolver==1.5.0 +Logbook==1.9.2 +lru-dict==1.3.0 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +mdurl==0.1.2 +methodtools==0.4.7 +ml_dtypes==0.5.4 +mpmath==1.3.0 +networkx==3.4.2 +packaging==26.0 +pandas==2.3.3 +patsy==1.0.2 +peewee==3.17.3 +pillow==12.1.1 +prettytable==3.17.0 +protobuf==3.20.3 +pyecharts==2.1.0 +Pygments==2.19.2 +PyJWT==2.11.0 +pyparsing==3.3.2 +python-dateutil==2.9.0.post0 +pytz==2026.1.post1 +requests==2.32.5 +RestrictedPython==8.1 +rich==14.3.3 +scipy==1.15.3 +simplejson==3.20.2 +six==1.17.0 +statsmodels==0.14.6 +sympy==1.14.0 +tabulate==0.10.0 +tqdm==4.67.1 +typing_extensions==4.15.0 +tzdata==2025.3 +unfoldNd==0.2.3 +urllib3==2.6.3 +wcwidth==0.6.0 +websocket-client==1.9.0 +wirerope==1.0.0 # Development dependencies # Install with: pip install -e ".[dev]" diff --git a/tests/models/conftest.py b/tests/models/conftest.py index fd4589c..88ae067 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -63,6 +63,19 @@ def epidenet_config(): } +@pytest.fixture +def qlitecnn_config(): + """Default configuration for QLiteCNN quantized model.""" + return { + "batch_size": 1, + "input_channels": 1, + "input_height": 28, + "input_width": 28, + "num_classes": 10, + "opset_version": 17, + } + + @pytest.fixture def mibminet_config(): """Default configuration for MI-BMInet model.""" diff --git a/tests/models/onnx_node_implementations.py b/tests/models/onnx_node_implementations.py new file mode 100644 index 0000000..604efc1 --- /dev/null +++ b/tests/models/onnx_node_implementations.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""Re-export from canonical location in onnx4deeploy.utils.""" + +from onnx4deeploy.utils.onnx_node_implementations import * # noqa: F401, F403 +from onnx4deeploy.utils.onnx_node_implementations import run_onnx_graph # noqa: F401 diff --git a/tests/models/test_inference_consistency.py b/tests/models/test_inference_consistency.py new file mode 100644 index 0000000..7ff4616 --- /dev/null +++ b/tests/models/test_inference_consistency.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +Inference consistency tests: onnxruntime vs pure-Python run_onnx_graph. + +For each model exported in inference mode, we verify that our pure-Python +ONNX executor (``run_onnx_graph``) produces numerically identical results +to onnxruntime's ``InferenceSession`` on the same random inputs. + +Models tested: + - LightweightCNN (standard ops only โ†’ ORT + run_onnx_graph agree) + - SleepConViT (contains com.microsoft/Gelu and a Squeeze with opset-12 + axes attribute; patched to opset-13 in-memory before ORT) +""" + +import os +import subprocess +import sys + +import numpy as np +import onnx +import onnx.numpy_helper as numpy_helper +import onnxruntime as ort +import pytest + +from .onnx_node_implementations import run_onnx_graph +from .test_utils import load_and_check_onnx_model + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_PROJECT_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +_CLI_SCRIPT = os.path.join(_PROJECT_ROOT, "Onnx4Deeploy.py") + +_NUM_SAMPLES = 5 +_TOLERANCE = 1e-5 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _export_inference(model_name: str, output_dir: str) -> str: + """Run the CLI in 'infer' mode and return the path to network.onnx.""" + cmd = [ + sys.executable, _CLI_SCRIPT, + "-model", model_name, + "-mode", "infer", + "-o", output_dir, + ] + result = subprocess.run( + cmd, cwd=_PROJECT_ROOT, capture_output=True, text=True + ) + if result.returncode != 0: + pytest.fail( + f"CLI failed for '{model_name}' (rc={result.returncode}):\n" + f"stdout: {result.stdout}\nstderr: {result.stderr}" + ) + onnx_file = os.path.join(output_dir, "network.onnx") + assert os.path.exists(onnx_file), f"network.onnx not found in {output_dir}" + return onnx_file + + +def _to_opset13_compatible(onnx_file: str) -> bytes: + """ + Return a serialised ONNX model with all Squeeze/Unsqueeze nodes patched to + opset-13 style: the ``axes`` *attribute* is moved to a constant *input* + tensor so that onnxruntime (which validates opset-13+ rules) can load it. + + The original file on disk is not modified. + """ + model = onnx.load(onnx_file) + new_nodes = [] + extra_initializers = [] + + for node in model.graph.node: + if node.op_type not in ("Squeeze", "Unsqueeze"): + new_nodes.append(node) + continue + + # Find the axes attribute (opset-12 style) + axes_attr = next((a for a in node.attribute if a.name == "axes"), None) + if axes_attr is None: + # Already opset-13 style (axes come as a second input) or no axes + new_nodes.append(node) + continue + + # Extract axes value(s) + if axes_attr.type == onnx.AttributeProto.INT: + axes = [int(axes_attr.i)] + else: # INTS + axes = list(axes_attr.ints) + + # Create a constant initializer for the axes tensor + axes_name = f"_axes_const_{node.name}" + axes_tensor = onnx.helper.make_tensor( + name=axes_name, + data_type=onnx.TensorProto.INT64, + dims=[len(axes)], + vals=axes, + ) + extra_initializers.append(axes_tensor) + + # Rebuild the node with axes as the second input, no axes attribute + new_node = onnx.helper.make_node( + op_type=node.op_type, + inputs=[node.input[0], axes_name], + outputs=list(node.output), + name=node.name, + domain=node.domain if node.domain else "", + ) + # Copy over any other attributes (not axes) + for attr in node.attribute: + if attr.name != "axes": + new_node.attribute.append(attr) + + new_nodes.append(new_node) + + # Rebuild graph with patched nodes and extra initializers + new_graph = onnx.helper.make_graph( + nodes=new_nodes, + name=model.graph.name, + inputs=list(model.graph.input), + outputs=list(model.graph.output), + initializer=list(model.graph.initializer) + extra_initializers, + ) + for vi in model.graph.value_info: + new_graph.value_info.append(vi) + + new_model = onnx.helper.make_model( + new_graph, + producer_name=model.producer_name, + opset_imports=list(model.opset_import), + ) + new_model.ir_version = model.ir_version + return new_model.SerializeToString() + + +def _ort_session(onnx_file: str) -> ort.InferenceSession: + """Create an onnxruntime InferenceSession, patching Squeeze axes to opset-13.""" + return ort.InferenceSession(_to_opset13_compatible(onnx_file)) + + +def _compare_outputs( + onnx_file: str, + input_shape: tuple, + num_samples: int = _NUM_SAMPLES, + tolerance: float = _TOLERANCE, +) -> None: + """ + Run *num_samples* random inputs through both onnxruntime and run_onnx_graph + and assert the outputs agree within *tolerance*. + + ORT receives the opset-13-patched model bytes; run_onnx_graph reads the + original file (our executor handles both styles). + """ + sess = _ort_session(onnx_file) + ort_input_name = sess.get_inputs()[0].name + + failures = [] + for seed in range(num_samples): + rng = np.random.default_rng(seed) + x = rng.standard_normal(input_shape).astype(np.float32) + + # onnxruntime reference + ort_out = sess.run(None, {ort_input_name: x})[0] + + # pure-Python executor + py_out = run_onnx_graph(onnx_file, {"input": x}) + + max_diff = float(np.max(np.abs(ort_out - py_out))) + if max_diff > tolerance: + failures.append( + f" seed={seed}: max |ORT โˆ’ run_onnx_graph| = {max_diff:.2e} " + f"(limit {tolerance:.2e})" + ) + + if failures: + pytest.fail( + f"Outputs diverge on {len(failures)}/{num_samples} samples:\n" + + "\n".join(failures) + ) + + +# =========================================================================== +# LightweightCNN +# =========================================================================== + + +class TestLightweightCNNInferenceConsistency: + """ + LightweightCNN uses only standard ONNX ops (Conv, MaxPool, Relu, Gemm, + Reshape), so both onnxruntime and run_onnx_graph can execute it. + This test verifies they agree numerically. + """ + + _MODEL = "LightweightCNN" + _INPUT_SHAPE = (1, 1, 28, 28) + + def test_export_produces_valid_onnx(self, model_test_dir): + """Exported network.onnx loads and has correct structure.""" + onnx_file = _export_inference(self._MODEL, model_test_dir) + model = load_and_check_onnx_model(onnx_file, skip_shape_check=True) + assert len(model.graph.node) > 0 + assert any(o.name for o in model.graph.output) + + def test_ort_and_pure_python_agree(self, model_test_dir): + """ + onnxruntime and run_onnx_graph produce identical outputs (within 1e-5) + on 5 random inputs for LightweightCNN. + """ + onnx_file = _export_inference(self._MODEL, model_test_dir) + _compare_outputs(onnx_file, self._INPUT_SHAPE) + + def test_run_onnx_graph_produces_finite_output(self, model_test_dir): + """run_onnx_graph produces finite float32 output for each random input.""" + onnx_file = _export_inference(self._MODEL, model_test_dir) + for seed in range(_NUM_SAMPLES): + rng = np.random.default_rng(seed) + x = rng.standard_normal(self._INPUT_SHAPE).astype(np.float32) + out = run_onnx_graph(onnx_file, {"input": x}) + assert out.shape == (1, 10), f"Unexpected output shape: {out.shape}" + assert np.all(np.isfinite(out)), f"Non-finite output at seed={seed}" + + def test_run_onnx_graph_output_shape(self, model_test_dir): + """Output shape is (1, num_classes=10).""" + onnx_file = _export_inference(self._MODEL, model_test_dir) + x = np.zeros(self._INPUT_SHAPE, dtype=np.float32) + out = run_onnx_graph(onnx_file, {"input": x}) + assert out.shape == (1, 10) + + +# =========================================================================== +# SleepConViT +# =========================================================================== + + +class TestSleepConViTInferenceConsistency: + """ + SleepConViT uses com.microsoft/Gelu and a Squeeze node with an ``axes`` + attribute (opset-12 style). Before handing the model to ORT the test + patches Squeeze/Unsqueeze nodes in-memory via ``_to_opset13_compatible`` + so that onnxruntime can load and execute it. + """ + + _MODEL = "SleepConViT" + _INPUT_SHAPE = (1, 1, 1, 3000) + _NUM_CLASSES = 4 + + def test_export_produces_valid_onnx(self, model_test_dir): + """Exported network.onnx loads without error.""" + onnx_file = _export_inference(self._MODEL, model_test_dir) + model = load_and_check_onnx_model(onnx_file, skip_shape_check=True) + assert len(model.graph.node) > 0 + + def test_run_onnx_graph_produces_finite_output(self, model_test_dir): + """ + run_onnx_graph produces finite float32 output for each random input. + This exercises com.microsoft/Gelu and the Squeeze op in our executor. + """ + onnx_file = _export_inference(self._MODEL, model_test_dir) + for seed in range(_NUM_SAMPLES): + rng = np.random.default_rng(seed) + x = rng.standard_normal(self._INPUT_SHAPE).astype(np.float32) + out = run_onnx_graph(onnx_file, {"input": x}) + assert np.all(np.isfinite(out)), f"Non-finite output at seed={seed}" + + def test_run_onnx_graph_output_shape(self, model_test_dir): + """Output shape is (1, num_classes).""" + onnx_file = _export_inference(self._MODEL, model_test_dir) + x = np.zeros(self._INPUT_SHAPE, dtype=np.float32) + out = run_onnx_graph(onnx_file, {"input": x}) + assert out.shape[0] == 1 + assert out.shape[-1] == self._NUM_CLASSES + + def test_run_onnx_graph_deterministic(self, model_test_dir): + """run_onnx_graph produces the same result on repeated calls.""" + onnx_file = _export_inference(self._MODEL, model_test_dir) + rng = np.random.default_rng(0) + x = rng.standard_normal(self._INPUT_SHAPE).astype(np.float32) + out1 = run_onnx_graph(onnx_file, {"input": x}) + out2 = run_onnx_graph(onnx_file, {"input": x}) + np.testing.assert_array_equal(out1, out2) + + def test_ort_and_pure_python_agree(self, model_test_dir): + """ + onnxruntime and run_onnx_graph produce identical outputs (within 1e-5) + on 5 random inputs for SleepConViT. + + The model is patched in-memory via ``_to_opset13_compatible`` before + being passed to ORT so that the opset-12-style Squeeze axes attribute + does not cause a load failure. + """ + onnx_file = _export_inference(self._MODEL, model_test_dir) + _compare_outputs(onnx_file, self._INPUT_SHAPE) diff --git a/tests/models/test_qlitecnn.py b/tests/models/test_qlitecnn.py new file mode 100644 index 0000000..cae43c4 --- /dev/null +++ b/tests/models/test_qlitecnn.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +Tests for QLiteCNN quantized model export. + +Tests the full PTQ calibration (Brevitas) โ†’ Onnx4Deeploy export โ†’ numerical +verification pipeline for the QLiteCNN model. + +Numerical verification is performed by ``run_onnx_graph`` from +``onnx_node_implementations``, a pure Python / PyTorch graph executor that +supports standard ONNX ops, Deeploy custom nodes (Quant, Dequant, +RequantShift) and the MeZO perturbation operators (PerturbNormal, etc.). +""" + +import os +import subprocess +import sys + +import numpy as np +import pytest +import torch +from brevitas.quant_tensor import QuantTensor + +from onnx4deeploy.models.pytorch_models.lightweight_cnn import QLiteCNN + +from .onnx_node_implementations import run_onnx_graph +from .test_utils import load_and_check_onnx_model + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_WEIGHTS_PATH = "onnx4deeploy/models/pytorch_models/lightweight_cnn/qlite_cnn.pth" +_INPUT_SHAPE = (1, 1, 28, 28) +_TOLERANCE = 1.0 / 2**8 # 1/256 โ‰ˆ 0.0039 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _load_brevitas_model(weights_path: str, num_classes: int = 10) -> torch.nn.Module: + """Load QLiteCNN Brevitas model with pre-calibrated PTQ weights.""" + model = QLiteCNN( + batch_size=1, + input_channels=1, + num_classes=num_classes, + dropout=0.0, + ) + state_dict = torch.load(weights_path, map_location="cpu") + model.load_state_dict(state_dict, strict=False) + model.eval() + return model + + +def _run_brevitas_inference(model: torch.nn.Module, test_input: np.ndarray) -> np.ndarray: + """Run Brevitas model inference and return a float32 numpy array.""" + with torch.no_grad(): + output = model(torch.from_numpy(test_input)) + if isinstance(output, QuantTensor): + output = output.value + return output.numpy().astype(np.float32) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.quantized +class TestQLiteCNNQuantized: + """Test QLiteCNN quantized inference export and numerical correctness.""" + + def test_qlitecnn_ptq_export_and_numerical_correctness( + self, model_test_dir, qlitecnn_config + ): + """ + End-to-end test: PTQ calibration โ†’ Onnx4Deeploy export โ†’ numerical check. + + Steps: + 1. Load the pre-calibrated Brevitas QLiteCNN model. + 2. Run the Onnx4Deeploy ``q-infer`` command to export the model to ONNX. + 3. Loop over 10 random input samples (seeds 0โ€“9) and verify that the + ONNX graph output matches the Brevitas reference within a tolerance + of 1/2^8 for every sample. + """ + project_root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + + # ------------------------------------------------------------------ + # Step 1 โ€“ Load Brevitas model + # ------------------------------------------------------------------ + print("\n[PTQ] Loading QLiteCNN Brevitas model with pre-calibrated weights...") + weights_path = os.path.join(project_root, _WEIGHTS_PATH) + model = _load_brevitas_model( + weights_path, num_classes=qlitecnn_config["num_classes"] + ) + print("[PTQ] Model loaded.") + + # ------------------------------------------------------------------ + # Step 2 โ€“ Run Onnx4Deeploy q-infer export command (once) + # ------------------------------------------------------------------ + cli_script = os.path.join(project_root, "Onnx4Deeploy.py") + cmd = [ + sys.executable, cli_script, + "-model", "QLiteCNN", + "-mode", "q-infer", + "-o", model_test_dir, + ] + print(f"\n[Onnx4Deeploy] Running: {' '.join(cmd)}") + + result = subprocess.run( + cmd, + cwd=project_root, + capture_output=True, + text=True, + ) + if result.returncode != 0: + print(f"[Onnx4Deeploy] stdout:\n{result.stdout}") + print(f"[Onnx4Deeploy] stderr:\n{result.stderr}") + pytest.fail( + f"Onnx4Deeploy command failed with return code {result.returncode}" + ) + + onnx_file = os.path.join(model_test_dir, "network.onnx") + assert os.path.exists(onnx_file), f"ONNX file not found: {onnx_file}" + + # Verify basic ONNX validity (relaxed: skip strict check for custom ops) + load_and_check_onnx_model(onnx_file, skip_shape_check=True) + print(f"[Onnx4Deeploy] Export complete. ONNX saved at: {onnx_file}") + + # ------------------------------------------------------------------ + # Step 3 โ€“ Loop over 10 random inputs and check numerical correctness + # ------------------------------------------------------------------ + print( + "\n[Check] Running numerical check over 10 random input samples " + "(seeds 0โ€“9) ..." + ) + failures = [] + for seed in range(10): + rng = np.random.default_rng(seed) + test_input = rng.standard_normal(_INPUT_SHAPE).astype(np.float32) + + brevitas_output = _run_brevitas_inference(model, test_input) + onnx_output = run_onnx_graph(onnx_file, {"input": test_input}) + + max_diff = float(np.max(np.abs(onnx_output - brevitas_output))) + if max_diff > _TOLERANCE: + failures.append( + f" seed={seed}: max |onnx โˆ’ brevitas| = {max_diff:.6f} " + f"(limit {_TOLERANCE:.6f})\n" + f" Brevitas: {brevitas_output}\n" + f" ONNX: {onnx_output}" + ) + else: + print( + f"[Check] seed={seed} PASSED: " + f"max |onnx โˆ’ brevitas| = {max_diff:.6f} โ‰ค {_TOLERANCE:.6f}" + ) + + if failures: + pytest.fail( + f"Numerical check FAILED for {len(failures)}/10 samples:\n" + + "\n".join(failures) + ) diff --git a/tests/models/test_quant_transform.py b/tests/models/test_quant_transform.py new file mode 100644 index 0000000..e16b622 --- /dev/null +++ b/tests/models/test_quant_transform.py @@ -0,0 +1,525 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +Unit tests for individual passes in onnx4deeploy/transform/quant_transform.py. + +Strategy +-------- +Each test builds a *minimal* ONNX graph that exercises exactly one pass, +runs the *original* graph through ORT as the numerical reference, applies +the pass, then runs the *transformed* graph through ``run_onnx_graph`` +(the pure-Python executor in onnx_node_implementations.py) and asserts +the outputs are identical (or within rounding for RequantShift). + +Passes under test +----------------- +1. ``float_to_rqs_params`` pure arithmetic, no ONNX graph needed. +2. ``replace_qdq_with_deeploy`` QuantizeLinear โ†’ Quant +3. ``replace_qdq_with_deeploy`` DequantizeLinear โ†’ Dequant +4. ``replace_qdq_with_deeploy`` QDQ pair end-to-end +5. ``insert_rqs_from_map`` RequantShift insertion +""" + +import io +import tempfile +from pathlib import Path + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onph +import onnx_graphsurgeon as gs +import onnxruntime as ort +import pytest + +from onnx4deeploy.transform.quant_transform import ( + float_to_rqs_params, + insert_rqs_from_map, + replace_qdq_with_deeploy, +) +from .onnx_node_implementations import run_onnx_graph, exec_requant_shift + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _save_tmp(model: onnx.ModelProto) -> str: + """Save an ONNX model to a temporary file and return the path.""" + tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) + tmp.close() + onnx.save(model, tmp.name) + return tmp.name + + +def _ort_run(model_path: str, feed: dict) -> np.ndarray: + """Run a model with ORT and return the first output.""" + sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + return sess.run(None, feed)[0] + + +def _gs_to_tmp(graph: gs.Graph) -> str: + """Export a graphsurgeon graph to a temp ONNX file.""" + graph.cleanup().toposort() + return _save_tmp(gs.export_onnx(graph)) + + +# --------------------------------------------------------------------------- +# 1. float_to_rqs_params โ€“ pure arithmetic +# --------------------------------------------------------------------------- + + +class TestFloatToRqsParams: + """Verify that the integer RQS params reproduce the scale ratio.""" + + def _check_ratio(self, ratio: np.ndarray) -> None: + """ + For integer inputs spanning [-128, 127], verify that the RequantShift + formula ``clip(((x * mul) + add) >> shift, -128, 127)`` gives the same + result as the reference ``clip(round(x * ratio), -128, 127)``. + """ + mul, add, div = float_to_rqs_params(ratio) + shift = int(round(np.log2(div))) if div > 1 else 0 + + x = np.arange(-128, 128, dtype=np.int64) + + # Reference: exact float multiplication, rounded and clipped + ref = np.clip(np.round(x * ratio.flatten()[0]), -128, 127).astype(np.int64) + + # RequantShift formula โ€” use int64 to avoid overflow (mul can be up to 2^30) + mul64 = np.asarray(mul, dtype=np.int64).flatten()[0] + rqs = np.clip(((x * mul64) + np.int64(add)) >> shift, -128, 127).astype(np.int64) + + # Allow ยฑ1 LSB rounding difference + assert np.all(np.abs(rqs - ref) <= 1), ( + f"RQS params mismatch for ratio={ratio}: " + f"max diff = {np.max(np.abs(rqs - ref))}" + ) + + def test_scalar_ratio_halving(self): + """scale_src / scale_dst = 0.5 โ†’ requantisation halves the value.""" + self._check_ratio(np.array(0.5)) + + def test_scalar_ratio_doubling(self): + """scale_src / scale_dst = 2.0.""" + self._check_ratio(np.array(2.0)) + + def test_scalar_ratio_identity(self): + """scale_src == scale_dst โ†’ identity (ratio = 1.0).""" + self._check_ratio(np.array(1.0)) + + def test_scalar_ratio_arbitrary(self): + """Non-power-of-two ratio.""" + self._check_ratio(np.array(0.75)) + + def test_per_channel_vector(self): + """Per-channel scale vector: check that mul is a vector.""" + ratio = np.array([0.5, 1.0, 2.0, 0.25], dtype=np.float64) + mul, add, div = float_to_rqs_params(ratio) + assert mul.shape == ratio.shape, "mul must be per-channel" + + def test_div_is_power_of_two(self): + """div must always be a power of two (log2(div) is integer).""" + for r in [0.1, 0.333, 1.5, 3.7, 0.001]: + _, _, div = float_to_rqs_params(np.array(r)) + log2_div = np.log2(div) + assert abs(log2_div - round(log2_div)) < 1e-9, ( + f"div={div} is not a power of two for ratio={r}" + ) + + def test_mul_fits_int32(self): + """mul must fit in int32 for all tested ratios.""" + for r in [0.001, 0.5, 1.0, 100.0, 1234.56]: + mul, _, _ = float_to_rqs_params(np.array(r)) + assert np.all(np.abs(mul) <= 2**31 - 1) + + +# --------------------------------------------------------------------------- +# 2. replace_qdq_with_deeploy โ€“ QuantizeLinear โ†’ Quant +# --------------------------------------------------------------------------- + + +class TestReplaceQuantizeLinear: + """ + QuantizeLinear โ†’ Deeploy Quant. + + Reference: ORT runs the original QuantizeLinear node. + Transformed: run_onnx_graph runs the Quant node. + Expected: identical int8 output. + """ + + def _make_quantize_graph(self, scale: float, zero_point: int = 0) -> str: + """Build Input(float32) โ†’ QuantizeLinear(scale, zp) โ†’ Output(int8).""" + scale_t = oh.make_tensor("scale", onnx.TensorProto.FLOAT, [], [scale]) + zp_t = oh.make_tensor("zp", onnx.TensorProto.INT8, [], [zero_point]) + + node = oh.make_node( + "QuantizeLinear", + inputs=["input", "scale", "zp"], + outputs=["output"], + ) + graph = oh.make_graph( + [node], + "quant_graph", + [oh.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 8])], + [oh.make_tensor_value_info("output", onnx.TensorProto.INT8, [1, 8])], + initializer=[scale_t, zp_t], + ) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 13)]) + return _save_tmp(model) + + def test_quantize_per_tensor_zero_zp(self): + """Per-tensor quantization with zero_point=0.""" + original_path = self._make_quantize_graph(scale=0.5, zero_point=0) + x = np.array([[-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, -2.0]], dtype=np.float32) + + reference = _ort_run(original_path, {"input": x}) + + graph = gs.import_onnx(onnx.load(original_path)) + replace_qdq_with_deeploy(graph) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + np.testing.assert_array_equal( + result.astype(np.int8), reference, + err_msg="Quant output differs from ORT QuantizeLinear (zero_point=0)", + ) + + def test_quantize_per_tensor_nonzero_zp(self): + """Per-tensor quantization with non-zero zero_point.""" + original_path = self._make_quantize_graph(scale=0.25, zero_point=10) + x = np.array([[-1.0, 0.0, 0.5, 1.0, 2.0, -0.5, 0.25, 0.75]], dtype=np.float32) + + reference = _ort_run(original_path, {"input": x}) + + graph = gs.import_onnx(onnx.load(original_path)) + replace_qdq_with_deeploy(graph) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + np.testing.assert_array_equal( + result.astype(np.int8), reference, + err_msg="Quant output differs from ORT QuantizeLinear (nonzero zp)", + ) + + def test_quantize_saturation(self): + """Values outside the quantization range must saturate to ยฑ127.""" + original_path = self._make_quantize_graph(scale=0.1, zero_point=0) + x = np.array([[-200.0, -100.0, 0.0, 100.0, 200.0, 12.7, -12.8, 0.05]], + dtype=np.float32) + + reference = _ort_run(original_path, {"input": x}) + + graph = gs.import_onnx(onnx.load(original_path)) + replace_qdq_with_deeploy(graph) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + np.testing.assert_array_equal( + result.astype(np.int8), reference, + err_msg="Quant saturation differs from ORT", + ) + + +# --------------------------------------------------------------------------- +# 3. replace_qdq_with_deeploy โ€“ DequantizeLinear โ†’ Dequant +# --------------------------------------------------------------------------- + + +class TestReplaceDequantizeLinear: + """ + DequantizeLinear โ†’ Deeploy Dequant. + + Reference: ORT runs the original DequantizeLinear node. + Transformed: run_onnx_graph runs the Dequant node. + Expected: identical float32 output. + """ + + def _make_dequantize_graph(self, scale: float, zero_point: int = 0) -> str: + """Build Input(int8) โ†’ DequantizeLinear(scale, zp) โ†’ Output(float32).""" + scale_t = oh.make_tensor("scale", onnx.TensorProto.FLOAT, [], [scale]) + zp_t = oh.make_tensor("zp", onnx.TensorProto.INT8, [], [zero_point]) + + node = oh.make_node( + "DequantizeLinear", + inputs=["input", "scale", "zp"], + outputs=["output"], + ) + graph = oh.make_graph( + [node], + "dequant_graph", + [oh.make_tensor_value_info("input", onnx.TensorProto.INT8, [1, 8])], + [oh.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 8])], + initializer=[scale_t, zp_t], + ) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 13)]) + return _save_tmp(model) + + def test_dequantize_zero_zp(self): + """Per-tensor dequantization with zero_point=0.""" + original_path = self._make_dequantize_graph(scale=0.5, zero_point=0) + x = np.array([[-128, -64, -1, 0, 1, 64, 127, -100]], dtype=np.int8) + + reference = _ort_run(original_path, {"input": x}) + + graph = gs.import_onnx(onnx.load(original_path)) + replace_qdq_with_deeploy(graph) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + np.testing.assert_allclose( + result, reference, rtol=0, atol=1e-6, + err_msg="Dequant output differs from ORT DequantizeLinear (zero_point=0)", + ) + + def test_dequantize_nonzero_zp(self): + """Per-tensor dequantization with non-zero zero_point.""" + original_path = self._make_dequantize_graph(scale=0.25, zero_point=10) + x = np.array([[-128, -64, -10, 0, 10, 64, 127, -1]], dtype=np.int8) + + reference = _ort_run(original_path, {"input": x}) + + graph = gs.import_onnx(onnx.load(original_path)) + replace_qdq_with_deeploy(graph) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + np.testing.assert_allclose( + result, reference, rtol=0, atol=1e-6, + err_msg="Dequant output differs from ORT DequantizeLinear (nonzero zp)", + ) + + +# --------------------------------------------------------------------------- +# 4. replace_qdq_with_deeploy โ€“ full QDQ pair end-to-end +# --------------------------------------------------------------------------- + + +class TestReplaceQDQPair: + """ + Input(float32) โ†’ QuantizeLinear โ†’ DequantizeLinear โ†’ Output(float32). + + Reference: ORT runs the original QDQ graph. + Transformed: run_onnx_graph runs the Quant โ†’ Dequant graph. + Expected: identical float32 output (same rounding as ORT). + """ + + def _make_qdq_graph(self, scale: float, zero_point: int = 0) -> str: + scale_t = oh.make_tensor("scale", onnx.TensorProto.FLOAT, [], [scale]) + zp_t = oh.make_tensor("zp", onnx.TensorProto.INT8, [], [zero_point]) + + q_node = oh.make_node( + "QuantizeLinear", + inputs=["input", "scale", "zp"], + outputs=["quantized"], + ) + dq_node = oh.make_node( + "DequantizeLinear", + inputs=["quantized", "scale", "zp"], + outputs=["output"], + ) + graph = oh.make_graph( + [q_node, dq_node], + "qdq_graph", + [oh.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 8])], + [oh.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 8])], + initializer=[scale_t, zp_t], + ) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 13)]) + return _save_tmp(model) + + def test_qdq_roundtrip(self): + """QDQ roundtrip: float โ†’ quantize โ†’ dequantize โ†’ float.""" + original_path = self._make_qdq_graph(scale=0.5, zero_point=0) + x = np.array([[-1.0, -0.5, 0.0, 0.25, 0.5, 1.0, 1.5, 2.0]], dtype=np.float32) + + reference = _ort_run(original_path, {"input": x}) + + graph = gs.import_onnx(onnx.load(original_path)) + replace_qdq_with_deeploy(graph) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + np.testing.assert_allclose( + result, reference, rtol=0, atol=1e-6, + err_msg="QDQ roundtrip differs from ORT", + ) + + def test_qdq_roundtrip_small_scale(self): + """QDQ with a small scale (high precision).""" + original_path = self._make_qdq_graph(scale=0.00390625, zero_point=0) + x = np.random.default_rng(0).uniform(-0.5, 0.5, (1, 8)).astype(np.float32) + + reference = _ort_run(original_path, {"input": x}) + + graph = gs.import_onnx(onnx.load(original_path)) + replace_qdq_with_deeploy(graph) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + np.testing.assert_allclose( + result, reference, rtol=0, atol=1e-6, + err_msg="QDQ roundtrip (small scale) differs from ORT", + ) + + +# --------------------------------------------------------------------------- +# 5. insert_rqs_from_map โ€“ RequantShift insertion +# --------------------------------------------------------------------------- + + +class TestInsertRqs: + """ + RequantShift: rescale int8 tensor from scale s_src to scale s_dst. + + Reference: ORT runs QuantizeLinear(s_src) โ†’ DequantizeLinear(s_src) โ†’ + QuantizeLinear(s_dst) โ†’ DequantizeLinear(s_dst) to get the + exact floating-point result after double-quantisation. + Transformed: run_onnx_graph runs a graph where RQS replaces the second + QuantizeLinear, followed by Dequant(s_dst). + Tolerance: ยฑ1 LSB at scale s_dst (RequantShift introduces ยฑ1 rounding vs ORT). + """ + + def _make_double_qdq_graph( + self, scale_src: float, scale_dst: float + ) -> tuple: + """ + Build: float โ†’ QDQ(s_src) โ†’ QDQ(s_dst) โ†’ float + + Returns (ort_path, src_tensor_name, dst_tensor_name, gs_graph) + so that insert_rqs_from_map can be applied to the gs_graph. + """ + s_src_t = oh.make_tensor("s_src", onnx.TensorProto.FLOAT, [], [scale_src]) + zp_src_t = oh.make_tensor("zp_src", onnx.TensorProto.INT8, [], [0]) + s_dst_t = oh.make_tensor("s_dst", onnx.TensorProto.FLOAT, [], [scale_dst]) + zp_dst_t = oh.make_tensor("zp_dst", onnx.TensorProto.INT8, [], [0]) + + nodes = [ + oh.make_node("QuantizeLinear", ["input", "s_src", "zp_src"], ["q_src"]), + oh.make_node("DequantizeLinear", ["q_src", "s_src", "zp_src"], ["dq_src"]), + oh.make_node("QuantizeLinear", ["dq_src", "s_dst", "zp_dst"], ["q_dst"]), + oh.make_node("DequantizeLinear", ["q_dst", "s_dst", "zp_dst"], ["output"]), + ] + graph = oh.make_graph( + nodes, "double_qdq", + [oh.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 8])], + [oh.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 8])], + initializer=[s_src_t, zp_src_t, s_dst_t, zp_dst_t], + ) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 13)]) + ort_path = _save_tmp(model) + return ort_path + + def _make_rqs_target_graph( + self, scale_src: float, scale_dst: float + ) -> str: + """ + Build the graph that insert_rqs_from_map will operate on: + float โ†’ Quant(s_src) โ†’ โ†’ Dequant(s_dst) โ†’ float + + We represent the "gap" as two tensors q_src and q_dst of the same dtype, + connected via an Identity (which will be replaced by RQS after the pass). + """ + s_src_t = oh.make_tensor("s_src", onnx.TensorProto.FLOAT, [], [scale_src]) + zp_src_t = oh.make_tensor("zp_src", onnx.TensorProto.INT8, [], [0]) + s_dst_t = oh.make_tensor("s_dst", onnx.TensorProto.FLOAT, [], [scale_dst]) + zp_dst_t = oh.make_tensor("zp_dst", onnx.TensorProto.INT8, [], [0]) + + nodes = [ + oh.make_node("QuantizeLinear", ["input", "s_src", "zp_src"], ["q_src"]), + oh.make_node("Identity", ["q_src"], ["q_dst"]), + oh.make_node("DequantizeLinear", ["q_dst", "s_dst", "zp_dst"], ["output"]), + ] + graph_proto = oh.make_graph( + nodes, "rqs_target", + [oh.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 8])], + [oh.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 8])], + initializer=[s_src_t, zp_src_t, s_dst_t, zp_dst_t], + ) + model = oh.make_model(graph_proto, opset_imports=[oh.make_opsetid("", 13)]) + return _save_tmp(model) + + def _run_rqs_test(self, scale_src: float, scale_dst: float, seed: int = 42): + rng = np.random.default_rng(seed) + x = rng.uniform(-1.0, 1.0, (1, 8)).astype(np.float32) + + # Reference: ORT double-QDQ (quantize at s_src, requantize to s_dst) + ort_path = self._make_double_qdq_graph(scale_src, scale_dst) + reference = _ort_run(ort_path, {"input": x}) + + # Transformed: insert RequantShift between q_src and q_dst + gs_model_path = self._make_rqs_target_graph(scale_src, scale_dst) + graph = gs.import_onnx(onnx.load(gs_model_path)) + + rqs_map = { + "edges": [{ + "src_tensor": "q_src", + "dst_tensor": "q_dst", + "src_scale": scale_src, + "dst_scale": scale_dst, + }] + } + # First replace QDQ โ†’ Deeploy so the graph has Quant/Dequant nodes + replace_qdq_with_deeploy(graph) + insert_rqs_from_map(graph, rqs_map) + transformed_path = _gs_to_tmp(graph) + + result = run_onnx_graph(transformed_path, {"input": x}) + + # Allow ยฑ1 LSB tolerance at scale_dst (RQS introduces ยฑ1 rounding vs ORT) + atol = scale_dst + np.testing.assert_allclose( + result, reference, rtol=0, atol=atol, + err_msg=( + f"RQS output differs from ORT double-QDQ " + f"(s_src={scale_src}, s_dst={scale_dst}) " + f"by more than 1 LSB.\n" + f" ORT: {reference}\n" + f" RQS: {result}\n" + f" diff: {np.abs(result - reference)}" + ), + ) + + def test_rqs_halving(self): + """Requantise from s=0.5 to s=0.25 (factor 2 downscale).""" + self._run_rqs_test(scale_src=0.5, scale_dst=0.25) + + def test_rqs_doubling(self): + """Requantise from s=0.25 to s=0.5 (factor 2 upscale).""" + self._run_rqs_test(scale_src=0.25, scale_dst=0.5) + + def test_rqs_identity(self): + """Same scale: RequantShift should be a no-op (ยฑ1 LSB).""" + self._run_rqs_test(scale_src=0.5, scale_dst=0.5) + + def test_rqs_arbitrary_ratio(self): + """Non-power-of-two scale ratio.""" + self._run_rqs_test(scale_src=0.03125, scale_dst=0.06395246833562851) + + def test_rqs_params_roundtrip_via_exec(self): + """ + Direct unit test: exec_requant_shift with params from float_to_rqs_params + must match the reference integer requantisation. + """ + scale_src, scale_dst = 0.5, 0.25 + ratio = np.array(scale_src / scale_dst) + mul, add, div = float_to_rqs_params(ratio) + + x = np.arange(-128, 128, dtype=np.int8) + result = exec_requant_shift(x, mul, np.array(add, np.int32), np.array(div, np.int32)) + + reference = np.clip(np.round(x.astype(np.float64) * ratio), -128, 127).astype(np.int8) + + assert np.all(np.abs(result.astype(np.int32) - reference.astype(np.int32)) <= 1), ( + f"exec_requant_shift deviates from reference by more than 1 LSB: " + f"max diff = {np.max(np.abs(result.astype(np.int32) - reference.astype(np.int32)))}" + ) diff --git a/tests/models/test_zo_perturbation.py b/tests/models/test_zo_perturbation.py new file mode 100644 index 0000000..f5314c2 --- /dev/null +++ b/tests/models/test_zo_perturbation.py @@ -0,0 +1,419 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +Tests for ZO (zeroth-order) perturbation model exports. + +Tests the full export pipeline for LightweightCNN, QLiteCNN, SleepConViT, +and QSleepConViT with various noise types (Uniform, Rademacher, Eggroll, +RQS-Rademacher). + +Numerical verification uses ``run_onnx_graph`` from +``onnx_node_implementations``, a pure Python / PyTorch graph executor. +""" + +import os +import subprocess +import sys + +import numpy as np +import onnx +import pytest + +from .onnx_node_implementations import run_onnx_graph +from .test_utils import load_and_check_onnx_model + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_PROJECT_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +_CLI_SCRIPT = os.path.join(_PROJECT_ROOT, "Onnx4Deeploy.py") + +NUM_CORES = 8 + +# --------------------------------------------------------------------------- +# RNG Reference Implementation (matches Deeploy C++ code) +# --------------------------------------------------------------------------- + + +def scramble_seed(seed: int) -> np.uint32: + """Scramble seed: seed * 1664525 + 1013904223 (mod 2^32).""" + return np.uint32(np.uint32(seed) * np.uint32(1664525) + np.uint32(1013904223)) + + +def xorshift32(state: np.uint32) -> np.uint32: + """Xorshift32 PRNG step.""" + state = np.uint32(state) + state ^= np.uint32(state << np.uint32(13)) + state ^= np.uint32(state >> np.uint32(17)) + state ^= np.uint32(state << np.uint32(5)) + return state + + +def generate_uniform_perturbation( + data: np.ndarray, + global_seed: int, + node_id: int, + eps: float, + perturbation_sign: int = 1, +) -> np.ndarray: + """Generate uniform perturbation matching Deeploy C++ reference. + + seed = scramble(initial_global_seed + NUM_CORES * node_id + core_id) + RNG: Xorshift32, mapped to [-1, 1], scaled by eps * sqrt(3). + """ + size = data.size + output = data.flatten().copy() + scale = eps * np.sqrt(3.0) + + for core_id in range(NUM_CORES): + log2core = int(np.log2(NUM_CORES)) + chunk = (size >> log2core) + ((size & (NUM_CORES - 1)) != 0) + chunk_start = min(chunk * core_id, size) + chunk_stop = min(chunk_start + chunk, size) + + seed = scramble_seed(global_seed + NUM_CORES * node_id + core_id) + + for i in range(chunk_start, chunk_stop): + seed = xorshift32(seed) + # Map uint32 to [-1, 1] + rand_val = (float(seed) / float(np.iinfo(np.uint32).max)) * 2.0 - 1.0 + output[i] = output[i] + perturbation_sign * rand_val * scale + + return output.reshape(data.shape) + + +def generate_rademacher_perturbation( + data: np.ndarray, + global_seed: int, + node_id: int, + eps: float, + perturbation_sign: int = 1, +) -> np.ndarray: + """Generate Rademacher perturbation (+/- eps) matching Deeploy C++ reference.""" + size = data.size + output = data.flatten().copy() + + for core_id in range(NUM_CORES): + log2core = int(np.log2(NUM_CORES)) + chunk = (size >> log2core) + ((size & (NUM_CORES - 1)) != 0) + chunk_start = min(chunk * core_id, size) + chunk_stop = min(chunk_start + chunk, size) + + seed = scramble_seed(global_seed + NUM_CORES * node_id + core_id) + + for i in range(chunk_start, chunk_stop): + seed = xorshift32(seed) + sign = 1 if (seed & 1) else -1 + output[i] = output[i] + perturbation_sign * sign * eps + + return output.reshape(data.shape) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_cli(model: str, mode: str, output_dir: str, noise_type: str) -> subprocess.CompletedProcess: + """Run Onnx4Deeploy CLI and return result.""" + cmd = [ + sys.executable, _CLI_SCRIPT, + "-model", model, + "-mode", mode, + "-o", output_dir, + "--noise-type", noise_type, + ] + result = subprocess.run(cmd, cwd=_PROJECT_ROOT, capture_output=True, text=True) + if result.returncode != 0: + print(f"[CLI] stdout:\n{result.stdout}") + print(f"[CLI] stderr:\n{result.stderr}") + return result + + +def _verify_zo_output_files(output_dir: str, quantized: bool = False): + """Verify expected ZO output files exist.""" + expected = ["network_infer.onnx", "network_zo_train.onnx", "inputs.npz", "outputs.npz"] + if not quantized: + # Non-quantized may also have network_zo_update.onnx + pass + for fname in expected: + fpath = os.path.join(output_dir, fname) + assert os.path.exists(fpath), f"Missing expected output file: {fpath}" + + +def _verify_onnx_valid(onnx_file: str) -> onnx.ModelProto: + """Load and do basic validation of ONNX model.""" + return load_and_check_onnx_model(onnx_file, skip_shape_check=True) + + +def _count_perturbation_nodes(model: onnx.ModelProto, op_prefix: str = "Perturb") -> int: + """Count perturbation operator nodes in the graph.""" + return sum(1 for node in model.graph.node if node.op_type.startswith(op_prefix)) + + +def _get_perturbation_op_type(noise_type: str) -> str: + """Map noise type to expected ONNX op_type.""" + mapping = { + "uniform": "PerturbUniform", + "rademacher": "PerturbRademacher", + "eggroll": "PerturbRademacher", # Eggroll uses Rademacher + "rqs_rademacher": "RQSPerturbRademacher", + "rqs_uniform": "RQSPerturbUniform", + } + return mapping.get(noise_type, f"Perturb{noise_type.capitalize()}") + + +def _run_and_verify_zo_export( + model_name: str, + mode: str, + noise_type: str, + output_dir: str, + quantized: bool = False, +): + """Run CLI export and verify output files and ONNX validity.""" + result = _run_cli(model_name, mode, output_dir, noise_type) + assert result.returncode == 0, ( + f"CLI failed (rc={result.returncode}):\n" + f"stdout: {result.stdout}\nstderr: {result.stderr}" + ) + + _verify_zo_output_files(output_dir, quantized=quantized) + + # Check inference model + infer_path = os.path.join(output_dir, "network_infer.onnx") + _verify_onnx_valid(infer_path) + + # Check ZO training model + zo_path = os.path.join(output_dir, "network_zo_train.onnx") + zo_model = _verify_onnx_valid(zo_path) + + # Verify perturbation nodes exist + expected_op = _get_perturbation_op_type(noise_type) + perturb_count = sum( + 1 for n in zo_model.graph.node if n.op_type == expected_op + ) + assert perturb_count > 0, ( + f"No {expected_op} nodes found in ZO training graph. " + f"Node types: {set(n.op_type for n in zo_model.graph.node)}" + ) + + return zo_model + + +def _run_numerical_check( + onnx_file: str, + input_shape: tuple, + num_samples: int = 3, +): + """Run the ONNX graph with pure Python executor and verify it doesn't crash. + + For ZO training graphs, we verify execution succeeds (the perturbation + ops produce finite outputs). Full numerical matching against the RNG + reference is done in dedicated tests. + """ + for seed in range(num_samples): + rng = np.random.default_rng(seed) + test_input = rng.standard_normal(input_shape).astype(np.float32) + # ZO graphs need both input and label + num_classes = 10 # default; overridden per-model + label = np.zeros((input_shape[0], num_classes), dtype=np.float32) + label[0, 0] = 1.0 + + feeds = {"input": test_input, "label": label} + try: + output = run_onnx_graph(onnx_file, feeds) + assert np.all(np.isfinite(output)), ( + f"Non-finite output at seed={seed}" + ) + except Exception as e: + pytest.fail(f"run_onnx_graph failed at seed={seed}: {e}") + + +# =========================================================================== +# LightweightCNN (float) ZO Tests +# =========================================================================== + + +@pytest.mark.zo +class TestLiteCNNZO: + """ZO perturbation tests for LightweightCNN (float).""" + + _MODEL = "LightweightCNN" + _MODE = "zo-train" + _INPUT_SHAPE = (1, 1, 28, 28) + + def test_litecnn_uniform(self, model_test_dir): + """LiteCNN-Uniform: export and verify perturbation nodes.""" + zo_model = _run_and_verify_zo_export( + self._MODEL, self._MODE, "uniform", model_test_dir + ) + # Verify inference model runs through pure Python executor + infer_path = os.path.join(model_test_dir, "network_infer.onnx") + rng = np.random.default_rng(0) + test_input = rng.standard_normal(self._INPUT_SHAPE).astype(np.float32) + output = run_onnx_graph(infer_path, {"input": test_input}) + assert output is not None and np.all(np.isfinite(output)) + + def test_litecnn_rademacher(self, model_test_dir): + """LiteCNN-Rademacher: export and verify perturbation nodes.""" + _run_and_verify_zo_export( + self._MODEL, self._MODE, "rademacher", model_test_dir + ) + + def test_litecnn_eggroll(self, model_test_dir): + """LiteCNN-Eggroll: export and verify perturbation nodes (uses Rademacher).""" + _run_and_verify_zo_export( + self._MODEL, self._MODE, "eggroll", model_test_dir + ) + + +# =========================================================================== +# QLiteCNN (quantized) ZO Tests +# =========================================================================== + + +@pytest.mark.zo +@pytest.mark.quantized +class TestQLiteCNNZO: + """ZO perturbation tests for QLiteCNN (quantized).""" + + _MODEL = "QLiteCNN" + _MODE = "q-zo-train" + _INPUT_SHAPE = (1, 1, 28, 28) + + def test_qlitecnn_rqs_rademacher(self, model_test_dir): + """QLiteCNN-RQSRad: quantized ZO export with RQS Rademacher perturbation.""" + zo_model = _run_and_verify_zo_export( + self._MODEL, self._MODE, "rqs_rademacher", model_test_dir, + quantized=True, + ) + # Verify RQSPerturbRademacher nodes exist + rqs_nodes = [ + n for n in zo_model.graph.node + if n.op_type == "RQSPerturbRademacher" + ] + assert len(rqs_nodes) > 0, "No RQSPerturbRademacher nodes in quantized ZO graph" + + +# =========================================================================== +# SleepConViT (float) ZO Tests +# =========================================================================== + + +@pytest.mark.zo +class TestSleepViTZO: + """ZO perturbation tests for SleepConViT (float).""" + + _MODEL = "SleepConViT" + _MODE = "zo-train" + _INPUT_SHAPE = (1, 1, 1, 3000) + + def test_sleepvit_uniform(self, model_test_dir): + """SleepViT-Uniform: export and verify perturbation nodes.""" + zo_model = _run_and_verify_zo_export( + self._MODEL, self._MODE, "uniform", model_test_dir + ) + # Verify inference model runs through pure Python executor + infer_path = os.path.join(model_test_dir, "network_infer.onnx") + rng = np.random.default_rng(0) + test_input = rng.standard_normal(self._INPUT_SHAPE).astype(np.float32) + output = run_onnx_graph(infer_path, {"input": test_input}) + assert output is not None and np.all(np.isfinite(output)) + + def test_sleepvit_rademacher(self, model_test_dir): + """SleepViT-Rademacher: export and verify perturbation nodes.""" + _run_and_verify_zo_export( + self._MODEL, self._MODE, "rademacher", model_test_dir + ) + + def test_sleepvit_eggroll(self, model_test_dir): + """SleepViT-Eggroll: export and verify perturbation nodes (uses Rademacher).""" + _run_and_verify_zo_export( + self._MODEL, self._MODE, "eggroll", model_test_dir + ) + + +# =========================================================================== +# QSleepConViT (quantized) ZO Tests +# =========================================================================== + + +@pytest.mark.zo +@pytest.mark.quantized +class TestQSleepViTZO: + """ZO perturbation tests for QSleepConViT (quantized).""" + + _MODEL = "QSleepConViT" + _MODE = "q-zo-train" + _INPUT_SHAPE = (1, 1, 1, 3000) + + def test_qsleepvit_rqs_rademacher(self, model_test_dir): + """QSleepViT-RQSRad: quantized ZO export with RQS Rademacher perturbation.""" + zo_model = _run_and_verify_zo_export( + self._MODEL, self._MODE, "rqs_rademacher", model_test_dir, + quantized=True, + ) + rqs_nodes = [ + n for n in zo_model.graph.node + if n.op_type == "RQSPerturbRademacher" + ] + assert len(rqs_nodes) > 0, "No RQSPerturbRademacher nodes in quantized ZO graph" + + +# =========================================================================== +# RNG Seed Verification Tests +# =========================================================================== + + +@pytest.mark.zo +class TestRNGSeedComputation: + """Verify the RNG seed computation matches the Deeploy C++ reference.""" + + def test_scramble_seed(self): + """Verify scramble formula: seed * 1664525 + 1013904223.""" + assert scramble_seed(0) == np.uint32(1013904223) + assert scramble_seed(1) == np.uint32(1664525 + 1013904223) + assert scramble_seed(42) == np.uint32( + np.uint32(42) * np.uint32(1664525) + np.uint32(1013904223) + ) + + def test_xorshift32_deterministic(self): + """Verify Xorshift32 is deterministic.""" + state = np.uint32(12345) + s1 = xorshift32(state) + s2 = xorshift32(state) + assert s1 == s2 + # Verify it changes state + assert xorshift32(s1) != s1 + + def test_seed_per_node(self): + """Verify each node gets a unique seed: seed + NUM_CORES * node_id + core_id.""" + global_seed = 42 + # Two different nodes should get different seeds + seed_node0_core0 = scramble_seed(global_seed + NUM_CORES * 0 + 0) + seed_node1_core0 = scramble_seed(global_seed + NUM_CORES * 1 + 0) + assert seed_node0_core0 != seed_node1_core0 + + # Same node, different cores should get different seeds + seed_node0_core1 = scramble_seed(global_seed + NUM_CORES * 0 + 1) + assert seed_node0_core0 != seed_node0_core1 + + def test_uniform_perturbation_deterministic(self): + """Verify uniform perturbation is deterministic for same seed/node_id.""" + data = np.zeros(64, dtype=np.float32) + result1 = generate_uniform_perturbation(data, global_seed=42, node_id=0, eps=0.01) + result2 = generate_uniform_perturbation(data, global_seed=42, node_id=0, eps=0.01) + np.testing.assert_array_equal(result1, result2) + + def test_rademacher_perturbation_values(self): + """Verify Rademacher perturbation only produces +/- eps offsets.""" + data = np.zeros(64, dtype=np.float32) + eps = 0.01 + result = generate_rademacher_perturbation(data, global_seed=42, node_id=0, eps=eps) + # All values should be exactly +eps or -eps + np.testing.assert_array_less(np.abs(np.abs(result) - eps), 1e-7) diff --git a/tests/operators/test_perturbation_operators.py b/tests/operators/test_perturbation_operators.py new file mode 100644 index 0000000..f8bf21f --- /dev/null +++ b/tests/operators/test_perturbation_operators.py @@ -0,0 +1,389 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +Tests for individual perturbation operators: PerturbUniform, PerturbRademacher, +and PerturbEggroll. + +Each test class covers: + - File generation (ONNX model, inputs.npz, outputs.npz) via the operator test + generator classes in onnx4deeploy.operators. + - Output shape correctness. + - Output finiteness. + - Determinism: the pure-Python executor produces the same result on two calls + with the same inputs. + - Reference-RNG consistency: the pure-Python executor result matches the + reference _perturb_* helper functions in onnx4deeploy.utils directly. +""" + +import os + +import numpy as np +import pytest + +from onnx4deeploy.operators import ( + PerturbEggrollOperatorTest, + PerturbRademacherOperatorTest, + PerturbUniformOperatorTest, +) +from onnx4deeploy.utils.onnx_node_implementations import ( + _perturb_rademacher, + _perturb_uniform, + run_onnx_graph, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_DEFAULT_SEED = 42 +_DEFAULT_EPS = 0.01 +_SHAPES = [(1, 16), (2, 32), (1, 8, 4)] + + +def _write_uniform_config(path: str, shape) -> str: + cfg_path = os.path.join(path, "config.yaml") + with open(cfg_path, "w") as f: + f.write(f"perturbuniform:\n input_shape: {list(shape)}\n") + return cfg_path + + +def _write_rademacher_config(path: str, shape) -> str: + cfg_path = os.path.join(path, "config.yaml") + with open(cfg_path, "w") as f: + f.write(f"perturbrademacher:\n input_shape: {list(shape)}\n") + return cfg_path + + +def _write_eggroll_config(path: str, shape) -> str: + cfg_path = os.path.join(path, "config.yaml") + with open(cfg_path, "w") as f: + f.write(f"perturbeggroll:\n input_shape: {list(shape)}\n") + return cfg_path + + +# --------------------------------------------------------------------------- +# PerturbUniform +# --------------------------------------------------------------------------- + + +class TestPerturbUniformOperator: + """Tests for the PerturbUniform custom ONNX operator.""" + + def test_files_generated(self, operator_test_dir): + """Verify that generate() creates the ONNX model and data files.""" + cfg = _write_uniform_config(operator_test_dir, (1, 32)) + test = PerturbUniformOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + assert os.path.exists(onnx_file), "ONNX model file not created" + assert os.path.exists(input_file), "inputs.npz not created" + assert os.path.exists(output_file), "outputs.npz not created" + + def test_output_shape(self, operator_test_dir): + """Output shape must match input shape.""" + shape = (2, 16) + cfg = _write_uniform_config(operator_test_dir, shape) + test = PerturbUniformOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + outputs = np.load(output_file) + assert "perturbed_x" in outputs + assert outputs["perturbed_x"].shape == shape + + def test_output_finite(self, operator_test_dir): + """All output values must be finite.""" + cfg = _write_uniform_config(operator_test_dir, (1, 64)) + test = PerturbUniformOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + outputs = np.load(output_file) + assert np.all(np.isfinite(outputs["perturbed_x"])), "PerturbUniform output contains non-finite values" + + @pytest.mark.parametrize("shape", _SHAPES) + def test_pure_python_executor_runs(self, operator_test_dir, shape): + """run_onnx_graph executes the PerturbUniform ONNX without errors.""" + cfg = _write_uniform_config(operator_test_dir, shape) + test = PerturbUniformOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + result = run_onnx_graph(onnx_file, {"x": x}) + assert result is not None + assert result.shape == x.shape + + def test_deterministic(self, operator_test_dir): + """Two invocations of run_onnx_graph with the same input give the same result.""" + cfg = _write_uniform_config(operator_test_dir, (1, 32)) + test = PerturbUniformOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + out1 = run_onnx_graph(onnx_file, {"x": x}) + out2 = run_onnx_graph(onnx_file, {"x": x}) + np.testing.assert_array_equal(out1, out2, err_msg="PerturbUniform is not deterministic") + + def test_rng_reference_consistency(self, operator_test_dir): + """run_onnx_graph output matches direct _perturb_uniform with seed=42, idx=0.""" + shape = (1, 32) + cfg = _write_uniform_config(operator_test_dir, shape) + test = PerturbUniformOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + graph_out = run_onnx_graph(onnx_file, {"x": x}) + + # The ONNX node was created with seed=42, idx=0, eps=0.01*sqrt(3) + eps = float(0.01 * np.sqrt(3)) + ref = _perturb_uniform(x, global_seed=42, node_id=0, eps=eps, sign=1) + + np.testing.assert_allclose( + graph_out, ref, rtol=1e-6, atol=1e-6, + err_msg="PerturbUniform graph result does not match reference RNG" + ) + + def test_perturbation_magnitude(self, operator_test_dir): + """Perturbation magnitude is bounded by eps * sqrt(3) (uniform support [-sqrt(3), sqrt(3)]).""" + shape = (4, 64) + cfg = _write_uniform_config(operator_test_dir, shape) + test = PerturbUniformOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + out = run_onnx_graph(onnx_file, {"x": x}) + + delta = np.abs(out - x) + eps = float(0.01 * np.sqrt(3)) + assert np.all(delta <= eps * np.sqrt(3) + 1e-5), ( + f"PerturbUniform perturbation exceeds expected bound: max={delta.max():.6f}" + ) + + +# --------------------------------------------------------------------------- +# PerturbRademacher +# --------------------------------------------------------------------------- + + +class TestPerturbRademacherOperator: + """Tests for the PerturbRademacher custom ONNX operator.""" + + def test_files_generated(self, operator_test_dir): + """Verify that generate() creates the ONNX model and data files.""" + cfg = _write_rademacher_config(operator_test_dir, (1, 32)) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + assert os.path.exists(onnx_file) + assert os.path.exists(input_file) + assert os.path.exists(output_file) + + def test_output_shape(self, operator_test_dir): + """Output shape must match input shape.""" + shape = (2, 16) + cfg = _write_rademacher_config(operator_test_dir, shape) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + outputs = np.load(output_file) + assert "perturbed_x" in outputs + assert outputs["perturbed_x"].shape == shape + + def test_output_finite(self, operator_test_dir): + """All output values must be finite.""" + cfg = _write_rademacher_config(operator_test_dir, (1, 64)) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + outputs = np.load(output_file) + assert np.all(np.isfinite(outputs["perturbed_x"])) + + @pytest.mark.parametrize("shape", _SHAPES) + def test_pure_python_executor_runs(self, operator_test_dir, shape): + """run_onnx_graph executes the PerturbRademacher ONNX without errors.""" + cfg = _write_rademacher_config(operator_test_dir, shape) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + result = run_onnx_graph(onnx_file, {"x": x}) + assert result is not None + assert result.shape == x.shape + + def test_deterministic(self, operator_test_dir): + """Two invocations with same input produce identical results.""" + cfg = _write_rademacher_config(operator_test_dir, (1, 32)) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + out1 = run_onnx_graph(onnx_file, {"x": x}) + out2 = run_onnx_graph(onnx_file, {"x": x}) + np.testing.assert_array_equal(out1, out2) + + def test_rng_reference_consistency(self, operator_test_dir): + """run_onnx_graph output matches direct _perturb_rademacher with seed=42, idx=0.""" + shape = (1, 32) + cfg = _write_rademacher_config(operator_test_dir, shape) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + graph_out = run_onnx_graph(onnx_file, {"x": x}) + + # The ONNX node was created with seed=42, idx=0, eps=0.01 + ref = _perturb_rademacher(x, global_seed=42, node_id=0, eps=0.01, sign=1) + + np.testing.assert_allclose( + graph_out, ref, rtol=1e-6, atol=1e-6, + err_msg="PerturbRademacher graph result does not match reference RNG" + ) + + def test_perturbation_is_exactly_eps(self, operator_test_dir): + """Every element of x should be perturbed by exactly ยฑeps=0.01.""" + shape = (4, 64) + cfg = _write_rademacher_config(operator_test_dir, shape) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + out = run_onnx_graph(onnx_file, {"x": x}) + + delta = np.abs(out - x) + np.testing.assert_allclose( + delta, np.full_like(delta, 0.01), atol=1e-5, + err_msg="PerturbRademacher: perturbation magnitude should be exactly eps=0.01" + ) + + def test_perturbation_values_binary(self, operator_test_dir): + """Perturbation offsets must be exactly +eps or -eps (Rademacher property).""" + cfg = _write_rademacher_config(operator_test_dir, (1, 128)) + test = PerturbRademacherOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + out = run_onnx_graph(onnx_file, {"x": x}) + + noise = out - x + # noise values should all be +0.01 or -0.01 + eps = np.float32(0.01) + valid_mask = np.isclose(noise, eps, atol=1e-5) | np.isclose(noise, -eps, atol=1e-5) + assert np.all(valid_mask), "PerturbRademacher noise contains values other than ยฑeps" + + +# --------------------------------------------------------------------------- +# PerturbEggroll +# --------------------------------------------------------------------------- + + +class TestPerturbEggrollOperator: + """Tests for the PerturbEggroll custom ONNX operator.""" + + def test_files_generated(self, operator_test_dir): + """Verify that generate() creates the ONNX model and data files.""" + cfg = _write_eggroll_config(operator_test_dir, (4, 8)) + test = PerturbEggrollOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + assert os.path.exists(onnx_file) + assert os.path.exists(input_file) + assert os.path.exists(output_file) + + def test_output_shape(self, operator_test_dir): + """Output shape must match input shape.""" + shape = (4, 8) + cfg = _write_eggroll_config(operator_test_dir, shape) + test = PerturbEggrollOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + outputs = np.load(output_file) + assert "perturbed_x" in outputs + assert outputs["perturbed_x"].shape == shape + + def test_output_finite(self, operator_test_dir): + """All output values must be finite.""" + cfg = _write_eggroll_config(operator_test_dir, (4, 16)) + test = PerturbEggrollOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, output_file = test.generate() + + outputs = np.load(output_file) + assert np.all(np.isfinite(outputs["perturbed_x"])) + + @pytest.mark.parametrize("shape", [(4, 8), (2, 16), (2, 4, 8)]) + def test_pure_python_executor_runs(self, operator_test_dir, shape): + """run_onnx_graph executes the PerturbEggroll ONNX without errors.""" + cfg = _write_eggroll_config(operator_test_dir, shape) + test = PerturbEggrollOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + result = run_onnx_graph(onnx_file, {"x": x}) + assert result is not None + assert result.shape == x.shape + + def test_deterministic(self, operator_test_dir): + """Two invocations with same input produce identical results.""" + cfg = _write_eggroll_config(operator_test_dir, (4, 8)) + test = PerturbEggrollOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + out1 = run_onnx_graph(onnx_file, {"x": x}) + out2 = run_onnx_graph(onnx_file, {"x": x}) + np.testing.assert_array_equal(out1, out2) + + def test_rng_reference_consistency_vectors(self, operator_test_dir): + """ + The PerturbEggroll vectors (a and b) computed by run_onnx_graph match + those from _perturb_rademacher applied to zero-filled column vectors. + """ + shape = (4, 8) + cfg = _write_eggroll_config(operator_test_dir, shape) + test = PerturbEggrollOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + graph_out = run_onnx_graph(onnx_file, {"x": x}) + + # The ONNX graph uses seed_a=13, idx=0 and seed_b=14, idx=1 for the two + # PerturbEggroll nodes (as defined in perturbeggroll.py). + a_shape = [shape[0], 1] + b_shape = [int(np.prod(shape[1:])), 1] + + a_ref = _perturb_rademacher( + np.zeros(a_shape, dtype=np.float32), global_seed=13, node_id=0, eps=1.0, sign=1 + ) + b_ref = _perturb_rademacher( + np.zeros(b_shape, dtype=np.float32), global_seed=14, node_id=1, eps=1.0, sign=1 + ) + + # PerturbEggroll output = eps * Gemm(a, b^T) with beta=0 (i.e. ignores x) + # alpha in the graph is uniform_epsilon = 0.01 * sqrt(3) + eps = float(0.01 * np.sqrt(3)) + expected = eps * (a_ref @ b_ref.T) + expected = expected.reshape(shape) + + np.testing.assert_allclose( + graph_out, expected, rtol=1e-5, atol=1e-5, + err_msg="PerturbEggroll graph result does not match reference Rademacher vectors" + ) + + def test_low_rank_structure(self, operator_test_dir): + """ + PerturbEggroll output = alpha * a @ b^T (Gemm with beta=0), which is + exactly rank-1. The output matrix itself should have matrix rank 1. + """ + shape = (8, 16) + cfg = _write_eggroll_config(operator_test_dir, shape) + test = PerturbEggrollOperatorTest(config_path=cfg, save_path=operator_test_dir) + onnx_file, input_file, _ = test.generate() + + x = np.load(input_file)["x"] + out = run_onnx_graph(onnx_file, {"x": x}) + + # The output IS alpha * a @ b^T (beta=0); it is rank-1. + sv = np.linalg.svd(out.astype(np.float64), compute_uv=False) + # Only one non-negligible singular value + assert sv[0] > sv[1] * 1e3, ( + f"PerturbEggroll output does not appear rank-1: sv[0]={sv[0]:.4f}, sv[1]={sv[1]:.4f}" + )