diff --git a/CHANGELOG.md b/CHANGELOG.md index 735f3ba..2975688 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add `ScratchBackbone` for benchmarking DL models without pretrained weights. Only compatible with `FullFinetune` (enforced by the `Experiment` validator) ([#30](https://github.com/braindecode/OpenEEGBench/pull/30)). +- Add `class_weight` parameter to `RidgeProbingTraining` (`"balanced"` or `None`); **default changed to `"balanced"`** — pass `None` for the previous unweighted behavior ([#32](https://github.com/braindecode/OpenEEGBench/pull/32)). +- Add `dtype` parameter to `RidgeProbingTraining` (`"float32"` or `"float64"`, default `"float64"`). Use `"float32"` only when necessary, e.g. on Apple MPS which does not support float64 ([#32](https://github.com/braindecode/OpenEEGBench/pull/32)). ### Changed - Fill in the Zenodo concept DOI (`10.5281/zenodo.19698863`) in the README DOI badge and the BibTeX snippet, and add it as an `identifiers` entry in `CITATION.cff`. diff --git a/open_eeg_bench/ridge_probe.py b/open_eeg_bench/ridge_probe.py index 0b43fff..52641d9 100644 --- a/open_eeg_bench/ridge_probe.py +++ b/open_eeg_bench/ridge_probe.py @@ -32,8 +32,35 @@ def _default_lambdas() -> list[float]: return [10**e for e in range(-8, 9)] +_STR_TO_DTYPE = {"float32": torch.float32, "float64": torch.float64} + + +def _resolve_dtype(dtype: str) -> torch.dtype: + if dtype not in _STR_TO_DTYPE: + raise ValueError( + f"dtype must be 'float32' or 'float64', got {dtype!r}." + ) + return _STR_TO_DTYPE[dtype] + + +def _balanced_class_weights( + train_loader: "DataLoader", n_classes: int, device: str, dtype: torch.dtype +) -> torch.Tensor: + """One label-only pass over train; returns sklearn-style "balanced" weights. + + ``w[c] = N / (n_classes * count[c])`` so that ``sum_i w_{y_i} == N`` and the + effective sample size is preserved. Empty classes are clamped to 1 sample + to avoid div-by-zero (their weight then has no effect on accumulation). + """ + counts = torch.zeros(n_classes, dtype=dtype, device=device) + for batch in train_loader: + y = batch[1].to(device).long() + counts += torch.bincount(y, minlength=n_classes).to(counts.dtype) + return counts.sum() / (n_classes * counts.clamp(min=1.0)) + + def _make_projection_matrix( - n_features: int, n_components: int, seed: int, device: str + n_features: int, n_components: int, seed: int, device: str, dtype: torch.dtype ) -> torch.Tensor: """Return a (n_components, n_features) Gaussian random projection matrix. @@ -46,7 +73,7 @@ def _make_projection_matrix( rp = GaussianRandomProjection(n_components=n_components, random_state=seed) P = rp._make_random_matrix(n_components, n_features) - return torch.from_numpy(np.asarray(P)).to(device=device, dtype=torch.float64) + return torch.from_numpy(np.asarray(P)).to(device=device, dtype=dtype) def _fit_streaming_ridge( @@ -58,6 +85,8 @@ def _fit_streaming_ridge( device: str, max_features: int | None = None, projection_seed: int = 0, + class_weight: str | None = "balanced", + dtype: str = "float64", ) -> dict: """Fit streaming ridge probe, select λ on val, return weights + diagnostics. @@ -65,53 +94,87 @@ def _fit_streaming_ridge( features are projected down to ``max_features`` dimensions via a Gaussian random projection (seeded by ``projection_seed``) before accumulation. + ``class_weight`` controls per-sample weighting in the weighted-least-squares + fit. Only meaningful when ``n_classes`` is set (classification); silently + ignored for regression. Supported values: + + * ``"balanced"`` (default): sklearn-style ``N / (n_classes * count[c])`` + weights. Costs one extra label-only pass over the train loader. + * ``None``: unweighted (every sample contributes equally). + + ``dtype`` controls the precision of all internal accumulators, the + eigendecomposition, and the returned weights. ``"float64"`` (default) is + recommended for numerical precision: covariances and eigendecompositions + can lose meaningful accuracy in single precision. Use ``"float32"`` only + when necessary, e.g. on devices like Apple's MPS that do not support + float64. + Returns dict with keys: W (D,C), bias (C,), best_lambda (float), val_scores (dict λ→score), lambdas (list[float]), n_classes, n_features D (after projection if applied), projection (torch.Tensor | None). """ model.eval() model.to(device) + torch_dtype = _resolve_dtype(dtype) + + # ----- Pass 0 (optional, classification only): per-class weights from labels ----- + class_weights = None + if class_weight == "balanced" and n_classes is not None: + class_weights = _balanced_class_weights(train_loader, n_classes, device, torch_dtype) + elif class_weight not in (None, "balanced"): + raise ValueError( + f"class_weight must be 'balanced' or None, got {class_weight!r}." + ) - # ----- Pass 1: accumulate sufficient statistics on train ----- + # ----- Pass 1: accumulate (optionally weighted) sufficient statistics on train ----- A = B = s_h = s_h2 = s_y = None - projection = None # (k, D_orig) float64; lazily built once D_orig is known - N = 0 + projection = None # (k, D_orig); lazily built once D_orig is known + N = 0.0 # sum of sample weights (== n_samples when unweighted) with torch.no_grad(): for batch in train_loader: x, y = batch[0], batch[1] h = model(x.to(device)) - h64 = h.double() + h_acc = h.to(dtype=torch_dtype) if ( projection is None and max_features is not None - and h64.shape[1] > max_features + and h_acc.shape[1] > max_features ): projection = _make_projection_matrix( - n_features=h64.shape[1], + n_features=h_acc.shape[1], n_components=max_features, seed=projection_seed, device=device, + dtype=torch_dtype, ) if projection is not None: - h64 = h64 @ projection.T # (B, k) - y_enc = _encode_targets(y.to(device), n_classes) - y64 = y_enc.double() + h_acc = h_acc @ projection.T # (B, k) + y_dev = y.to(device) + y_enc = _encode_targets(y_dev, n_classes) + y_acc = y_enc.to(dtype=torch_dtype) + + if class_weights is not None: + w = class_weights[y_dev.long()] # (B,) + else: + w = torch.ones(h_acc.shape[0], dtype=torch_dtype, device=device) + hw = h_acc * w.unsqueeze(1) # (B, D), each row h_i scaled by w_i + yw = y_acc * w.unsqueeze(1) # (B, C) if A is None: - D = h64.shape[1] - C = y64.shape[1] - A = torch.zeros(D, D, dtype=torch.float64, device=device) - B = torch.zeros(D, C, dtype=torch.float64, device=device) - s_h = torch.zeros(D, dtype=torch.float64, device=device) - s_h2 = torch.zeros(D, dtype=torch.float64, device=device) - s_y = torch.zeros(C, dtype=torch.float64, device=device) - - A += h64.T @ h64 - B += h64.T @ y64 - s_h += h64.sum(0) - s_h2 += (h64**2).sum(0) - s_y += y64.sum(0) - N += h64.shape[0] + D = h_acc.shape[1] + C = y_acc.shape[1] + A = torch.zeros(D, D, dtype=torch_dtype, device=device) + B = torch.zeros(D, C, dtype=torch_dtype, device=device) + s_h = torch.zeros(D, dtype=torch_dtype, device=device) + s_h2 = torch.zeros(D, dtype=torch_dtype, device=device) + s_y = torch.zeros(C, dtype=torch_dtype, device=device) + + A += hw.T @ h_acc + B += hw.T @ y_acc + s_h += hw.sum(0) + s_h2 += (hw * h_acc).sum(0) + s_y += yw.sum(0) + N += float(w.sum().item()) if A is None: raise ValueError("Empty train_loader — no features accumulated.") @@ -142,21 +205,32 @@ def _fit_streaming_ridge( C_zy = C_xy * inv_std.unsqueeze(1) # ----- Eigendecomposition (once, on correlation matrix) ----- - eigvals, Q = torch.linalg.eigh(C_zz) - Ct = Q.T @ C_zy # (D, C) + # Some accelerators (notably Apple MPS) don't implement torch.linalg.eigh. + # The decomposition runs on a (D, D) matrix at most ``max_features`` wide, + # which is small and fast on CPU — do it there unconditionally and move the + # resulting weights back to ``device`` before the val pass. + C_zz_cpu = C_zz.cpu() + C_zy_cpu = C_zy.cpu() + inv_std_cpu = inv_std.cpu() + h_bar_cpu = h_bar.cpu() + y_bar_cpu = y_bar.cpu() + eigvals, Q = torch.linalg.eigh(C_zz_cpu) + Ct = Q.T @ C_zy_cpu # (D, C) if lambdas is None: lambdas = _default_lambdas() # ----- Solve for each λ in eigenbasis, convert back to original space ----- K = len(lambdas) - Ws = torch.zeros(K, D, C, dtype=torch.float64, device=device) - biases = torch.zeros(K, C, dtype=torch.float64, device=device) + Ws_cpu = torch.zeros(K, D, C, dtype=torch_dtype) + biases_cpu = torch.zeros(K, C, dtype=torch_dtype) for k, lam in enumerate(lambdas): denom = (eigvals + lam).unsqueeze(1) # (D, 1) W_z = Q @ (Ct / denom) # (D, C) in standardized space - Ws[k] = W_z * inv_std.unsqueeze(1) # back to original: w_i / std_i - biases[k] = y_bar - Ws[k].T @ h_bar # (C,) + Ws_cpu[k] = W_z * inv_std_cpu.unsqueeze(1) # back to original: w_i / std_i + biases_cpu[k] = y_bar_cpu - Ws_cpu[k].T @ h_bar_cpu # (C,) + Ws = Ws_cpu.to(device) + biases = biases_cpu.to(device) # ----- Pass 2: streaming λ selection on val ----- val_scores = _streaming_val_scores( @@ -168,6 +242,7 @@ def _fit_streaming_ridge( y_bar_train=y_bar, device=device, projection=projection, + dtype=torch_dtype, ) # tensor of shape (K,) # Among tied best scores, pick the largest λ (most regularization): @@ -205,6 +280,7 @@ def _streaming_val_scores( n_classes: int | None, y_bar_train: torch.Tensor, device: str, + dtype: torch.dtype, projection: torch.Tensor | None = None, ) -> torch.Tensor: """Accumulate per-λ metric streaming on val. Returns (K,) scores (higher=better). @@ -217,15 +293,15 @@ def _streaming_val_scores( if is_regression: # R² = 1 - SS_res / SS_tot (SS_tot using train mean; matches sklearn behavior closely enough) - ss_res = torch.zeros(K, dtype=torch.float64, device=device) - ss_tot_scalar = torch.zeros((), dtype=torch.float64, device=device) + ss_res = torch.zeros(K, dtype=dtype, device=device) + ss_tot_scalar = torch.zeros((), dtype=dtype, device=device) with torch.no_grad(): for batch in val_loader: x, y = batch[0], batch[1] - h = model(x.to(device)).double() + h = model(x.to(device)).to(dtype=dtype) if projection is not None: h = h @ projection.T - y_enc = _encode_targets(y.to(device), n_classes).double() + y_enc = _encode_targets(y.to(device), n_classes).to(dtype=dtype) preds = torch.einsum("kdc,bd->kbc", Ws, h) + biases.unsqueeze(1) res = preds - y_enc.unsqueeze(0) ss_res += (res**2).sum(dim=(1, 2)) @@ -233,11 +309,11 @@ def _streaming_val_scores( return 1.0 - ss_res / ss_tot_scalar.clamp(min=1e-12) # Classification: balanced accuracy via confusion matrix (K, C, C) - confusion = torch.zeros(K, C, C, dtype=torch.float64, device=device) + confusion = torch.zeros(K, C, C, dtype=dtype, device=device) with torch.no_grad(): for batch in val_loader: x, y = batch[0], batch[1] - h = model(x.to(device)).double() + h = model(x.to(device)).to(dtype=dtype) if projection is not None: h = h @ projection.T y_true = y.to(device).long() @@ -277,6 +353,8 @@ def __init__( val_set, max_features: int | None = None, projection_seed: int = 0, + class_weight: str | None = "balanced", + dtype: str = "float64", verbose: int = 1, ): self.model_ = feature_extractor @@ -288,6 +366,8 @@ def __init__( self.val_set = val_set self.max_features = max_features self.projection_seed = projection_seed + self.class_weight = class_weight + self.dtype = dtype self.verbose = verbose self._result: dict | None = None @@ -321,6 +401,8 @@ def fit(self, train_set, y=None): device=self.device, max_features=self.max_features, projection_seed=self.projection_seed, + class_weight=self.class_weight, + dtype=self.dtype, ) if self.verbose: metric = "R²" if self.n_classes_ is None else "balanced_acc" @@ -340,9 +422,10 @@ def predict(self, test_set) -> np.ndarray: if self._result is None: raise RuntimeError("Call .fit() before .predict().") - W = self._result["W"] # (D, C) float64 - bias = self._result["bias"] # (C,) float64 - projection = self._result.get("projection") # (k, D_orig) float64 or None + W = self._result["W"] # (D, C) + bias = self._result["bias"] # (C,) + projection = self._result.get("projection") # (k, D_orig) or None + torch_dtype = _resolve_dtype(self.dtype) loader = DataLoader( test_set, batch_size=self.batch_size, @@ -357,7 +440,7 @@ def predict(self, test_set) -> np.ndarray: with torch.no_grad(): for batch in loader: x = batch[0] if isinstance(batch, (list, tuple)) else batch - h = self.model_(x.to(self.device)).double() + h = self.model_(x.to(self.device)).to(dtype=torch_dtype) if projection is not None: h = h @ projection.T y_hat = h @ W + bias # (B, C) diff --git a/open_eeg_bench/training.py b/open_eeg_bench/training.py index 1db88e3..0c60bc5 100644 --- a/open_eeg_bench/training.py +++ b/open_eeg_bench/training.py @@ -219,6 +219,8 @@ class RidgeProbingTraining(BaseModel): device: str = "cpu" lambdas: list[float] | None = None # None → use the learner's default fixed log-spaced grid max_features: int | None = 5000 # if set and D > max_features, Gaussian random-project to max_features + class_weight: Literal["balanced"] | None = "balanced" # "balanced" → sklearn-style per-class weights; None → unweighted + dtype: Literal["float32", "float64"] = "float64" # "float64" recommended for precision; use "float32" only if necessary (e.g. Apple MPS) def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: int = 0): from open_eeg_bench.ridge_probe import StreamingRidgeProbeLearner @@ -234,5 +236,7 @@ def build_learner(self, model, callbacks, n_classes, val_set, verbose=1, seed: i val_set=val_set, max_features=self.max_features, projection_seed=seed, + class_weight=self.class_weight, + dtype=self.dtype, verbose=verbose, ) diff --git a/tests/test_ridge_probe.py b/tests/test_ridge_probe.py index 90744ae..ffeb07e 100644 --- a/tests/test_ridge_probe.py +++ b/tests/test_ridge_probe.py @@ -69,6 +69,7 @@ def test_fit_streaming_ridge_matches_sklearn_classification(classif_data, max_fe device="cpu", max_features=max_features, projection_seed=projection_seed, + class_weight=None, # sklearn Ridge has no class_weight; disable to compare 1-to-1 ) if max_features is not None: @@ -171,6 +172,119 @@ def test_fit_streaming_ridge_regression_matches_sklearn(regression_data): assert out["val_scores"][lam] > 0.9 +@pytest.fixture +def imbalanced_classif_data(): + """Imbalanced binary classification: ~95% class 0, ~5% class 1. + + Both classes are linearly separable in feature space, so a class-aware + classifier can recover the minority class; an unweighted ridge will + collapse to predicting the majority class. + """ + rng = np.random.default_rng(0) + N, D = 4000, 20 + p_minority = 0.05 + y = (rng.random(N) < p_minority).astype(np.int64) + # Class-conditional means separated along one direction. + mu = np.zeros((2, D), dtype=np.float32) + mu[1, 0] = 3.0 + X = (mu[y] + rng.standard_normal((N, D)).astype(np.float32)) + # Shuffle then split 60/20/20 (preserves the imbalance in each split). + idx = rng.permutation(N) + X, y = X[idx], y[idx] + i1, i2 = int(0.6 * N), int(0.8 * N) + return { + "train": (X[:i1], y[:i1]), + "val": (X[i1:i2], y[i1:i2]), + "test": (X[i2:], y[i2:]), + "C": 2, + } + + +def test_balanced_class_weight_recovers_minority(imbalanced_classif_data): + """On 95/5 imbalanced data, class_weight='balanced' beats unweighted on minority recall. + + Unweighted ridge minimizes squared error, which on heavily imbalanced data + is dominated by the majority class — minority recall collapses near 0. + Balanced weighting restores per-class equal contribution and the linearly + separable signal becomes recoverable. + """ + from open_eeg_bench.ridge_probe import StreamingRidgeProbeLearner + + X_tr, y_tr = imbalanced_classif_data["train"] + X_val, y_val = imbalanced_classif_data["val"] + X_te, y_te = imbalanced_classif_data["test"] + C = imbalanced_classif_data["C"] + + train_set = TensorDataset(torch.from_numpy(X_tr), torch.from_numpy(y_tr)) + val_set = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val)) + test_set = TensorDataset(torch.from_numpy(X_te), torch.from_numpy(y_te)) + + def fit_predict(class_weight): + learner = StreamingRidgeProbeLearner( + feature_extractor=nn.Identity(), + n_classes=C, + batch_size=64, + num_workers=0, + device="cpu", + lambdas=[1e-2, 1.0, 1e2], + val_set=val_set, + class_weight=class_weight, + verbose=0, + ) + learner.fit(train_set, y=None) + return learner.predict(test_set) + + pred_unweighted = fit_predict(None) + pred_balanced = fit_predict("balanced") + + minority = y_te == 1 + recall_unweighted = (pred_unweighted[minority] == 1).mean() + recall_balanced = (pred_balanced[minority] == 1).mean() + + # Unweighted collapses to majority class — minority recall ~0. + assert recall_unweighted < 0.2, f"unweighted minority recall too high: {recall_unweighted}" + # Balanced recovers the linearly separable signal — minority recall well above chance. + assert recall_balanced > 0.7, f"balanced minority recall too low: {recall_balanced}" + + +def test_balanced_class_weight_noop_when_classes_balanced(classif_data): + """class_weight='balanced' on already-balanced classes must equal the unweighted fit. + + sklearn-style balanced weights are ``N / (n_classes * count[c])``: when every + class has the same count, all weights equal 1 exactly, so weighted and + unweighted accumulators produce identical W and bias. + """ + from open_eeg_bench.ridge_probe import _fit_streaming_ridge + + X_tr, y_tr = classif_data["train"] + X_val, y_val = classif_data["val"] + C = classif_data["C"] + # Trim the train set so each class appears exactly the same number of times. + per_class = min(int((y_tr == c).sum()) for c in range(C)) + keep = np.concatenate([np.where(y_tr == c)[0][:per_class] for c in range(C)]) + X_tr_bal, y_tr_bal = X_tr[keep], y_tr[keep] + common = dict( + model=nn.Identity(), + train_loader=_loader(X_tr_bal, y_tr_bal), + val_loader=_loader(X_val, y_val), + n_classes=C, + lambdas=[1.0], + device="cpu", + ) + out_unw = _fit_streaming_ridge(**common, class_weight=None) + out_bal = _fit_streaming_ridge( + model=nn.Identity(), + train_loader=_loader(X_tr_bal, y_tr_bal), + val_loader=_loader(X_val, y_val), + n_classes=C, + lambdas=[1.0], + device="cpu", + class_weight="balanced", + ) + torch.testing.assert_close(out_bal["W"], out_unw["W"]) + torch.testing.assert_close(out_bal["bias"], out_unw["bias"]) + + def test_fit_streaming_ridge_raises_on_nan_features(classif_data): """NaN features (e.g. from a buggy backbone) yield a clear error, not IndexError.""" from open_eeg_bench.ridge_probe import _fit_streaming_ridge