diff --git a/Onnx4Deeploy.py b/Onnx4Deeploy.py index a9dda2c..bef210b 100644 --- a/Onnx4Deeploy.py +++ b/Onnx4Deeploy.py @@ -46,6 +46,7 @@ def list_available_models(): SimpleCnnExporter, SimpleMlpExporter, SleepConViTExporter, + SpeechNetExporter, TinyTransformerExporter, TinyViTExporter, ) @@ -226,6 +227,14 @@ def list_available_models(): "input_shape": "(B, 16, 49)", "classes": 10, }, + # EMG / Bio-Signal Models + "SpeechNet": { + "class": SpeechNetExporter, + "description": "SpeechNet (SilentWear EMG silent speech, ~15K params)", + "input_shape": "(B, 1, 14, 700)", + "classes": 9, + "config": {"num_channels": 14, "time_steps": 700, "num_classes": 9}, + }, "LightweightCNN": { "class": LightweightCnnExporter, "description": "Lightweight CNN (Compact CNN for image classification)", diff --git a/onnx4deeploy/models/__init__.py b/onnx4deeploy/models/__init__.py index eee9554..9b025a6 100644 --- a/onnx4deeploy/models/__init__.py +++ b/onnx4deeploy/models/__init__.py @@ -17,6 +17,7 @@ from .simple_cnn_exporter import SimpleCnnExporter from .simple_mlp_exporter import SimpleMlpExporter from .sleep_convit_exporter import SleepConViTExporter +from .speechnet_exporter import SpeechNetExporter from .tiny_transformer_exporter import TinyTransformerExporter from .tinyvit_exporter import TinyViTExporter @@ -36,4 +37,5 @@ "SleepConViTExporter", "TinyTransformerExporter", "TinyViTExporter", + "SpeechNetExporter", ] diff --git a/onnx4deeploy/models/pytorch_models/speechnet/__init__.py b/onnx4deeploy/models/pytorch_models/speechnet/__init__.py new file mode 100644 index 0000000..964190e --- /dev/null +++ b/onnx4deeploy/models/pytorch_models/speechnet/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""SpeechNet model for ONNX export (SilentWear EMG silent speech recognition).""" + +from .speechnet import SpeechNetDeploy + +__all__ = ["SpeechNetDeploy"] diff --git a/onnx4deeploy/models/pytorch_models/speechnet/speechnet.py b/onnx4deeploy/models/pytorch_models/speechnet/speechnet.py new file mode 100644 index 0000000..0bd4b26 --- /dev/null +++ b/onnx4deeploy/models/pytorch_models/speechnet/speechnet.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +SpeechNet — Deployment-ready variant for Deeploy on PULP MCUs. + +Based on the SpeechNet architecture from: + Spacone et al., "SilentWear: an Ultra-Low Power Wearable System for + EMG-based Silent Speech Recognition", arXiv: 2603.02847. + +Differences from the upstream SilentWear SpeechNet: + - Input is 4-D (B, 1, C, T) directly; no unsqueeze in forward(). + - MaxPool2d replaced by AvgPool2d for Deeploy tiling/gradient compatibility. + - No Dropout (deployment / inference mode). + +Paper default configuration (5 blocks): + Input: (1, 1, 14, 700) 14 EMG channels, 700 samples (1.4 s @ 500 Hz) + Block 0: Conv2d(1, 8, k=(1,4), pad=(0,2)) -> BN -> ReLU -> AvgPool(1,8) + Block 1: Conv2d(8, 16, k=(1,16), pad=(0,8)) -> BN -> ReLU -> AvgPool(1,4) + Block 2: Conv2d(16,16, k=(1,8), pad=(0,4)) -> BN -> ReLU -> AvgPool(1,4) + Block 3: Conv2d(16,32, k=(7,1), pad=(0,0)) -> BN -> ReLU -> AvgPool(1,1) + Block 4: Conv2d(32,32, k=(7,1), pad=(0,0)) -> BN -> ReLU -> AvgPool(1,1) + Global: AdaptiveAvgPool2d(1,1) -> Flatten -> Linear(32, num_classes) +""" + +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + + +class SpeechNetDeploy(nn.Module): + """ + Deployment-ready SpeechNet for Deeploy. + + Parameters + ---------- + num_channels : int + Number of EMG input channels (H dimension). Default: 14. + time_steps : int + Number of time samples (W dimension). Default: 700. + num_classes : int + Number of output classes. Default: 9 (8 commands + rest). + blocks_config : list of dict, optional + Per-block configuration. Each dict has keys: + out_channels (int), kernel (tuple), pool (tuple). + If None, uses the paper's 5-block default. + """ + + def __init__( + self, + num_channels: int = 14, + time_steps: int = 700, + num_classes: int = 9, + blocks_config: Optional[List[Dict[str, Any]]] = None, + ): + super().__init__() + self.num_channels = num_channels + self.time_steps = time_steps + self.num_classes = num_classes + + if blocks_config is None: + blocks_config = [ + dict(out_channels=8, kernel=(1, 4), pool=(1, 8)), + dict(out_channels=16, kernel=(1, 16), pool=(1, 4)), + dict(out_channels=16, kernel=(1, 8), pool=(1, 4)), + dict(out_channels=32, kernel=(7, 1), pool=(1, 1)), + dict(out_channels=32, kernel=(7, 1), pool=(1, 1)), + ] + + self.blocks = nn.ModuleList() + in_ch = 1 + + for cfg in blocks_config: + out_ch = cfg["out_channels"] + k_c, k_t = cfg["kernel"] + k_c, k_t = int(k_c), int(k_t) + pool_c, pool_t = cfg.get("pool", (1, 1)) + pool_c, pool_t = int(pool_c), int(pool_t) + + layers: List[nn.Module] = [ + nn.Conv2d( + in_ch, + out_ch, + kernel_size=(k_c, k_t), + stride=(1, 1), + padding=(0, k_t // 2), + bias=True, + ), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=False), + nn.AvgPool2d(kernel_size=(pool_c, pool_t), stride=(pool_c, pool_t)), + ] + + self.blocks.append(nn.Sequential(*layers)) + in_ch = out_ch + + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self._fc_in = in_ch # stored as Python int for static ONNX reshape + self.fc = nn.Linear(in_ch, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. x: (B, 1, C, T).""" + for block in self.blocks: + x = block(x) + x = self.global_pool(x) + # Static reshape avoiding dynamic Shape ops in ONNX. + # After GlobalAvgPool2d(1,1): (1, C, 1, 1) → (1, C) + # Both dims are Python ints known at trace time. + x = x.reshape(1, self._fc_in) + x = self.fc(x) + return x diff --git a/onnx4deeploy/models/speechnet_exporter.py b/onnx4deeploy/models/speechnet_exporter.py new file mode 100644 index 0000000..16b01b0 --- /dev/null +++ b/onnx4deeploy/models/speechnet_exporter.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +"""SpeechNet Model Exporter — SilentWear EMG silent speech recognition. + +Based on: Spacone et al., "SilentWear: an Ultra-Low Power Wearable System for +EMG-based Silent Speech Recognition", arXiv: 2603.02847. + +Default configuration (paper, FP32): + Input: (1, 1, 14, 700) — 1-ch, 14 EMG channels × 700 time samples (1.4 s @ 500 Hz) + Classes: 9 (8 commands + rest) + ~15 K parameters + +The deployment variant uses AvgPool (not MaxPool) and omits Dropout for Deeploy +tiling/gradient compatibility. BatchNorm folds into Conv during ONNX export. +""" + +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch + +from ..core.base_exporter import BaseONNXExporter + + +class SpeechNetExporter(BaseONNXExporter): + """ONNX exporter for SpeechNet (SilentWear EMG silent speech).""" + + def __init__(self, save_path: str = None, config_file: str = "config.yaml"): + super().__init__(save_path, config_file) + self.model_config = {} + + # ------------------------------------------------------------------ # + # Configuration # + # ------------------------------------------------------------------ # + + def load_config(self) -> Dict[str, Any]: + config = { + "batch_size": 1, + "num_channels": 14, # EMG differential channels + "time_steps": 700, # 1.4 s @ 500 Hz + "num_classes": 9, # 8 commands + rest + "opset_version": 17, + # Training + "training_strategy": "full", # "full" | "last_layer" | "custom" + "custom_trainable_params": [], + "learning_rate": 0.001, + "n_batches": 4, + "n_accum": 1, + "data_size": None, + } + + if hasattr(self, "_config_overrides") and self._config_overrides: + config.update(self._config_overrides) + + self.model_config = config + return config + + # ------------------------------------------------------------------ # + # Model factory # + # ------------------------------------------------------------------ # + + def create_model(self) -> torch.nn.Module: + from .pytorch_models.speechnet.speechnet import SpeechNetDeploy + + return SpeechNetDeploy( + num_channels=self.model_config["num_channels"], + time_steps=self.model_config["time_steps"], + num_classes=self.model_config["num_classes"], + ) + + # ------------------------------------------------------------------ # + # Shape helpers # + # ------------------------------------------------------------------ # + + def get_input_shape(self) -> Tuple[int, ...]: + return ( + self.config["batch_size"], + 1, # single input channel + self.config["num_channels"], + self.config["time_steps"], + ) + + def _get_config_string(self) -> str: + return ( + f"_speechnet_{self.config['num_channels']}ch" + f"_{self.config['time_steps']}t" + f"_{self.config['num_classes']}cls" + ) + + # ------------------------------------------------------------------ # + # Training strategy # + # ------------------------------------------------------------------ # + + def get_trainable_params(self, all_param_names: List[str]) -> List[str]: + """ + Pattern-based trainable parameter selection. + + Strategies: + - "full": Train all parameters (default). + - "last_layer": Only the FC classifier. + - "custom": Explicit list from config["custom_trainable_params"]. + """ + strategy = self.config.get("training_strategy", "full") + + _FREEZE = { + "full": lambda n: False, + "last_layer": lambda n: "fc" not in n, + "custom": lambda n: n not in self.config.get("custom_trainable_params", []), + } + + if strategy not in _FREEZE: + print(f" Unknown strategy '{strategy}', using 'full'") + strategy = "full" + + requires_grad = [n for n in all_param_names if not _FREEZE[strategy](n)] + frozen = [n for n in all_param_names if _FREEZE[strategy](n)] + + print(f"\n Training Strategy: '{strategy}'") + print( + f" Total: {len(all_param_names)} Trainable: {len(requires_grad)} Frozen: {len(frozen)}" + ) + if frozen: + print(f" Frozen: {frozen[:5]}{'...' if len(frozen) > 5 else ''}") + return requires_grad + + # ------------------------------------------------------------------ # + # Inference test data # + # ------------------------------------------------------------------ # + + def save_test_data(self, model: torch.nn.Module, save_dir: str): + print(" Saving inference test data...") + input_shape = self.get_input_shape() + test_input = np.random.randn(*input_shape).astype(np.float32) + + was_training = model.training + model.eval() + with torch.no_grad(): + test_output = model(torch.from_numpy(test_input)).numpy() + if was_training: + model.train() + + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + np.savez(save_path / "inputs.npz", input=test_input) + np.savez(save_path / "outputs.npz", output=test_output) + print(f" Input: {test_input.shape} Output: {test_output.shape}") + + # ------------------------------------------------------------------ # + # Training test data # + # ------------------------------------------------------------------ # + + def create_training_test_data( + self, n_batches: int = None, num_data_inputs: int = 2, n_accum: int = None + ) -> None: + """ + Save inputs.npz / outputs.npz for training-mode validation. + + Inputs are random EMG-shaped float tensors; labels are random int64 class indices. + """ + import onnx + import onnxruntime as ort + + if n_batches is None: + n_batches = self.config.get("n_batches", 4) + if n_accum is None: + n_accum = int(self.config.get("n_accum", 1)) + if n_batches % n_accum != 0: + n_batches = max((n_batches // n_accum) * n_accum, n_accum) + print(f" n_batches adjusted to {n_batches} (must be divisible by n_accum={n_accum})") + n_steps = n_batches // n_accum + + save_dir = Path(self.paths["output_dir"]) + save_dir.mkdir(parents=True, exist_ok=True) + + input_shape = self.get_input_shape() + num_classes = self.config.get("num_classes", 9) + learning_rate = float(self.config.get("learning_rate", 0.001)) + + print( + f" Training sim: n_batches={n_batches} n_accum={n_accum} n_steps={n_steps} lr={learning_rate}" + ) + + _data_size_cfg = self.config.get("data_size", None) + effective_data_size = ( + int(_data_size_cfg) + if (_data_size_cfg and int(_data_size_cfg) < n_batches) + else n_batches + ) + + rng = np.random.default_rng(42) + test_inputs = [ + rng.standard_normal(input_shape).astype(np.float32) for _ in range(effective_data_size) + ] + labels_list = [ + rng.integers(0, num_classes, size=(input_shape[0],)).astype(np.int64) + for _ in range(effective_data_size) + ] + + init_map: dict = self._load_init_map(self.paths["network_infer"]) + + train_model_onnx = onnx.load(self.paths["network_train"]) + grad_tensor_map: dict = {} + for node in train_model_onnx.graph.node: + if "InPlaceAccumulator" in node.op_type and len(node.input) >= 2: + grad_tensor_name = node.input[1] + if grad_tensor_name.endswith("_grad"): + grad_tensor_map[grad_tensor_name[:-5]] = grad_tensor_name + + for grad_name in grad_tensor_map.values(): + vi = onnx.helper.make_tensor_value_info(grad_name, onnx.TensorProto.FLOAT, None) + train_model_onnx.graph.output.append(vi) + + session = ort.InferenceSession( + train_model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + session_output_names = [o.name for o in session.get_outputs()] + print(f" Training model inputs: {[i.name for i in session.get_inputs()]}") + print(f" Training model outputs: {session_output_names}") + + current_weights = {k: v.copy() for k, v in init_map.items()} + all_losses: list = [] + feed_mb0: dict = {} + + for update_step in range(n_steps): + accumulated_grads = { + pname: np.zeros_like(current_weights[pname]) + for pname in grad_tensor_map + if pname in current_weights + } + + for accum_step in range(n_accum): + mb = update_step * n_accum + accum_step + + feed = self._build_input_feed( + session, + param_values=current_weights, + test_input=test_inputs[mb % effective_data_size], + labels=labels_list[mb % effective_data_size], + lazy_reset_grad=(accum_step == 0), + ) + + if mb == 0: + feed_mb0 = {k: v.copy() if hasattr(v, "copy") else v for k, v in feed.items()} + + raw_outputs = session.run(None, feed) + outputs_raw = dict(zip(session_output_names, raw_outputs)) + + for out_name, out_val in outputs_raw.items(): + if "loss" in out_name.lower() and "grad" not in out_name.lower(): + all_losses.append(float(np.array(out_val).flatten()[0])) + break + + for pname, grad_name in grad_tensor_map.items(): + if grad_name in outputs_raw and pname in accumulated_grads: + accumulated_grads[pname] += outputs_raw[grad_name] + + for pname, acc_grad in accumulated_grads.items(): + current_weights[pname] -= learning_rate * acc_grad + + outputs_dict: dict = {k: v for k, v in current_weights.items()} + outputs_dict["loss"] = np.array(all_losses, dtype=np.float32) + print(f" Reference losses: {all_losses}") + + final_model = onnx.load(self.paths["network"]) + final_input_names = [inp.name for inp in final_model.graph.input] + grad_acc_names = {n for n in final_input_names if self._GRAD_ACC_SUFFIX in n} + non_grad_names = [n for n in final_input_names if n not in grad_acc_names] + + save_dict: dict = {} + for npz_idx, name in enumerate(non_grad_names): + if name in feed_mb0: + save_dict[f"arr_{npz_idx:04d}"] = feed_mb0[name] + else: + print(f" non-grad input '{name}' not found in feed -- skipping") + + session_type: dict = {inp.name: inp.type for inp in session.get_inputs()} + data_names = non_grad_names[:num_data_inputs] + for mb in range(1, effective_data_size): + for buf_idx, data_name in enumerate(data_names): + inp_type = session_type.get(data_name, "tensor(float)") + if inp_type == "tensor(int64)": + save_dict[f"mb{mb}_arr_{buf_idx:04d}"] = labels_list[mb] + else: + save_dict[f"mb{mb}_arr_{buf_idx:04d}"] = test_inputs[mb] + + save_dict["meta_data_size"] = np.array([effective_data_size], dtype=np.int32) + save_dict["meta_n_batches"] = np.array([n_batches], dtype=np.int32) + save_dict["meta_n_accum"] = np.array([n_accum], dtype=np.int32) + np.savez(save_dir / "inputs.npz", **save_dict) + + n_params = sum(1 for n in non_grad_names if n in init_map) + n_grad = len(grad_acc_names) + print( + f" inputs.npz: {len(non_grad_names)} base tensors " + f"(data + {n_params} params; {n_grad} grad-acc-buf(s) omitted) " + f"+ {(effective_data_size - 1) * num_data_inputs} DATA entries" + ) + + np.savez(save_dir / "outputs.npz", **outputs_dict) + n_updated = sum(1 for k in outputs_dict if k in init_map) + print(f" outputs.npz: {len(outputs_dict)} tensors ({n_updated} updated params + loss)") diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 75bed86..d84e276 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -203,6 +203,18 @@ def tiny_transformer_config(): } +@pytest.fixture +def speechnet_config(): + """Default SpeechNet configuration for fast tests (SilentWear EMG).""" + return { + "batch_size": 1, + "num_channels": 14, + "time_steps": 700, + "num_classes": 9, + "opset_version": 17, + } + + @pytest.fixture def sleep_convit_config(): """Default SleepConViT configuration (tests use library defaults). diff --git a/tests/models/test_speechnet.py b/tests/models/test_speechnet.py new file mode 100644 index 0000000..834edb1 --- /dev/null +++ b/tests/models/test_speechnet.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: MIT + +""" +Tests for SpeechNet model export (SilentWear EMG silent speech recognition). + +Tests inference and training graph generation for the 5-block SpeechNet CNN +(14 EMG channels, 700 time samples, 9 classes). +""" + +import os + +import numpy as np +import pytest + +from onnx4deeploy.models.speechnet_exporter import SpeechNetExporter + +from .test_utils import ( + create_random_input, + verify_inference_export, + verify_onnxruntime_compatibility, + verify_trainable_params, + verify_training_export, +) + + +@pytest.mark.inference +class TestSpeechNetInference: + """Test SpeechNet model inference mode export.""" + + def test_speechnet_inference_export(self, model_test_dir, speechnet_config): + """Test SpeechNet model inference export with full verification.""" + exporter = SpeechNetExporter(save_path=model_test_dir) + exporter._config_overrides = {k: v for k, v in speechnet_config.items()} + + expected_input_shape = [ + speechnet_config["batch_size"], + 1, + speechnet_config["num_channels"], + speechnet_config["time_steps"], + ] + + verify_inference_export( + exporter, + model_test_dir, + expected_input_shape=expected_input_shape, + expected_batch_size=speechnet_config["batch_size"], + expected_output_classes=speechnet_config["num_classes"], + ) + + def test_speechnet_onnxruntime_inference(self, model_test_dir, speechnet_config): + """Test exported SpeechNet model runs correctly with ONNX Runtime.""" + exporter = SpeechNetExporter(save_path=model_test_dir) + exporter._config_overrides = {k: v for k, v in speechnet_config.items()} + onnx_file = exporter.export(mode="infer") + + input_shape = ( + speechnet_config["batch_size"], + 1, + speechnet_config["num_channels"], + speechnet_config["time_steps"], + ) + test_input = create_random_input(input_shape) + expected_output_shape = (speechnet_config["batch_size"], speechnet_config["num_classes"]) + verify_onnxruntime_compatibility( + onnx_file, test_input, expected_output_shape, input_name="input" + ) + + def test_speechnet_trainable_params(self, model_test_dir, speechnet_config): + """Test SpeechNet trainable parameter identification.""" + exporter = SpeechNetExporter(save_path=model_test_dir) + exporter._config_overrides = {k: v for k, v in speechnet_config.items()} + exporter.config = exporter.load_config() + + model = exporter.create_model() + verify_trainable_params(exporter, model) + + def test_speechnet_last_layer_strategy(self, model_test_dir, speechnet_config): + """Test SpeechNet last-layer-only training strategy.""" + exporter = SpeechNetExporter(save_path=model_test_dir) + cfg = {k: v for k, v in speechnet_config.items()} + cfg["training_strategy"] = "last_layer" + exporter._config_overrides = cfg + exporter.config = exporter.load_config() + + model = exporter.create_model() + all_params = [name for name, _ in model.named_parameters()] + trainable = exporter.get_trainable_params(all_params) + + assert len(trainable) > 0 + assert len(trainable) < len(all_params) + + +@pytest.mark.training +class TestSpeechNetTraining: + """Test SpeechNet model training mode export.""" + + def test_speechnet_training_export(self, model_test_dir, speechnet_config): + """Test SpeechNet training graph generation (smoke test with random data).""" + exporter = SpeechNetExporter(save_path=model_test_dir) + exporter._config_overrides = { + **speechnet_config, + "n_batches": 2, + "n_accum": 1, + "dataset": "random", + } + onnx_file = verify_training_export(exporter, model_test_dir) + assert os.path.exists(onnx_file) + + def test_speechnet_training_npz_layout(self, model_test_dir, speechnet_config): + """Test inputs.npz / outputs.npz are generated with the correct layout.""" + n_batches = 4 + exporter = SpeechNetExporter(save_path=model_test_dir) + exporter._config_overrides = { + **speechnet_config, + "n_batches": n_batches, + "n_accum": 1, + "dataset": "random", + } + onnx_file = verify_training_export(exporter, model_test_dir) + output_dir = os.path.dirname(onnx_file) + + npz_in = np.load(os.path.join(output_dir, "inputs.npz"), allow_pickle=True) + assert "meta_data_size" in npz_in + assert "meta_n_batches" in npz_in + assert int(npz_in["meta_n_batches"]) == n_batches + + npz_out = np.load(os.path.join(output_dir, "outputs.npz"), allow_pickle=True) + assert "loss" in npz_out + assert len(npz_out["loss"]) == n_batches