From ecd54fa8a0002e90b55397cfd8afa7a9c095b91f Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 14 May 2026 16:51:23 +0200 Subject: [PATCH 1/5] Use `DiscriminatedModel` --- open_eeg_bench/normalization.py | 46 ++++++++++++++------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/open_eeg_bench/normalization.py b/open_eeg_bench/normalization.py index b365052..308f9a3 100644 --- a/open_eeg_bench/normalization.py +++ b/open_eeg_bench/normalization.py @@ -10,19 +10,29 @@ from __future__ import annotations -from typing import Annotated, Literal, Union +from typing import Literal import numpy as np -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict +from exca.helpers import DiscriminatedModel -class DivideByConstant(BaseModel): +class Normalization(DiscriminatedModel, discriminator_key="kind"): + """Base class for all normalizations.""" + + model_config = ConfigDict(extra="forbid") + + def apply(self, data: np.ndarray) -> np.ndarray: + """Apply normalization to a window of EEG data.""" + raise NotImplementedError + + +class DivideByConstant(Normalization): """Divide data by a constant factor. Used by LaBraM and CBraMod (factor=100) to set the EEG unit to 0.1 mV. """ - model_config = ConfigDict(extra="forbid") kind: Literal["divide_by_constant"] = "divide_by_constant" factor: float = 100.0 @@ -30,14 +40,13 @@ def apply(self, data: np.ndarray) -> np.ndarray: return data / self.factor -class PercentileScale(BaseModel): +class PercentileScale(Normalization): """Per-channel percentile normalization. Used by BIOT: each channel is divided by the q-th percentile of its absolute amplitude so that the bulk of values falls near [-1, 1]. """ - model_config = ConfigDict(extra="forbid") kind: Literal["percentile_scale"] = "percentile_scale" q: float = 95.0 eps: float = 1e-8 @@ -49,13 +58,12 @@ def apply(self, data: np.ndarray) -> np.ndarray: return data / (quantile + self.eps) -class MinMaxScale(BaseModel): +class MinMaxScale(Normalization): """Per-window min-max scaling to [-1, 1]. Used by BENDR. """ - model_config = ConfigDict(extra="forbid") kind: Literal["minmax_scale"] = "minmax_scale" def apply(self, data: np.ndarray) -> np.ndarray: @@ -67,13 +75,12 @@ def apply(self, data: np.ndarray) -> np.ndarray: return 2.0 * (data - dmin) / drange - 1.0 -class WindowZScore(BaseModel): +class WindowZScore(Normalization): """Per-window z-score normalization with optional sigma clipping. Used by REVE (clip_sigma=15) and EEGPT (channel_wise=True). """ - model_config = ConfigDict(extra="forbid") kind: Literal["window_zscore"] = "window_zscore" channel_wise: bool = False clip_sigma: float | None = 15.0 @@ -90,37 +97,22 @@ def apply(self, data: np.ndarray) -> np.ndarray: return normalised -class ScaleToMV(BaseModel): +class ScaleToMV(Normalization): """Convert microvolt data to millivolts (divide by 1000). Used by EEGPT during pretraining. """ - model_config = ConfigDict(extra="forbid") kind: Literal["scale_to_mv"] = "scale_to_mv" def apply(self, data: np.ndarray) -> np.ndarray: return data / 1000.0 -class NoNormalization(BaseModel): +class NoNormalization(Normalization): """No-op normalization (identity). Default when no normalization is needed.""" - model_config = ConfigDict(extra="forbid") kind: Literal["none"] = "none" def apply(self, data: np.ndarray) -> np.ndarray: return data - - -Normalization = Annotated[ - Union[ - NoNormalization, - DivideByConstant, - PercentileScale, - MinMaxScale, - WindowZScore, - ScaleToMV, - ], - Field(discriminator="kind"), -] From b22ad12849a8b22292617c8050a397417e296730 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 14 May 2026 17:25:45 +0200 Subject: [PATCH 2/5] Better handle the discrimination field --- open_eeg_bench/normalization.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/open_eeg_bench/normalization.py b/open_eeg_bench/normalization.py index 308f9a3..faaec61 100644 --- a/open_eeg_bench/normalization.py +++ b/open_eeg_bench/normalization.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import Literal +from typing import Any, ClassVar import numpy as np from pydantic import ConfigDict @@ -18,10 +18,25 @@ class Normalization(DiscriminatedModel, discriminator_key="kind"): - """Base class for all normalizations.""" + """Base class for all normalizations. + + Subclass to define a custom normalization; instances are dispatched + automatically via the ``kind`` discriminator key. By default the + discriminator value is ``cls.__name__``. Builtin subclasses pin a + snake_case value via ``_legacy_kind`` to preserve the pre-DiscriminatedModel + serialization format and keep cached experiment UIDs stable. + """ + + _legacy_kind: ClassVar[str | None] = None model_config = ConfigDict(extra="forbid") + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + legacy = cls.__dict__.get("_legacy_kind") + if legacy: + cls.__name__ = legacy + def apply(self, data: np.ndarray) -> np.ndarray: """Apply normalization to a window of EEG data.""" raise NotImplementedError @@ -33,7 +48,7 @@ class DivideByConstant(Normalization): Used by LaBraM and CBraMod (factor=100) to set the EEG unit to 0.1 mV. """ - kind: Literal["divide_by_constant"] = "divide_by_constant" + _legacy_kind: ClassVar[str | None] = "divide_by_constant" factor: float = 100.0 def apply(self, data: np.ndarray) -> np.ndarray: @@ -47,7 +62,7 @@ class PercentileScale(Normalization): absolute amplitude so that the bulk of values falls near [-1, 1]. """ - kind: Literal["percentile_scale"] = "percentile_scale" + _legacy_kind: ClassVar[str | None] = "percentile_scale" q: float = 95.0 eps: float = 1e-8 @@ -64,7 +79,7 @@ class MinMaxScale(Normalization): Used by BENDR. """ - kind: Literal["minmax_scale"] = "minmax_scale" + _legacy_kind: ClassVar[str | None] = "minmax_scale" def apply(self, data: np.ndarray) -> np.ndarray: dmin = np.min(data) @@ -81,7 +96,7 @@ class WindowZScore(Normalization): Used by REVE (clip_sigma=15) and EEGPT (channel_wise=True). """ - kind: Literal["window_zscore"] = "window_zscore" + _legacy_kind: ClassVar[str | None] = "window_zscore" channel_wise: bool = False clip_sigma: float | None = 15.0 eps: float = 1e-10 @@ -103,7 +118,7 @@ class ScaleToMV(Normalization): Used by EEGPT during pretraining. """ - kind: Literal["scale_to_mv"] = "scale_to_mv" + _legacy_kind: ClassVar[str | None] = "scale_to_mv" def apply(self, data: np.ndarray) -> np.ndarray: return data / 1000.0 @@ -112,7 +127,7 @@ def apply(self, data: np.ndarray) -> np.ndarray: class NoNormalization(Normalization): """No-op normalization (identity). Default when no normalization is needed.""" - kind: Literal["none"] = "none" + _legacy_kind: ClassVar[str | None] = "none" def apply(self, data: np.ndarray) -> np.ndarray: return data From 072539bb66371575ec52d94a2e2f24c0cf8e29d0 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 14 May 2026 17:25:55 +0200 Subject: [PATCH 3/5] Add test --- tests/test_normalization.py | 49 +++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 73b18e9..99b1110 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -2,10 +2,13 @@ import numpy as np import pytest +from pydantic import BaseModel from open_eeg_bench.normalization import ( DivideByConstant, MinMaxScale, + NoNormalization, + Normalization, PercentileScale, ScaleToMV, WindowZScore, @@ -58,3 +61,49 @@ def test_scale_to_mv(data): norm = ScaleToMV() result = norm.apply(data) np.testing.assert_allclose(result, data / 1000.0) + + +def test_custom_normalization(data): + """A user-defined Normalization subclass should integrate seamlessly.""" + + class AddOffset(Normalization): + offset: float = 1.0 + + def apply(self, data: np.ndarray) -> np.ndarray: + return data + self.offset + + # Direct use + norm = AddOffset(offset=2.5) + np.testing.assert_allclose(norm.apply(data), data + 2.5) + + # Usable as a Normalization field on a parent model, with round-trip + # serialization through the "kind" discriminator key. + class Parent(BaseModel): + norm: Normalization + + parent = Parent(norm=AddOffset(offset=3.0)) + assert isinstance(parent.norm, AddOffset) + + restored = Parent.model_validate(parent.model_dump()) + assert isinstance(restored.norm, AddOffset) + assert restored.norm.offset == 3.0 + + +@pytest.mark.parametrize( + "cls, legacy_kind", + [ + (DivideByConstant, "divide_by_constant"), + (PercentileScale, "percentile_scale"), + (MinMaxScale, "minmax_scale"), + (WindowZScore, "window_zscore"), + (ScaleToMV, "scale_to_mv"), + (NoNormalization, "none"), + ], +) +def test_legacy_kind_serialization(cls, legacy_kind): + """Builtin subclasses must keep their pre-DiscriminatedModel ``kind`` value + so cached experiment UIDs remain stable.""" + dump = cls().model_dump() + assert dump["kind"] == legacy_kind + restored = Normalization.model_validate(dump) + assert isinstance(restored, cls) From 9a7eab34d043564c18d03cfd2f6416a441913067 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 14 May 2026 17:26:15 +0200 Subject: [PATCH 4/5] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6de7076..4d89cc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Allow custom normalization methods by subclassing `Normalization`, now based on `exca.helpers.DiscriminatedModel`. Builtin subclasses pin their pre-existing `kind` value to preserve cached experiment UIDs ([#113](https://github.com/braindecode/OpenEEGBench/pull/113)). + ## [0.4.0] - 2026-05-07 ### Added From a8a6717dc72b03055cd4b692abd543c5214895eb Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Thu, 14 May 2026 17:26:49 +0200 Subject: [PATCH 5/5] fix pr number --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d89cc9..bbef5ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- Allow custom normalization methods by subclassing `Normalization`, now based on `exca.helpers.DiscriminatedModel`. Builtin subclasses pin their pre-existing `kind` value to preserve cached experiment UIDs ([#113](https://github.com/braindecode/OpenEEGBench/pull/113)). +- Allow custom normalization methods by subclassing `Normalization`, now based on `exca.helpers.DiscriminatedModel`. Builtin subclasses pin their pre-existing `kind` value to preserve cached experiment UIDs ([#35](https://github.com/braindecode/OpenEEGBench/pull/35)). ## [0.4.0] - 2026-05-07