feat(train): add prioritized sampling for CL importance weighting#105
Open
krm9c wants to merge 2 commits into
Open
feat(train): add prioritized sampling for CL importance weighting#105krm9c wants to merge 2 commits into
krm9c wants to merge 2 commits into
Conversation
Introduces per-sample priority-based replay for continual-learning updates. At the start of each CL round the trainer rebuilds the current-task DataLoader with a WeightedRandomSampler whose weights are priority_i = (L(w_current, x_i) - L(theta_star, x_i))^alpha, so samples the model has forgotten relative to the previous CL anchor are sampled more often. Training loss itself is left unchanged (no gradient distortion). - BaseUpdater: theta_star anchor + _unreduced_criterion + compute_sample_priorities; cl_postprocessing refreshes the anchor. - ContinuousTrainer: gated by cl_updater.importance_weighting; rebuilds cur_train_loader with WeightedRandomSampler when enabled. - ContinualLearningCfg: new importance_weighting (default False) and importance_alpha (default 1.0) fields. - create_updater: assigns the two config values onto the updater. - examples/mnist + examples/cifar: optional current_ratio constructor arg drives a RandomSampler so current/historical loaders draw a balanced fraction of samples (no behavior change when current_ratio=1). - make_loader (both example utils): accepts a sampler kwarg that overrides shuffle. - Tests: new tests/test_importance_weighting.py (17 tests) and importance_weighting=False on MagicMock updaters in the outer-loop trainer tests so the new sampling branch stays off in mock-driven runs. Reference: Raghavan & Papadimitriou, FGCS 2025.
Refresh theta_star in cl_preprocessing instead of cl_postprocessing so that the anchor lags the model by one CL round. Previously theta_star was overwritten with the post-training weights, which meant compute_sample_priorities always saw L_current == L_anchor on the next round and WeightedRandomSampler collapsed to uniform — making importance_weighting=true a no-op. EWC's Fisher commit and KFAC's A/G EMA commit stay in cl_postprocessing; only the anchor refresh moved. Add BaseUpdater.uses_hist_batch class flag (default False). JVPRegUpdater overrides to True since its fwd_bwd actually consumes hist_batch. The trainer rebuilds cur_train_loader with priority-weighted sampling on every CL round and additionally rebuilds hist_train_loader only when the updater's fwd_bwd will read it — so for EWC/KFAC, where the historical signal is meant to enter via mixing into the current loader, we don't waste work reshaping a loader whose batches get discarded. Extract the rebuild + priority-stats logging into ContinuousTrainer._rebuild_with_priorities and emit a tagged [priority/cur|hist] diagnostic line per round so the priority distribution (n, min/mean/max/std, effective-sample-size fraction) is visible in the run log. Update tests/test_importance_weighting.py to pin the new contract: cl_preprocessing refreshes theta_star, cl_postprocessing must not. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Introduces per-sample priority-based replay for continual-learning updates. At the start of each CL round the trainer rebuilds the current-task DataLoader with a WeightedRandomSampler whose weights are priority_i = (L(w_current, x_i) - L(theta_star, x_i))^alpha, so samples the model has forgotten relative to the previous CL anchor are sampled more often. Training loss itself is left unchanged (no gradient distortion).
Summary
Adds prioritized replay sampling to the CL pipeline. When
continual_learning.importance_weighting = true, every CL round draws current-task samples proportionally to how much each sample's loss has worsened since the previous CL anchor (theta_star). The standard training loss is unchanged, so this is purely a data-distribution intervention — no gradient distortion.Motivation & Context
Forgetting under continual learning is concentrated on a subset of samples that the most recent update step degraded; uniformly sampling the current-task loader spends most of the compute budget on already-easy examples. Replaying the high-forgetting samples more often is a low-overhead way to push the CL update toward the regions of input space that actually need it. Method follows Raghavan & Papadimitriou, FGCS 2025.
This replaces an earlier, abandoned per-sample loss-weighting variant that distorted gradients via softmax weights and was sensitive to a temperature hyperparameter; prioritized sampling decouples the "what to focus on" signal from the loss itself.
Approach
theta_star) and an unreduced-loss helper (_unreduced_criterion) ontoBaseUpdaterso every updater can use them.BaseUpdater.compute_sample_priorities(loader, device): runs the current model andfunctional_call(model, theta_star, x)over the loader, returns(L_cur - L_anchor).clamp(min=1e-8).pow(alpha)as a 1-D priority tensor (eval mode,@torch.no_grad).ContinuousTrainer.outer_cl_training_loop, whencl_updater.importance_weighting and cl_updater.theta_star, rebuildcur_train_loaderwithWeightedRandomSampler(weights=priorities.tolist(), num_samples=len(priorities), replacement=True).BaseUpdater.cl_postprocessingrefreshestheta_starto the post-CL parameters, so the next round measures forgetting relative to "where we ended last time."current_ratioconstructor arg (default1.0, fully backward-compatible) that lets the current/historical loaders each draw a fraction of one dataset viaRandomSampler— useful when running paired ablations.Trade-off:
compute_sample_prioritiesadds one extra full forward pass over the current-task dataset per CL round (not per inner iter). Architecture is signal-agnostic — the priority can be swapped for Fisher info, predictive entropy, ensemble disagreement, etc. without changing the trainer.Screenshots / Logs (optional)
End-to-end MNIST run with
update_mode=ewc_online,importance_weighting=true,importance_alpha=1.0is still in progress at PR-open time. Will paste accuracy / forgetting numbers as a comment once it finishes.API / CLI Changes
ContinualLearningCfg.importance_weighting: bool = False(new)ContinualLearningCfg.importance_alpha: float = 1.0(new)BaseUpdater.theta_star: dict[str, Tensor](new attribute)BaseUpdater._unreduced_criterion(outputs, y) -> Tensor(new)BaseUpdater.compute_sample_priorities(loader, device) -> Tensor(new)BaseUpdater.cl_postprocessing()now refreshestheta_star(was a no-op).MNIST_CNN.__init__(..., current_ratio: float = 1.0)(new kwarg, default preserves old behavior)CIFAR_VISION.__init__(..., current_ratio: float = 1.0)(new kwarg, default preserves old behavior)examples/mnist/utils.py::make_loader(..., sampler: Sampler | None = None)(new kwarg, overridesshufflewhen provided)examples/cifar/src/utils.py::make_loader(..., sampler: Sampler | None = None)(new kwarg, overridesshufflewhen provided)Breaking Changes
Performance (optional)
importance_weighting=true: one full forward pass over the current-task dataset to compute priorities, plus the model's own forward pass underfunctional_call(theta_star)— i.e. ~2× one inference pass over the dataset, once per drift event. Negligible relative tomax_iterinner training iters at typical settings.importance_weighting=false(the trainer skips the entire block).Security & Privacy
WeightedRandomSampler)Dependencies
torch.func.functional_call,torch.utils.data.WeightedRandomSampler,torch.utils.data.RandomSampler.Testing Plan
tests/test_importance_weighting.py(17 tests): unreduced criterion shapes, priority shape/positivity/variance/alpha sharpness,WeightedRandomSamplerintegration,fwd_bwdstandard-loss invariant, theta_star init/inheritance/postprocessing, anchor-vs-current loss divergence, no-grad through anchor, balanced sampling, config field defaults/customs,create_updaterwiring.tests/test_continuous_trainer.pypasses; three outer-loop tests setimportance_weighting=FalseonMagicMockupdaters so the new sampling branch stays off (otherwiseMagicMock.importance_weightingis truthy and the trainer tries to callcompute_sample_prioritieson the mock).importance_weighting=truerun in progress at PR-open time; results to be added as a comment.poetry run python -m src.main --config examples/mnist/mnist.toml --set continual_learning.update_mode=ewc_online --set continual_learning.importance_weighting=true --set continual_learning.importance_alpha=1.0Documentation
compute_sample_priorities,_unreduced_criterion,cl_postprocessing,current_ratioin example harnesses)Checklist
ruff format --checkruff check .mypy .pytest(197 passed, 1 pre-existing failuretest_mnist_first_drift_losses_match_referencethat also fails on cleanmain, 11 skipped)compute_sample_priorities, sampler-vs-shuffle inmake_loader)Risk & Rollback Plan
Low risk. The new behavior is fully gated by
continual_learning.importance_weighting, which defaults toFalse. Existing configs and existing runs are byte-identical tomain. Rollback is a straight revert of this PR — no schema / data migrations involved.Notes for Reviewers
src/apeiron/training/updater/base.py(the newcompute_sample_prioritiesis the substantive change) and thensrc/apeiron/training/continuous_trainer.py(the trainer hook).current_ratioplumbing inexamples/{mnist,cifar}/...is an independent feature bundled in the same commit because it's the experimental setup the prioritized-sampling ablation was run with. Skipping it would leave the example harnesses unable to reproduce the published numbers.src/apeiron/training/updater/jvp_reg.py— the old loss-weighting code that the abandoned variant added there does not exist on currentmain(jvp_reg was rewritten upstream), so there is nothing to remove.compute_sample_prioritiesis decoupled from the updater type and works withjvp_reg,ewc_online,kfac_online, andbasealike.__init__still re-creates its owntheta_staraftersuper().__init__runs; that is a harmless re-bind (the base'stheta_staris overwritten by EWC's device-placed version). Intentionally left as-is per discussion.