Skip to content

feat(train): add prioritized sampling for CL importance weighting#105

Open
krm9c wants to merge 2 commits into
mainfrom
prioritized-sampling
Open

feat(train): add prioritized sampling for CL importance weighting#105
krm9c wants to merge 2 commits into
mainfrom
prioritized-sampling

Conversation

@krm9c

@krm9c krm9c commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Introduces per-sample priority-based replay for continual-learning updates. At the start of each CL round the trainer rebuilds the current-task DataLoader with a WeightedRandomSampler whose weights are priority_i = (L(w_current, x_i) - L(theta_star, x_i))^alpha, so samples the model has forgotten relative to the previous CL anchor are sampled more often. Training loss itself is left unchanged (no gradient distortion).

  • BaseUpdater: theta_star anchor + _unreduced_criterion + compute_sample_priorities; cl_postprocessing refreshes the anchor.
  • ContinuousTrainer: gated by cl_updater.importance_weighting; rebuilds cur_train_loader with WeightedRandomSampler when enabled.
  • ContinualLearningCfg: new importance_weighting (default False) and importance_alpha (default 1.0) fields.
  • create_updater: assigns the two config values onto the updater.
  • examples/mnist + examples/cifar: optional current_ratio constructor arg drives a RandomSampler so current/historical loaders draw a balanced fraction of samples (no behavior change when current_ratio=1).
  • make_loader (both example utils): accepts a sampler kwarg that overrides shuffle.
  • Tests: new tests/test_importance_weighting.py (17 tests) and importance_weighting=False on MagicMock updaters in the outer-loop trainer tests so the new sampling branch stays off in mock-driven runs.

Summary

Adds prioritized replay sampling to the CL pipeline. When continual_learning.importance_weighting = true, every CL round draws current-task samples proportionally to how much each sample's loss has worsened since the previous CL anchor (theta_star). The standard training loss is unchanged, so this is purely a data-distribution intervention — no gradient distortion.

Motivation & Context

Forgetting under continual learning is concentrated on a subset of samples that the most recent update step degraded; uniformly sampling the current-task loader spends most of the compute budget on already-easy examples. Replaying the high-forgetting samples more often is a low-overhead way to push the CL update toward the regions of input space that actually need it. Method follows Raghavan & Papadimitriou, FGCS 2025.

This replaces an earlier, abandoned per-sample loss-weighting variant that distorted gradients via softmax weights and was sensitive to a temperature hyperparameter; prioritized sampling decouples the "what to focus on" signal from the loss itself.

Approach

  • Lift the anchor weights (theta_star) and an unreduced-loss helper (_unreduced_criterion) onto BaseUpdater so every updater can use them.
  • Add BaseUpdater.compute_sample_priorities(loader, device): runs the current model and functional_call(model, theta_star, x) over the loader, returns (L_cur - L_anchor).clamp(min=1e-8).pow(alpha) as a 1-D priority tensor (eval mode, @torch.no_grad).
  • In ContinuousTrainer.outer_cl_training_loop, when cl_updater.importance_weighting and cl_updater.theta_star, rebuild cur_train_loader with WeightedRandomSampler(weights=priorities.tolist(), num_samples=len(priorities), replacement=True).
  • BaseUpdater.cl_postprocessing refreshes theta_star to the post-CL parameters, so the next round measures forgetting relative to "where we ended last time."
  • Example harnesses get an optional current_ratio constructor arg (default 1.0, fully backward-compatible) that lets the current/historical loaders each draw a fraction of one dataset via RandomSampler — useful when running paired ablations.

Trade-off: compute_sample_priorities adds one extra full forward pass over the current-task dataset per CL round (not per inner iter). Architecture is signal-agnostic — the priority can be swapped for Fisher info, predictive entropy, ensemble disagreement, etc. without changing the trainer.

Screenshots / Logs (optional)

End-to-end MNIST run with update_mode=ewc_online, importance_weighting=true, importance_alpha=1.0 is still in progress at PR-open time. Will paste accuracy / forgetting numbers as a comment once it finishes.

