From 4e38b04ef9a53a80d0e0352208c3598c808b0e9b Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Tue, 28 Apr 2026 20:03:30 +0200 Subject: [PATCH 1/4] Ridge probe: balanced class weighting for imbalanced datasets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a `class_weight` parameter (default `"balanced"`) to the streaming ridge probe. When enabled, an extra label-only pass over the train loader computes sklearn-style per-class weights `w[c] = N / (n_classes * count[c])`, which are then applied to every sufficient-statistic accumulator (A, B, s_h, s_h2, s_y, N) so that the weighted-least-squares fit, weighted centering, and weighted standardization are all internally consistent. Regression silently ignores the parameter. Verified end-to-end on REVE × chbmit (97/3 imbalance): unweighted: test_balanced_accuracy = 0.5000 (chance) balanced: test_balanced_accuracy = 0.8594 --- open_eeg_bench/ridge_probe.py | 63 ++++++++++++++++--- open_eeg_bench/training.py | 2 + tests/test_ridge_probe.py | 114 ++++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 9 deletions(-) diff --git a/open_eeg_bench/ridge_probe.py b/open_eeg_bench/ridge_probe.py index 0b43fff..4867c86 100644 --- a/open_eeg_bench/ridge_probe.py +++ b/open_eeg_bench/ridge_probe.py @@ -32,6 +32,22 @@ def _default_lambdas() -> list[float]: return [10**e for e in range(-8, 9)] +def _balanced_class_weights( + train_loader: "DataLoader", n_classes: int, device: str +) -> torch.Tensor: + """One label-only pass over train; returns sklearn-style "balanced" weights. + + ``w[c] = N / (n_classes * count[c])`` so that ``sum_i w_{y_i} == N`` and the + effective sample size is preserved. Empty classes are clamped to 1 sample + to avoid div-by-zero (their weight then has no effect on accumulation). + """ + counts = torch.zeros(n_classes, dtype=torch.float64, device=device) + for batch in train_loader: + y = batch[1].to(device).long() + counts += torch.bincount(y, minlength=n_classes).to(counts.dtype) + return counts.sum() / (n_classes * counts.clamp(min=1.0)) + + def _make_projection_matrix( n_features: int, n_components: int, seed: int, device: str ) -> torch.Tensor: @@ -58,6 +74,7 @@ def _fit_streaming_ridge( device: str, max_features: int | None = None, projection_seed: int = 0, + class_weight: str | None = "balanced", ) -> dict: """Fit streaming ridge probe, select λ on val, return weights + diagnostics. @@ -65,6 +82,14 @@ def _fit_streaming_ridge( features are projected down to ``max_features`` dimensions via a Gaussian random projection (seeded by ``projection_seed``) before accumulation. + ``class_weight`` controls per-sample weighting in the weighted-least-squares + fit. Only meaningful when ``n_classes`` is set (classification); silently + ignored for regression. Supported values: + + * ``"balanced"`` (default): sklearn-style ``N / (n_classes * count[c])`` + weights. Costs one extra label-only pass over the train loader. + * ``None``: unweighted (every sample contributes equally). + Returns dict with keys: W (D,C), bias (C,), best_lambda (float), val_scores (dict λ→score), lambdas (list[float]), n_classes, n_features D (after projection if applied), projection (torch.Tensor | None). @@ -72,10 +97,19 @@ def _fit_streaming_ridge( model.eval() model.to(device) - # ----- Pass 1: accumulate sufficient statistics on train ----- + # ----- Pass 0 (optional, classification only): per-class weights from labels ----- + class_weights = None + if class_weight == "balanced" and n_classes is not None: + class_weights = _balanced_class_weights(train_loader, n_classes, device) + elif class_weight not in (None, "balanced"): + raise ValueError( + f"class_weight must be 'balanced' or None, got {class_weight!r}." + ) + + # ----- Pass 1: accumulate (optionally weighted) sufficient statistics on train ----- A = B = s_h = s_h2 = s_y = None projection = None # (k, D_orig) float64; lazily built once D_orig is known - N = 0 + N = 0.0 # sum of sample weights (== n_samples when unweighted) with torch.no_grad(): for batch in train_loader: x, y = batch[0], batch[1] @@ -94,9 +128,17 @@ def _fit_streaming_ridge( ) if projection is not None: h64 = h64 @ projection.T # (B, k) - y_enc = _encode_targets(y.to(device), n_classes) + y_dev = y.to(device) + y_enc = _encode_targets(y_dev, n_classes) y64 = y_enc.double() + if class_weights is not None: + w = class_weights[y_dev.long()] # (B,) float64 + else: + w = torch.ones(h64.shape[0], dtype=torch.float64, device=device) + hw = h64 * w.unsqueeze(1) # (B, D), each row h_i scaled by w_i + yw = y64 * w.unsqueeze(1) # (B, C) + if A is None: D = h64.shape[1] C = y64.shape[1] @@ -106,12 +148,12 @@ def _fit_streaming_ridge( s_h2 = torch.zeros(D, dtype=torch.float64, device=device) s_y = torch.zeros(C, dtype=torch.float64, device=device) - A += h64.T @ h64 - B += h64.T @ y64 - s_h += h64.sum(0) - s_h2 += (h64**2).sum(0) - s_y += y64.sum(0) - N += h64.shape[0] + A += hw.T @ h64 + B += hw.T @ y64 + s_h += hw.sum(0) + s_h2 += (hw * h64).sum(0) + s_y += yw.sum(0) + N += float(w.sum().item()) if A is None: raise ValueError("Empty train_loader — no features accumulated.") @@ -277,6 +319,7 @@ def __init__( val_set, max_features: int | None = None, projection_seed: int = 0, + class_weight: str | None = "balanced", verbose: int = 1, ): self.model_ = feature_extractor @@ -288,6 +331,7 @@ def __init__( self.val_set = val_set self.max_features = max_features self.projection_seed = projection_seed + self.class_weight = class_weight self.verbose = verbose self._result: dict | None = None @@ -321,6 +365,7 @@ def fit(self, train_set, y=None): device=self.device, max_features=self.max_features, projection_seed=self.projection_seed, + class_weight=self.class_weight, ) if self.verbose: metric = "R²" if self.n_classes_ is None else "balanced_acc" diff --git a/open_eeg_bench/training.py b/open_eeg_bench/training.py index 1db88e3..5f1523d 100644 --- a/open_eeg_bench/training.py +++ b/open_eeg_bench/training.py @@ -219,6 +219,7 @@ class RidgeProbingTraining(BaseModel): device: str = "cpu" lambdas: list[float] | None = None # None → use the learner's default fixed log-spaced grid max_features: int | None = 5000 # if set and D > max_features, Gaussian random-project to max_features + class_weight: Literal["balanced"] | None = "balanced" # "balanced" → sklearn-style per-class weights; None → unweighted def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: int = 0): from open_eeg_bench.ridge_probe import StreamingRidgeProbeLearner @@ -234,5 +235,6 @@ def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: i val_set=val_set, max_features=self.max_features, projection_seed=seed, + class_weight=self.class_weight, verbose=verbose, ) diff --git a/tests/test_ridge_probe.py b/tests/test_ridge_probe.py index 90744ae..ffeb07e 100644 --- a/tests/test_ridge_probe.py +++ b/tests/test_ridge_probe.py @@ -69,6 +69,7 @@ def test_fit_streaming_ridge_matches_sklearn_classification(classif_data, max_fe device="cpu", max_features=max_features, projection_seed=projection_seed, + class_weight=None, # sklearn Ridge has no class_weight; disable to compare 1-to-1 ) if max_features is not None: @@ -171,6 +172,119 @@ def test_fit_streaming_ridge_regression_matches_sklearn(regression_data): assert out["val_scores"][lam] > 0.9 +@pytest.fixture +def imbalanced_classif_data(): + """Imbalanced binary classification: ~95% class 0, ~5% class 1. + + Both classes are linearly separable in feature space, so a class-aware + classifier can recover the minority class; an unweighted ridge will + collapse to predicting the majority class. + """ + rng = np.random.default_rng(0) + N, D = 4000, 20 + p_minority = 0.05 + y = (rng.random(N) < p_minority).astype(np.int64) + # Class-conditional means separated along one direction. + mu = np.zeros((2, D), dtype=np.float32) + mu[1, 0] = 3.0 + X = (mu[y] + rng.standard_normal((N, D)).astype(np.float32)) + # Shuffle then split 60/20/20 (preserves the imbalance in each split). + idx = rng.permutation(N) + X, y = X[idx], y[idx] + i1, i2 = int(0.6 * N), int(0.8 * N) + return { + "train": (X[:i1], y[:i1]), + "val": (X[i1:i2], y[i1:i2]), + "test": (X[i2:], y[i2:]), + "C": 2, + } + + +def test_balanced_class_weight_recovers_minority(imbalanced_classif_data): + """On 95/5 imbalanced data, class_weight='balanced' beats unweighted on minority recall. + + Unweighted ridge minimizes squared error, which on heavily imbalanced data + is dominated by the majority class — minority recall collapses near 0. + Balanced weighting restores per-class equal contribution and the linearly + separable signal becomes recoverable. + """ + from open_eeg_bench.ridge_probe import StreamingRidgeProbeLearner + + X_tr, y_tr = imbalanced_classif_data["train"] + X_val, y_val = imbalanced_classif_data["val"] + X_te, y_te = imbalanced_classif_data["test"] + C = imbalanced_classif_data["C"] + + train_set = TensorDataset(torch.from_numpy(X_tr), torch.from_numpy(y_tr)) + val_set = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val)) + test_set = TensorDataset(torch.from_numpy(X_te), torch.from_numpy(y_te)) + + def fit_predict(class_weight): + learner = StreamingRidgeProbeLearner( + feature_extractor=nn.Identity(), + n_classes=C, + batch_size=64, + num_workers=0, + device="cpu", + lambdas=[1e-2, 1.0, 1e2], + val_set=val_set, + class_weight=class_weight, + verbose=0, + ) + learner.fit(train_set, y=None) + return learner.predict(test_set) + + pred_unweighted = fit_predict(None) + pred_balanced = fit_predict("balanced") + + minority = y_te == 1 + recall_unweighted = (pred_unweighted[minority] == 1).mean() + recall_balanced = (pred_balanced[minority] == 1).mean() + + # Unweighted collapses to majority class — minority recall ~0. + assert recall_unweighted < 0.2, f"unweighted minority recall too high: {recall_unweighted}" + # Balanced recovers the linearly separable signal — minority recall well above chance. + assert recall_balanced > 0.7, f"balanced minority recall too low: {recall_balanced}" + + +def test_balanced_class_weight_noop_when_classes_balanced(classif_data): + """class_weight='balanced' on already-balanced classes must equal the unweighted fit. + + sklearn-style balanced weights are ``N / (n_classes * count[c])``: when every + class has the same count, all weights equal 1 exactly, so weighted and + unweighted accumulators produce identical W and bias. + """ + from open_eeg_bench.ridge_probe import _fit_streaming_ridge + + X_tr, y_tr = classif_data["train"] + X_val, y_val = classif_data["val"] + C = classif_data["C"] + # Trim the train set so each class appears exactly the same number of times. + per_class = min(int((y_tr == c).sum()) for c in range(C)) + keep = np.concatenate([np.where(y_tr == c)[0][:per_class] for c in range(C)]) + X_tr_bal, y_tr_bal = X_tr[keep], y_tr[keep] + common = dict( + model=nn.Identity(), + train_loader=_loader(X_tr_bal, y_tr_bal), + val_loader=_loader(X_val, y_val), + n_classes=C, + lambdas=[1.0], + device="cpu", + ) + out_unw = _fit_streaming_ridge(**common, class_weight=None) + out_bal = _fit_streaming_ridge( + model=nn.Identity(), + train_loader=_loader(X_tr_bal, y_tr_bal), + val_loader=_loader(X_val, y_val), + n_classes=C, + lambdas=[1.0], + device="cpu", + class_weight="balanced", + ) + torch.testing.assert_close(out_bal["W"], out_unw["W"]) + torch.testing.assert_close(out_bal["bias"], out_unw["bias"]) + + def test_fit_streaming_ridge_raises_on_nan_features(classif_data): """NaN features (e.g. from a buggy backbone) yield a clear error, not IndexError.""" from open_eeg_bench.ridge_probe import _fit_streaming_ridge From 19701f3cc46dd2a1eaccf462ffbf372256cf5ae4 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Tue, 28 Apr 2026 20:07:19 +0200 Subject: [PATCH 2/4] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 735f3ba..7932ba3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add `ScratchBackbone` for benchmarking DL models without pretrained weights. Only compatible with `FullFinetune` (enforced by the `Experiment` validator) ([#30](https://github.com/braindecode/OpenEEGBench/pull/30)). +- Add `class_weight` parameter to `RidgeProbingTraining` (`"balanced"` or `None`); **default changed to `"balanced"`** — pass `None` for the previous unweighted behavior. ### Changed - Fill in the Zenodo concept DOI (`10.5281/zenodo.19698863`) in the README DOI badge and the BibTeX snippet, and add it as an `identifiers` entry in `CITATION.cff`. From c84df6db2377e87c44caa67de57956e6baaf12f3 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Tue, 28 Apr 2026 20:50:19 +0200 Subject: [PATCH 3/4] Ridge probe: configurable dtype for MPS/float32 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a `dtype: Literal["float32", "float64"]` parameter (default `"float64"`) threaded through `RidgeProbingTraining`, `StreamingRidgeProbeLearner`, and `_fit_streaming_ridge`. All previously hardcoded float64 accumulators, eigendecomposition tensors, and predict paths now honor this dtype. `"float64"` remains the recommended precision; `"float32"` exists for devices that don't support double, notably Apple MPS. To make MPS actually work, the eigendecomposition + Ws/biases construction (operating on `(D, D)` matrices ≤ `max_features`) is now run on CPU unconditionally — `torch.linalg.eigh` is not implemented on MPS, and these matrices are small enough that the CPU detour is free. The streaming backbone forward and statistics accumulation stay on the configured device. Verified REVE × chbmit on MPS with `class_weight="balanced"`, `dtype="float32"`: test_balanced_accuracy=0.8594 (matches CPU/float64), fit_time 651s (vs 1235s on CPU). --- CHANGELOG.md | 1 + open_eeg_bench/ridge_probe.py | 122 ++++++++++++++++++++++------------ open_eeg_bench/training.py | 2 + 3 files changed, 83 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7932ba3..fff5c6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add `ScratchBackbone` for benchmarking DL models without pretrained weights. Only compatible with `FullFinetune` (enforced by the `Experiment` validator) ([#30](https://github.com/braindecode/OpenEEGBench/pull/30)). - Add `class_weight` parameter to `RidgeProbingTraining` (`"balanced"` or `None`); **default changed to `"balanced"`** — pass `None` for the previous unweighted behavior. +- Add `dtype` parameter to `RidgeProbingTraining` (`"float32"` or `"float64"`, default `"float64"`). Use `"float32"` only when necessary, e.g. on Apple MPS which does not support float64. ### Changed - Fill in the Zenodo concept DOI (`10.5281/zenodo.19698863`) in the README DOI badge and the BibTeX snippet, and add it as an `identifiers` entry in `CITATION.cff`. diff --git a/open_eeg_bench/ridge_probe.py b/open_eeg_bench/ridge_probe.py index 4867c86..52641d9 100644 --- a/open_eeg_bench/ridge_probe.py +++ b/open_eeg_bench/ridge_probe.py @@ -32,8 +32,19 @@ def _default_lambdas() -> list[float]: return [10**e for e in range(-8, 9)] +_STR_TO_DTYPE = {"float32": torch.float32, "float64": torch.float64} + + +def _resolve_dtype(dtype: str) -> torch.dtype: + if dtype not in _STR_TO_DTYPE: + raise ValueError( + f"dtype must be 'float32' or 'float64', got {dtype!r}." + ) + return _STR_TO_DTYPE[dtype] + + def _balanced_class_weights( - train_loader: "DataLoader", n_classes: int, device: str + train_loader: "DataLoader", n_classes: int, device: str, dtype: torch.dtype ) -> torch.Tensor: """One label-only pass over train; returns sklearn-style "balanced" weights. @@ -41,7 +52,7 @@ def _balanced_class_weights( effective sample size is preserved. Empty classes are clamped to 1 sample to avoid div-by-zero (their weight then has no effect on accumulation). """ - counts = torch.zeros(n_classes, dtype=torch.float64, device=device) + counts = torch.zeros(n_classes, dtype=dtype, device=device) for batch in train_loader: y = batch[1].to(device).long() counts += torch.bincount(y, minlength=n_classes).to(counts.dtype) @@ -49,7 +60,7 @@ def _balanced_class_weights( def _make_projection_matrix( - n_features: int, n_components: int, seed: int, device: str + n_features: int, n_components: int, seed: int, device: str, dtype: torch.dtype ) -> torch.Tensor: """Return a (n_components, n_features) Gaussian random projection matrix. @@ -62,7 +73,7 @@ def _make_projection_matrix( rp = GaussianRandomProjection(n_components=n_components, random_state=seed) P = rp._make_random_matrix(n_components, n_features) - return torch.from_numpy(np.asarray(P)).to(device=device, dtype=torch.float64) + return torch.from_numpy(np.asarray(P)).to(device=device, dtype=dtype) def _fit_streaming_ridge( @@ -75,6 +86,7 @@ def _fit_streaming_ridge( max_features: int | None = None, projection_seed: int = 0, class_weight: str | None = "balanced", + dtype: str = "float64", ) -> dict: """Fit streaming ridge probe, select λ on val, return weights + diagnostics. @@ -90,17 +102,25 @@ def _fit_streaming_ridge( weights. Costs one extra label-only pass over the train loader. * ``None``: unweighted (every sample contributes equally). + ``dtype`` controls the precision of all internal accumulators, the + eigendecomposition, and the returned weights. ``"float64"`` (default) is + recommended for numerical precision: covariances and eigendecompositions + can lose meaningful accuracy in single precision. Use ``"float32"`` only + when necessary, e.g. on devices like Apple's MPS that do not support + float64. + Returns dict with keys: W (D,C), bias (C,), best_lambda (float), val_scores (dict λ→score), lambdas (list[float]), n_classes, n_features D (after projection if applied), projection (torch.Tensor | None). """ model.eval() model.to(device) + torch_dtype = _resolve_dtype(dtype) # ----- Pass 0 (optional, classification only): per-class weights from labels ----- class_weights = None if class_weight == "balanced" and n_classes is not None: - class_weights = _balanced_class_weights(train_loader, n_classes, device) + class_weights = _balanced_class_weights(train_loader, n_classes, device, torch_dtype) elif class_weight not in (None, "balanced"): raise ValueError( f"class_weight must be 'balanced' or None, got {class_weight!r}." @@ -108,50 +128,51 @@ def _fit_streaming_ridge( # ----- Pass 1: accumulate (optionally weighted) sufficient statistics on train ----- A = B = s_h = s_h2 = s_y = None - projection = None # (k, D_orig) float64; lazily built once D_orig is known + projection = None # (k, D_orig); lazily built once D_orig is known N = 0.0 # sum of sample weights (== n_samples when unweighted) with torch.no_grad(): for batch in train_loader: x, y = batch[0], batch[1] h = model(x.to(device)) - h64 = h.double() + h_acc = h.to(dtype=torch_dtype) if ( projection is None and max_features is not None - and h64.shape[1] > max_features + and h_acc.shape[1] > max_features ): projection = _make_projection_matrix( - n_features=h64.shape[1], + n_features=h_acc.shape[1], n_components=max_features, seed=projection_seed, device=device, + dtype=torch_dtype, ) if projection is not None: - h64 = h64 @ projection.T # (B, k) + h_acc = h_acc @ projection.T # (B, k) y_dev = y.to(device) y_enc = _encode_targets(y_dev, n_classes) - y64 = y_enc.double() + y_acc = y_enc.to(dtype=torch_dtype) if class_weights is not None: - w = class_weights[y_dev.long()] # (B,) float64 + w = class_weights[y_dev.long()] # (B,) else: - w = torch.ones(h64.shape[0], dtype=torch.float64, device=device) - hw = h64 * w.unsqueeze(1) # (B, D), each row h_i scaled by w_i - yw = y64 * w.unsqueeze(1) # (B, C) + w = torch.ones(h_acc.shape[0], dtype=torch_dtype, device=device) + hw = h_acc * w.unsqueeze(1) # (B, D), each row h_i scaled by w_i + yw = y_acc * w.unsqueeze(1) # (B, C) if A is None: - D = h64.shape[1] - C = y64.shape[1] - A = torch.zeros(D, D, dtype=torch.float64, device=device) - B = torch.zeros(D, C, dtype=torch.float64, device=device) - s_h = torch.zeros(D, dtype=torch.float64, device=device) - s_h2 = torch.zeros(D, dtype=torch.float64, device=device) - s_y = torch.zeros(C, dtype=torch.float64, device=device) - - A += hw.T @ h64 - B += hw.T @ y64 + D = h_acc.shape[1] + C = y_acc.shape[1] + A = torch.zeros(D, D, dtype=torch_dtype, device=device) + B = torch.zeros(D, C, dtype=torch_dtype, device=device) + s_h = torch.zeros(D, dtype=torch_dtype, device=device) + s_h2 = torch.zeros(D, dtype=torch_dtype, device=device) + s_y = torch.zeros(C, dtype=torch_dtype, device=device) + + A += hw.T @ h_acc + B += hw.T @ y_acc s_h += hw.sum(0) - s_h2 += (hw * h64).sum(0) + s_h2 += (hw * h_acc).sum(0) s_y += yw.sum(0) N += float(w.sum().item()) @@ -184,21 +205,32 @@ def _fit_streaming_ridge( C_zy = C_xy * inv_std.unsqueeze(1) # ----- Eigendecomposition (once, on correlation matrix) ----- - eigvals, Q = torch.linalg.eigh(C_zz) - Ct = Q.T @ C_zy # (D, C) + # Some accelerators (notably Apple MPS) don't implement torch.linalg.eigh. + # The decomposition runs on a (D, D) matrix at most ``max_features`` wide, + # which is small and fast on CPU — do it there unconditionally and move the + # resulting weights back to ``device`` before the val pass. + C_zz_cpu = C_zz.cpu() + C_zy_cpu = C_zy.cpu() + inv_std_cpu = inv_std.cpu() + h_bar_cpu = h_bar.cpu() + y_bar_cpu = y_bar.cpu() + eigvals, Q = torch.linalg.eigh(C_zz_cpu) + Ct = Q.T @ C_zy_cpu # (D, C) if lambdas is None: lambdas = _default_lambdas() # ----- Solve for each λ in eigenbasis, convert back to original space ----- K = len(lambdas) - Ws = torch.zeros(K, D, C, dtype=torch.float64, device=device) - biases = torch.zeros(K, C, dtype=torch.float64, device=device) + Ws_cpu = torch.zeros(K, D, C, dtype=torch_dtype) + biases_cpu = torch.zeros(K, C, dtype=torch_dtype) for k, lam in enumerate(lambdas): denom = (eigvals + lam).unsqueeze(1) # (D, 1) W_z = Q @ (Ct / denom) # (D, C) in standardized space - Ws[k] = W_z * inv_std.unsqueeze(1) # back to original: w_i / std_i - biases[k] = y_bar - Ws[k].T @ h_bar # (C,) + Ws_cpu[k] = W_z * inv_std_cpu.unsqueeze(1) # back to original: w_i / std_i + biases_cpu[k] = y_bar_cpu - Ws_cpu[k].T @ h_bar_cpu # (C,) + Ws = Ws_cpu.to(device) + biases = biases_cpu.to(device) # ----- Pass 2: streaming λ selection on val ----- val_scores = _streaming_val_scores( @@ -210,6 +242,7 @@ def _fit_streaming_ridge( y_bar_train=y_bar, device=device, projection=projection, + dtype=torch_dtype, ) # tensor of shape (K,) # Among tied best scores, pick the largest λ (most regularization): @@ -247,6 +280,7 @@ def _streaming_val_scores( n_classes: int | None, y_bar_train: torch.Tensor, device: str, + dtype: torch.dtype, projection: torch.Tensor | None = None, ) -> torch.Tensor: """Accumulate per-λ metric streaming on val. Returns (K,) scores (higher=better). @@ -259,15 +293,15 @@ def _streaming_val_scores( if is_regression: # R² = 1 - SS_res / SS_tot (SS_tot using train mean; matches sklearn behavior closely enough) - ss_res = torch.zeros(K, dtype=torch.float64, device=device) - ss_tot_scalar = torch.zeros((), dtype=torch.float64, device=device) + ss_res = torch.zeros(K, dtype=dtype, device=device) + ss_tot_scalar = torch.zeros((), dtype=dtype, device=device) with torch.no_grad(): for batch in val_loader: x, y = batch[0], batch[1] - h = model(x.to(device)).double() + h = model(x.to(device)).to(dtype=dtype) if projection is not None: h = h @ projection.T - y_enc = _encode_targets(y.to(device), n_classes).double() + y_enc = _encode_targets(y.to(device), n_classes).to(dtype=dtype) preds = torch.einsum("kdc,bd->kbc", Ws, h) + biases.unsqueeze(1) res = preds - y_enc.unsqueeze(0) ss_res += (res**2).sum(dim=(1, 2)) @@ -275,11 +309,11 @@ def _streaming_val_scores( return 1.0 - ss_res / ss_tot_scalar.clamp(min=1e-12) # Classification: balanced accuracy via confusion matrix (K, C, C) - confusion = torch.zeros(K, C, C, dtype=torch.float64, device=device) + confusion = torch.zeros(K, C, C, dtype=dtype, device=device) with torch.no_grad(): for batch in val_loader: x, y = batch[0], batch[1] - h = model(x.to(device)).double() + h = model(x.to(device)).to(dtype=dtype) if projection is not None: h = h @ projection.T y_true = y.to(device).long() @@ -320,6 +354,7 @@ def __init__( max_features: int | None = None, projection_seed: int = 0, class_weight: str | None = "balanced", + dtype: str = "float64", verbose: int = 1, ): self.model_ = feature_extractor @@ -332,6 +367,7 @@ def __init__( self.max_features = max_features self.projection_seed = projection_seed self.class_weight = class_weight + self.dtype = dtype self.verbose = verbose self._result: dict | None = None @@ -366,6 +402,7 @@ def fit(self, train_set, y=None): max_features=self.max_features, projection_seed=self.projection_seed, class_weight=self.class_weight, + dtype=self.dtype, ) if self.verbose: metric = "R²" if self.n_classes_ is None else "balanced_acc" @@ -385,9 +422,10 @@ def predict(self, test_set) -> np.ndarray: if self._result is None: raise RuntimeError("Call .fit() before .predict().") - W = self._result["W"] # (D, C) float64 - bias = self._result["bias"] # (C,) float64 - projection = self._result.get("projection") # (k, D_orig) float64 or None + W = self._result["W"] # (D, C) + bias = self._result["bias"] # (C,) + projection = self._result.get("projection") # (k, D_orig) or None + torch_dtype = _resolve_dtype(self.dtype) loader = DataLoader( test_set, batch_size=self.batch_size, @@ -402,7 +440,7 @@ def predict(self, test_set) -> np.ndarray: with torch.no_grad(): for batch in loader: x = batch[0] if isinstance(batch, (list, tuple)) else batch - h = self.model_(x.to(self.device)).double() + h = self.model_(x.to(self.device)).to(dtype=torch_dtype) if projection is not None: h = h @ projection.T y_hat = h @ W + bias # (B, C) diff --git a/open_eeg_bench/training.py b/open_eeg_bench/training.py index 5f1523d..0c60bc5 100644 --- a/open_eeg_bench/training.py +++ b/open_eeg_bench/training.py @@ -220,6 +220,7 @@ class RidgeProbingTraining(BaseModel): lambdas: list[float] | None = None # None → use the learner's default fixed log-spaced grid max_features: int | None = 5000 # if set and D > max_features, Gaussian random-project to max_features class_weight: Literal["balanced"] | None = "balanced" # "balanced" → sklearn-style per-class weights; None → unweighted + dtype: Literal["float32", "float64"] = "float64" # "float64" recommended for precision; use "float32" only if necessary (e.g. Apple MPS) def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: int = 0): from open_eeg_bench.ridge_probe import StreamingRidgeProbeLearner @@ -236,5 +237,6 @@ def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: i max_features=self.max_features, projection_seed=seed, class_weight=self.class_weight, + dtype=self.dtype, verbose=verbose, ) From d8e0273453bdc7fe4b3b9e501a99f506cc3cdc55 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Tue, 28 Apr 2026 20:53:58 +0200 Subject: [PATCH 4/4] Changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fff5c6f..2975688 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add `ScratchBackbone` for benchmarking DL models without pretrained weights. Only compatible with `FullFinetune` (enforced by the `Experiment` validator) ([#30](https://github.com/braindecode/OpenEEGBench/pull/30)). -- Add `class_weight` parameter to `RidgeProbingTraining` (`"balanced"` or `None`); **default changed to `"balanced"`** — pass `None` for the previous unweighted behavior. -- Add `dtype` parameter to `RidgeProbingTraining` (`"float32"` or `"float64"`, default `"float64"`). Use `"float32"` only when necessary, e.g. on Apple MPS which does not support float64. +- Add `class_weight` parameter to `RidgeProbingTraining` (`"balanced"` or `None`); **default changed to `"balanced"`** — pass `None` for the previous unweighted behavior ([#32](https://github.com/braindecode/OpenEEGBench/pull/32)). +- Add `dtype` parameter to `RidgeProbingTraining` (`"float32"` or `"float64"`, default `"float64"`). Use `"float32"` only when necessary, e.g. on Apple MPS which does not support float64 ([#32](https://github.com/braindecode/OpenEEGBench/pull/32)). ### Changed - Fill in the Zenodo concept DOI (`10.5281/zenodo.19698863`) in the README DOI badge and the BibTeX snippet, and add it as an `identifiers` entry in `CITATION.cff`.