Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Allow custom normalization methods by subclassing `Normalization`, now based on `exca.helpers.DiscriminatedModel`. Builtin subclasses pin their pre-existing `kind` value to preserve cached experiment UIDs ([#35](https://github.com/braindecode/OpenEEGBench/pull/35)).

## [0.4.0] - 2026-05-07

### Added
Expand Down
73 changes: 40 additions & 33 deletions open_eeg_bench/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,59 @@

from __future__ import annotations

from typing import Annotated, Literal, Union
from typing import Any, ClassVar

import numpy as np
from pydantic import BaseModel, ConfigDict, Field
from pydantic import ConfigDict
from exca.helpers import DiscriminatedModel


class DivideByConstant(BaseModel):
class Normalization(DiscriminatedModel, discriminator_key="kind"):
"""Base class for all normalizations.

Subclass to define a custom normalization; instances are dispatched
automatically via the ``kind`` discriminator key. By default the
discriminator value is ``cls.__name__``. Builtin subclasses pin a
snake_case value via ``_legacy_kind`` to preserve the pre-DiscriminatedModel
serialization format and keep cached experiment UIDs stable.
"""

_legacy_kind: ClassVar[str | None] = None

model_config = ConfigDict(extra="forbid")

def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
legacy = cls.__dict__.get("_legacy_kind")
if legacy:
cls.__name__ = legacy

def apply(self, data: np.ndarray) -> np.ndarray:
"""Apply normalization to a window of EEG data."""
raise NotImplementedError


class DivideByConstant(Normalization):
"""Divide data by a constant factor.

Used by LaBraM and CBraMod (factor=100) to set the EEG unit to 0.1 mV.
"""

model_config = ConfigDict(extra="forbid")
kind: Literal["divide_by_constant"] = "divide_by_constant"
_legacy_kind: ClassVar[str | None] = "divide_by_constant"
factor: float = 100.0

def apply(self, data: np.ndarray) -> np.ndarray:
return data / self.factor


class PercentileScale(BaseModel):
class PercentileScale(Normalization):
"""Per-channel percentile normalization.

Used by BIOT: each channel is divided by the q-th percentile of its
absolute amplitude so that the bulk of values falls near [-1, 1].
"""

model_config = ConfigDict(extra="forbid")
kind: Literal["percentile_scale"] = "percentile_scale"
_legacy_kind: ClassVar[str | None] = "percentile_scale"
q: float = 95.0
eps: float = 1e-8

Expand All @@ -49,14 +73,13 @@ def apply(self, data: np.ndarray) -> np.ndarray:
return data / (quantile + self.eps)


class MinMaxScale(BaseModel):
class MinMaxScale(Normalization):
"""Per-window min-max scaling to [-1, 1].

Used by BENDR.
"""

model_config = ConfigDict(extra="forbid")
kind: Literal["minmax_scale"] = "minmax_scale"
_legacy_kind: ClassVar[str | None] = "minmax_scale"

def apply(self, data: np.ndarray) -> np.ndarray:
dmin = np.min(data)
Expand All @@ -67,14 +90,13 @@ def apply(self, data: np.ndarray) -> np.ndarray:
return 2.0 * (data - dmin) / drange - 1.0


class WindowZScore(BaseModel):
class WindowZScore(Normalization):
"""Per-window z-score normalization with optional sigma clipping.

Used by REVE (clip_sigma=15) and EEGPT (channel_wise=True).
"""

model_config = ConfigDict(extra="forbid")
kind: Literal["window_zscore"] = "window_zscore"
_legacy_kind: ClassVar[str | None] = "window_zscore"
channel_wise: bool = False
clip_sigma: float | None = 15.0
eps: float = 1e-10
Expand All @@ -90,37 +112,22 @@ def apply(self, data: np.ndarray) -> np.ndarray:
return normalised


class ScaleToMV(BaseModel):
class ScaleToMV(Normalization):
"""Convert microvolt data to millivolts (divide by 1000).

Used by EEGPT during pretraining.
"""

model_config = ConfigDict(extra="forbid")
kind: Literal["scale_to_mv"] = "scale_to_mv"
_legacy_kind: ClassVar[str | None] = "scale_to_mv"

def apply(self, data: np.ndarray) -> np.ndarray:
return data / 1000.0


class NoNormalization(BaseModel):
class NoNormalization(Normalization):
"""No-op normalization (identity). Default when no normalization is needed."""

model_config = ConfigDict(extra="forbid")
kind: Literal["none"] = "none"
_legacy_kind: ClassVar[str | None] = "none"

def apply(self, data: np.ndarray) -> np.ndarray:
return data


Normalization = Annotated[
Union[
NoNormalization,
DivideByConstant,
PercentileScale,
MinMaxScale,
WindowZScore,
ScaleToMV,
],
Field(discriminator="kind"),
]
49 changes: 49 additions & 0 deletions tests/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import numpy as np
import pytest
from pydantic import BaseModel

from open_eeg_bench.normalization import (
DivideByConstant,
MinMaxScale,
NoNormalization,
Normalization,
PercentileScale,
ScaleToMV,
WindowZScore,
Expand Down Expand Up @@ -58,3 +61,49 @@ def test_scale_to_mv(data):
norm = ScaleToMV()
result = norm.apply(data)
np.testing.assert_allclose(result, data / 1000.0)


def test_custom_normalization(data):
"""A user-defined Normalization subclass should integrate seamlessly."""

class AddOffset(Normalization):
offset: float = 1.0

def apply(self, data: np.ndarray) -> np.ndarray:
return data + self.offset

# Direct use
norm = AddOffset(offset=2.5)
np.testing.assert_allclose(norm.apply(data), data + 2.5)

# Usable as a Normalization field on a parent model, with round-trip
# serialization through the "kind" discriminator key.
class Parent(BaseModel):
norm: Normalization

parent = Parent(norm=AddOffset(offset=3.0))
assert isinstance(parent.norm, AddOffset)

restored = Parent.model_validate(parent.model_dump())
assert isinstance(restored.norm, AddOffset)
assert restored.norm.offset == 3.0


@pytest.mark.parametrize(
"cls, legacy_kind",
[
(DivideByConstant, "divide_by_constant"),
(PercentileScale, "percentile_scale"),
(MinMaxScale, "minmax_scale"),
(WindowZScore, "window_zscore"),
(ScaleToMV, "scale_to_mv"),
(NoNormalization, "none"),
],
)
def test_legacy_kind_serialization(cls, legacy_kind):
"""Builtin subclasses must keep their pre-DiscriminatedModel ``kind`` value
so cached experiment UIDs remain stable."""
dump = cls().model_dump()
assert dump["kind"] == legacy_kind
restored = Normalization.model_validate(dump)
assert isinstance(restored, cls)
Loading