API / CLI Changes

  • ContinualLearningCfg.importance_weighting: bool = False (new)
  • ContinualLearningCfg.importance_alpha: float = 1.0 (new)
  • BaseUpdater.theta_star: dict[str, Tensor] (new attribute)
  • BaseUpdater._unreduced_criterion(outputs, y) -> Tensor (new)
  • BaseUpdater.compute_sample_priorities(loader, device) -> Tensor (new)
  • BaseUpdater.cl_postprocessing() now refreshes theta_star (was a no-op).
  • MNIST_CNN.__init__(..., current_ratio: float = 1.0) (new kwarg, default preserves old behavior)
  • CIFAR_VISION.__init__(..., current_ratio: float = 1.0) (new kwarg, default preserves old behavior)
  • examples/mnist/utils.py::make_loader(..., sampler: Sampler | None = None) (new kwarg, overrides shuffle when provided)
  • examples/cifar/src/utils.py::make_loader(..., sampler: Sampler | None = None) (new kwarg, overrides shuffle when provided)

Breaking Changes

  • None. All new fields and kwargs have defaults that reproduce existing behavior. Existing TOML configs continue to load and run unchanged.

Performance (optional)

  • Extra cost per CL round when importance_weighting=true: one full forward pass over the current-task dataset to compute priorities, plus the model's own forward pass under functional_call(theta_star) — i.e. ~2× one inference pass over the dataset, once per drift event. Negligible relative to max_iter inner training iters at typical settings.
  • No measurable cost when importance_weighting=false (the trainer skips the entire block).
  • End-to-end MNIST + EWC benchmark numbers will be appended once the smoke run completes.

Security & Privacy

  • No secrets committed
  • Input validation added where needed (priorities clamped at 1e-8 to avoid zero/negative weights in WeightedRandomSampler)

Dependencies

  • None added or removed. Uses existing torch.func.functional_call, torch.utils.data.WeightedRandomSampler, torch.utils.data.RandomSampler.

Testing Plan

  • Unit tests — new tests/test_importance_weighting.py (17 tests): unreduced criterion shapes, priority shape/positivity/variance/alpha sharpness, WeightedRandomSampler integration, fwd_bwd standard-loss invariant, theta_star init/inheritance/postprocessing, anchor-vs-current loss divergence, no-grad through anchor, balanced sampling, config field defaults/customs, create_updater wiring.
  • Integration tests — existing tests/test_continuous_trainer.py passes; three outer-loop tests set importance_weighting=False on MagicMock updaters so the new sampling branch stays off (otherwise MagicMock.importance_weighting is truthy and the trainer tries to call compute_sample_priorities on the mock).
  • e2e / smoke test — MNIST + EWC + importance_weighting=true run in progress at PR-open time; results to be added as a comment.
  • Manual steps: poetry run python -m src.main --config examples/mnist/mnist.toml --set continual_learning.update_mode=ewc_online --set continual_learning.importance_weighting=true --set continual_learning.importance_alpha=1.0

Documentation

  • Docstrings updated (compute_sample_priorities, _unreduced_criterion, cl_postprocessing, current_ratio in example harnesses)
  • User docs / README updated — not in this PR; defaults are off, no user-facing TOML change is required.
  • CHANGELOG entry — repo has no CHANGELOG file.

Checklist

  • Code formatted (Ruff) → ruff format --check
  • Lint passes (Ruff) → ruff check .
  • Types pass (mypy) → mypy .
  • Tests pass (pytest) → pytest (197 passed, 1 pre-existing failure test_mnist_first_drift_losses_match_reference that also fails on clean main, 11 skipped)
  • Backward compatibility considered (all new fields/kwargs have defaults that reproduce existing behavior)
  • Adequate comments for tricky parts (anchor construction in compute_sample_priorities, sampler-vs-shuffle in make_loader)
  • CI green — pending CI run on this branch

Risk & Rollback Plan

