Skip to content

Ridge probe: balanced class weighting + configurable dtype#32

Merged
PierreGtch merged 4 commits into
mainfrom
balance_ridge
Apr 28, 2026
Merged

Ridge probe: balanced class weighting + configurable dtype#32
PierreGtch merged 4 commits into
mainfrom
balance_ridge

Conversation

@PierreGtch
Copy link
Copy Markdown
Contributor

@PierreGtch PierreGtch commented Apr 28, 2026

Summary

  • Add class_weight to RidgeProbingTraining ("balanced" or None); default changed to "balanced" so imbalanced classification datasets are no longer collapsed to majority-class predictions.
  • Add dtype to RidgeProbingTraining ("float32" or "float64", default "float64"). "float32" enables running on devices without double support, notably Apple MPS.
  • Run the eigendecomposition + Ws/biases construction on CPU unconditionally (the matrices are at most max_features × max_features so the detour is free) — torch.linalg.eigh is not implemented on MPS. Streaming forward + accumulation stay on the configured device.

Validation: REVE × chbmit (97/3 imbalance)

device / dtype class_weight test_balanced_accuracy fit_time
cpu / float64 None (previous default) 0.5000 (chance) 1158.8 s
cpu / float64 "balanced" (new default) 0.8594 1235.6 s
mps / float32 "balanced" 0.8594 651.5 s

Test plan

  • pytest tests/test_ridge_probe.py (14/14 pass, including new test_balanced_class_weight_recovers_minority and test_balanced_class_weight_noop_when_classes_balanced)
  • pytest tests/test_default_configs.py (44/44 pass)
  • End-to-end REVE × chbmit on cpu/float64 with both class_weight settings
  • End-to-end REVE × chbmit on mps/float32 with class_weight="balanced"

Adds a `class_weight` parameter (default `"balanced"`) to the streaming
ridge probe. When enabled, an extra label-only pass over the train loader
computes sklearn-style per-class weights `w[c] = N / (n_classes * count[c])`,
which are then applied to every sufficient-statistic accumulator (A, B,
s_h, s_h2, s_y, N) so that the weighted-least-squares fit, weighted
centering, and weighted standardization are all internally consistent.
Regression silently ignores the parameter.

Verified end-to-end on REVE × chbmit (97/3 imbalance):
  unweighted:    test_balanced_accuracy = 0.5000 (chance)
  balanced:      test_balanced_accuracy = 0.8594
Adds a `dtype: Literal["float32", "float64"]` parameter (default
`"float64"`) threaded through `RidgeProbingTraining`,
`StreamingRidgeProbeLearner`, and `_fit_streaming_ridge`. All previously
hardcoded float64 accumulators, eigendecomposition tensors, and predict
paths now honor this dtype.

`"float64"` remains the recommended precision; `"float32"` exists for
devices that don't support double, notably Apple MPS.

To make MPS actually work, the eigendecomposition + Ws/biases construction
(operating on `(D, D)` matrices ≤ `max_features`) is now run on CPU
unconditionally — `torch.linalg.eigh` is not implemented on MPS, and these
matrices are small enough that the CPU detour is free. The streaming
backbone forward and statistics accumulation stay on the configured device.

Verified REVE × chbmit on MPS with `class_weight="balanced"`,
`dtype="float32"`: test_balanced_accuracy=0.8594 (matches CPU/float64),
fit_time 651s (vs 1235s on CPU).
@PierreGtch
Copy link
Copy Markdown
Contributor Author

CC @tomMoral

@PierreGtch PierreGtch merged commit 789d87a into main Apr 28, 2026
6 checks passed
@tomMoral tomMoral deleted the balance_ridge branch April 28, 2026 21:49
Comment thread open_eeg_bench/ridge_probe.py
Comment thread open_eeg_bench/ridge_probe.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants