Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Add `ScratchBackbone` for benchmarking DL models without pretrained weights. Only compatible with `FullFinetune` (enforced by the `Experiment` validator) ([#30](https://github.com/braindecode/OpenEEGBench/pull/30)).
- Add `class_weight` parameter to `RidgeProbingTraining` (`"balanced"` or `None`); **default changed to `"balanced"`** — pass `None` for the previous unweighted behavior ([#32](https://github.com/braindecode/OpenEEGBench/pull/32)).
- Add `dtype` parameter to `RidgeProbingTraining` (`"float32"` or `"float64"`, default `"float64"`). Use `"float32"` only when necessary, e.g. on Apple MPS which does not support float64 ([#32](https://github.com/braindecode/OpenEEGBench/pull/32)).

### Changed
- Fill in the Zenodo concept DOI (`10.5281/zenodo.19698863`) in the README DOI badge and the BibTeX snippet, and add it as an `identifiers` entry in `CITATION.cff`.
Expand Down
165 changes: 124 additions & 41 deletions open_eeg_bench/ridge_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,35 @@ def _default_lambdas() -> list[float]:
return [10**e for e in range(-8, 9)]


_STR_TO_DTYPE = {"float32": torch.float32, "float64": torch.float64}


def _resolve_dtype(dtype: str) -> torch.dtype:
if dtype not in _STR_TO_DTYPE:
raise ValueError(
f"dtype must be 'float32' or 'float64', got {dtype!r}."
)
return _STR_TO_DTYPE[dtype]


def _balanced_class_weights(
train_loader: "DataLoader", n_classes: int, device: str, dtype: torch.dtype
) -> torch.Tensor:
"""One label-only pass over train; returns sklearn-style "balanced" weights.

``w[c] = N / (n_classes * count[c])`` so that ``sum_i w_{y_i} == N`` and the
effective sample size is preserved. Empty classes are clamped to 1 sample
to avoid div-by-zero (their weight then has no effect on accumulation).
"""
counts = torch.zeros(n_classes, dtype=dtype, device=device)
for batch in train_loader:
y = batch[1].to(device).long()
counts += torch.bincount(y, minlength=n_classes).to(counts.dtype)
return counts.sum() / (n_classes * counts.clamp(min=1.0))


def _make_projection_matrix(
n_features: int, n_components: int, seed: int, device: str
n_features: int, n_components: int, seed: int, device: str, dtype: torch.dtype
) -> torch.Tensor:
"""Return a (n_components, n_features) Gaussian random projection matrix.

Expand All @@ -46,7 +73,7 @@ def _make_projection_matrix(

rp = GaussianRandomProjection(n_components=n_components, random_state=seed)
P = rp._make_random_matrix(n_components, n_features)
return torch.from_numpy(np.asarray(P)).to(device=device, dtype=torch.float64)
return torch.from_numpy(np.asarray(P)).to(device=device, dtype=dtype)


def _fit_streaming_ridge(
Expand All @@ -58,60 +85,96 @@ def _fit_streaming_ridge(
device: str,
max_features: int | None = None,
projection_seed: int = 0,
class_weight: str | None = "balanced",
dtype: str = "float64",
) -> dict:
"""Fit streaming ridge probe, select λ on val, return weights + diagnostics.

If ``max_features`` is set and the backbone emits more features than that,
features are projected down to ``max_features`` dimensions via a Gaussian
random projection (seeded by ``projection_seed``) before accumulation.

``class_weight`` controls per-sample weighting in the weighted-least-squares
fit. Only meaningful when ``n_classes`` is set (classification); silently
ignored for regression. Supported values:

* ``"balanced"`` (default): sklearn-style ``N / (n_classes * count[c])``
weights. Costs one extra label-only pass over the train loader.
* ``None``: unweighted (every sample contributes equally).

``dtype`` controls the precision of all internal accumulators, the
eigendecomposition, and the returned weights. ``"float64"`` (default) is
recommended for numerical precision: covariances and eigendecompositions
can lose meaningful accuracy in single precision. Use ``"float32"`` only
when necessary, e.g. on devices like Apple's MPS that do not support
float64.

Returns dict with keys: W (D,C), bias (C,), best_lambda (float),
val_scores (dict λ→score), lambdas (list[float]), n_classes, n_features D
(after projection if applied), projection (torch.Tensor | None).
"""
model.eval()
model.to(device)
torch_dtype = _resolve_dtype(dtype)

# ----- Pass 0 (optional, classification only): per-class weights from labels -----
class_weights = None
if class_weight == "balanced" and n_classes is not None:
class_weights = _balanced_class_weights(train_loader, n_classes, device, torch_dtype)
elif class_weight not in (None, "balanced"):
raise ValueError(
f"class_weight must be 'balanced' or None, got {class_weight!r}."
)

# ----- Pass 1: accumulate sufficient statistics on train -----
# ----- Pass 1: accumulate (optionally weighted) sufficient statistics on train -----
A = B = s_h = s_h2 = s_y = None
projection = None # (k, D_orig) float64; lazily built once D_orig is known
N = 0
projection = None # (k, D_orig); lazily built once D_orig is known
N = 0.0 # sum of sample weights (== n_samples when unweighted)
with torch.no_grad():
for batch in train_loader:
x, y = batch[0], batch[1]
h = model(x.to(device))
h64 = h.double()
h_acc = h.to(dtype=torch_dtype)
if (
projection is None
and max_features is not None
and h64.shape[1] > max_features
and h_acc.shape[1] > max_features
):
projection = _make_projection_matrix(
n_features=h64.shape[1],
n_features=h_acc.shape[1],
n_components=max_features,
seed=projection_seed,
device=device,
dtype=torch_dtype,
)
if projection is not None:
h64 = h64 @ projection.T # (B, k)
y_enc = _encode_targets(y.to(device), n_classes)
y64 = y_enc.double()
h_acc = h_acc @ projection.T # (B, k)
y_dev = y.to(device)
y_enc = _encode_targets(y_dev, n_classes)
y_acc = y_enc.to(dtype=torch_dtype)

if class_weights is not None:
w = class_weights[y_dev.long()] # (B,)
else:
w = torch.ones(h_acc.shape[0], dtype=torch_dtype, device=device)
hw = h_acc * w.unsqueeze(1) # (B, D), each row h_i scaled by w_i
yw = y_acc * w.unsqueeze(1) # (B, C)

if A is None:
D = h64.shape[1]
C = y64.shape[1]
A = torch.zeros(D, D, dtype=torch.float64, device=device)
B = torch.zeros(D, C, dtype=torch.float64, device=device)
s_h = torch.zeros(D, dtype=torch.float64, device=device)
s_h2 = torch.zeros(D, dtype=torch.float64, device=device)
s_y = torch.zeros(C, dtype=torch.float64, device=device)

A += h64.T @ h64
B += h64.T @ y64
s_h += h64.sum(0)
s_h2 += (h64**2).sum(0)
s_y += y64.sum(0)
N += h64.shape[0]
D = h_acc.shape[1]
C = y_acc.shape[1]
A = torch.zeros(D, D, dtype=torch_dtype, device=device)
B = torch.zeros(D, C, dtype=torch_dtype, device=device)
s_h = torch.zeros(D, dtype=torch_dtype, device=device)
s_h2 = torch.zeros(D, dtype=torch_dtype, device=device)
s_y = torch.zeros(C, dtype=torch_dtype, device=device)

A += hw.T @ h_acc
B += hw.T @ y_acc
s_h += hw.sum(0)
s_h2 += (hw * h_acc).sum(0)
s_y += yw.sum(0)
N += float(w.sum().item())

if A is None:
raise ValueError("Empty train_loader — no features accumulated.")
Expand Down Expand Up @@ -142,21 +205,32 @@ def _fit_streaming_ridge(
C_zy = C_xy * inv_std.unsqueeze(1)

# ----- Eigendecomposition (once, on correlation matrix) -----
eigvals, Q = torch.linalg.eigh(C_zz)
Ct = Q.T @ C_zy # (D, C)
# Some accelerators (notably Apple MPS) don't implement torch.linalg.eigh.
Comment thread
PierreGtch marked this conversation as resolved.
# The decomposition runs on a (D, D) matrix at most ``max_features`` wide,
# which is small and fast on CPU — do it there unconditionally and move the
# resulting weights back to ``device`` before the val pass.
C_zz_cpu = C_zz.cpu()
C_zy_cpu = C_zy.cpu()
inv_std_cpu = inv_std.cpu()
h_bar_cpu = h_bar.cpu()
y_bar_cpu = y_bar.cpu()
eigvals, Q = torch.linalg.eigh(C_zz_cpu)
Ct = Q.T @ C_zy_cpu # (D, C)

if lambdas is None:
lambdas = _default_lambdas()

# ----- Solve for each λ in eigenbasis, convert back to original space -----
K = len(lambdas)
Ws = torch.zeros(K, D, C, dtype=torch.float64, device=device)
biases = torch.zeros(K, C, dtype=torch.float64, device=device)
Ws_cpu = torch.zeros(K, D, C, dtype=torch_dtype)
biases_cpu = torch.zeros(K, C, dtype=torch_dtype)
for k, lam in enumerate(lambdas):
denom = (eigvals + lam).unsqueeze(1) # (D, 1)
W_z = Q @ (Ct / denom) # (D, C) in standardized space
Ws[k] = W_z * inv_std.unsqueeze(1) # back to original: w_i / std_i
biases[k] = y_bar - Ws[k].T @ h_bar # (C,)
Ws_cpu[k] = W_z * inv_std_cpu.unsqueeze(1) # back to original: w_i / std_i
biases_cpu[k] = y_bar_cpu - Ws_cpu[k].T @ h_bar_cpu # (C,)
Ws = Ws_cpu.to(device)
biases = biases_cpu.to(device)

# ----- Pass 2: streaming λ selection on val -----
val_scores = _streaming_val_scores(
Expand All @@ -168,6 +242,7 @@ def _fit_streaming_ridge(
y_bar_train=y_bar,
device=device,
projection=projection,
dtype=torch_dtype,
) # tensor of shape (K,)

# Among tied best scores, pick the largest λ (most regularization):
Expand Down Expand Up @@ -205,6 +280,7 @@ def _streaming_val_scores(
n_classes: int | None,
y_bar_train: torch.Tensor,
device: str,
dtype: torch.dtype,
projection: torch.Tensor | None = None,
) -> torch.Tensor:
"""Accumulate per-λ metric streaming on val. Returns (K,) scores (higher=better).
Expand All @@ -217,27 +293,27 @@ def _streaming_val_scores(

if is_regression:
# R² = 1 - SS_res / SS_tot (SS_tot using train mean; matches sklearn behavior closely enough)
ss_res = torch.zeros(K, dtype=torch.float64, device=device)
ss_tot_scalar = torch.zeros((), dtype=torch.float64, device=device)
ss_res = torch.zeros(K, dtype=dtype, device=device)
ss_tot_scalar = torch.zeros((), dtype=dtype, device=device)
with torch.no_grad():
for batch in val_loader:
x, y = batch[0], batch[1]
h = model(x.to(device)).double()
h = model(x.to(device)).to(dtype=dtype)
if projection is not None:
h = h @ projection.T
y_enc = _encode_targets(y.to(device), n_classes).double()
y_enc = _encode_targets(y.to(device), n_classes).to(dtype=dtype)
preds = torch.einsum("kdc,bd->kbc", Ws, h) + biases.unsqueeze(1)
res = preds - y_enc.unsqueeze(0)
ss_res += (res**2).sum(dim=(1, 2))
ss_tot_scalar += ((y_enc - y_bar_train) ** 2).sum()
return 1.0 - ss_res / ss_tot_scalar.clamp(min=1e-12)

# Classification: balanced accuracy via confusion matrix (K, C, C)
confusion = torch.zeros(K, C, C, dtype=torch.float64, device=device)
confusion = torch.zeros(K, C, C, dtype=dtype, device=device)
with torch.no_grad():
for batch in val_loader:
x, y = batch[0], batch[1]
h = model(x.to(device)).double()
h = model(x.to(device)).to(dtype=dtype)
if projection is not None:
h = h @ projection.T
y_true = y.to(device).long()
Expand Down Expand Up @@ -277,6 +353,8 @@ def __init__(
val_set,
max_features: int | None = None,
projection_seed: int = 0,
class_weight: str | None = "balanced",
dtype: str = "float64",
verbose: int = 1,
):
self.model_ = feature_extractor
Expand All @@ -288,6 +366,8 @@ def __init__(
self.val_set = val_set
self.max_features = max_features
self.projection_seed = projection_seed
self.class_weight = class_weight
self.dtype = dtype
self.verbose = verbose
self._result: dict | None = None

Expand Down Expand Up @@ -321,6 +401,8 @@ def fit(self, train_set, y=None):
device=self.device,
max_features=self.max_features,
projection_seed=self.projection_seed,
class_weight=self.class_weight,
dtype=self.dtype,
)
if self.verbose:
metric = "R²" if self.n_classes_ is None else "balanced_acc"
Expand All @@ -340,9 +422,10 @@ def predict(self, test_set) -> np.ndarray:
if self._result is None:
raise RuntimeError("Call .fit() before .predict().")

W = self._result["W"] # (D, C) float64
bias = self._result["bias"] # (C,) float64
projection = self._result.get("projection") # (k, D_orig) float64 or None
W = self._result["W"] # (D, C)
bias = self._result["bias"] # (C,)
projection = self._result.get("projection") # (k, D_orig) or None
torch_dtype = _resolve_dtype(self.dtype)
loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand All @@ -357,7 +440,7 @@ def predict(self, test_set) -> np.ndarray:
with torch.no_grad():
for batch in loader:
x = batch[0] if isinstance(batch, (list, tuple)) else batch
h = self.model_(x.to(self.device)).double()
h = self.model_(x.to(self.device)).to(dtype=torch_dtype)
if projection is not None:
h = h @ projection.T
y_hat = h @ W + bias # (B, C)
Expand Down
4 changes: 4 additions & 0 deletions open_eeg_bench/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ class RidgeProbingTraining(BaseModel):
device: str = "cpu"
lambdas: list[float] | None = None # None → use the learner's default fixed log-spaced grid
max_features: int | None = 5000 # if set and D > max_features, Gaussian random-project to max_features
class_weight: Literal["balanced"] | None = "balanced" # "balanced" → sklearn-style per-class weights; None → unweighted
dtype: Literal["float32", "float64"] = "float64" # "float64" recommended for precision; use "float32" only if necessary (e.g. Apple MPS)

def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: int = 0):
from open_eeg_bench.ridge_probe import StreamingRidgeProbeLearner
Expand All @@ -234,5 +236,7 @@ def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: i
val_set=val_set,
max_features=self.max_features,
projection_seed=seed,
class_weight=self.class_weight,
dtype=self.dtype,
verbose=verbose,
)
Loading
Loading