Low risk. The new behavior is fully gated by continual_learning.importance_weighting, which defaults to False. Existing configs and existing runs are byte-identical to main. Rollback is a straight revert of this PR — no schema / data migrations involved.

Notes for Reviewers

  • Start at src/apeiron/training/updater/base.py (the new compute_sample_priorities is the substantive change) and then src/apeiron/training/continuous_trainer.py (the trainer hook).
  • The current_ratio plumbing in examples/{mnist,cifar}/... is an independent feature bundled in the same commit because it's the experimental setup the prioritized-sampling ablation was run with. Skipping it would leave the example harnesses unable to reproduce the published numbers.
  • I did not touch src/apeiron/training/updater/jvp_reg.py — the old loss-weighting code that the abandoned variant added there does not exist on current main (jvp_reg was rewritten upstream), so there is nothing to remove. compute_sample_priorities is decoupled from the updater type and works with jvp_reg, ewc_online, kfac_online, and base alike.
  • EWC's __init__ still re-creates its own theta_star after super().__init__ runs; that is a harmless re-bind (the base's theta_star is overwritten by EWC's device-placed version). Intentionally left as-is per discussion.

Introduces per-sample priority-based replay for continual-learning updates.
At the start of each CL round the trainer rebuilds the current-task
DataLoader with a WeightedRandomSampler whose weights are
priority_i = (L(w_current, x_i) - L(theta_star, x_i))^alpha, so samples
the model has forgotten relative to the previous CL anchor are sampled
more often. Training loss itself is left unchanged (no gradient
distortion).

- BaseUpdater: theta_star anchor + _unreduced_criterion +
  compute_sample_priorities; cl_postprocessing refreshes the anchor.
- ContinuousTrainer: gated by cl_updater.importance_weighting; rebuilds
  cur_train_loader with WeightedRandomSampler when enabled.
- ContinualLearningCfg: new importance_weighting (default False) and
  importance_alpha (default 1.0) fields.
- create_updater: assigns the two config values onto the updater.
- examples/mnist + examples/cifar: optional current_ratio constructor
  arg drives a RandomSampler so current/historical loaders draw a
  balanced fraction of samples (no behavior change when current_ratio=1).
- make_loader (both example utils): accepts a sampler kwarg that
  overrides shuffle.
- Tests: new tests/test_importance_weighting.py (17 tests) and
  importance_weighting=False on MagicMock updaters in the outer-loop
  trainer tests so the new sampling branch stays off in mock-driven runs.

Reference: Raghavan & Papadimitriou, FGCS 2025.
@krm9c krm9c requested review from ScSteffen and anagainaru June 9, 2026 18:08
Refresh theta_star in cl_preprocessing instead of cl_postprocessing so
that the anchor lags the model by one CL round. Previously theta_star
was overwritten with the post-training weights, which meant
compute_sample_priorities always saw L_current == L_anchor on the next
round and WeightedRandomSampler collapsed to uniform — making
importance_weighting=true a no-op. EWC's Fisher commit and KFAC's A/G
EMA commit stay in cl_postprocessing; only the anchor refresh moved.

Add BaseUpdater.uses_hist_batch class flag (default False). JVPRegUpdater
overrides to True since its fwd_bwd actually consumes hist_batch. The
trainer rebuilds cur_train_loader with priority-weighted sampling on
every CL round and additionally rebuilds hist_train_loader only when
the updater's fwd_bwd will read it — so for EWC/KFAC, where the
historical signal is meant to enter via mixing into the current loader,
we don't waste work reshaping a loader whose batches get discarded.

Extract the rebuild + priority-stats logging into
ContinuousTrainer._rebuild_with_priorities and emit a tagged
[priority/cur|hist] diagnostic line per round so the priority
distribution (n, min/mean/max/std, effective-sample-size fraction) is
visible in the run log.

Update tests/test_importance_weighting.py to pin the new contract:
cl_preprocessing refreshes theta_star, cl_postprocessing must not.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant