From 34fc6d93a09ed2c522473dfd389a1b8ed209c56c Mon Sep 17 00:00:00 2001 From: Krishnan Raghavan Date: Tue, 9 Jun 2026 12:53:30 -0500 Subject: [PATCH 1/2] Add prioritized sampling for CL importance weighting Introduces per-sample priority-based replay for continual-learning updates. At the start of each CL round the trainer rebuilds the current-task DataLoader with a WeightedRandomSampler whose weights are priority_i = (L(w_current, x_i) - L(theta_star, x_i))^alpha, so samples the model has forgotten relative to the previous CL anchor are sampled more often. Training loss itself is left unchanged (no gradient distortion). - BaseUpdater: theta_star anchor + _unreduced_criterion + compute_sample_priorities; cl_postprocessing refreshes the anchor. - ContinuousTrainer: gated by cl_updater.importance_weighting; rebuilds cur_train_loader with WeightedRandomSampler when enabled. - ContinualLearningCfg: new importance_weighting (default False) and importance_alpha (default 1.0) fields. - create_updater: assigns the two config values onto the updater. - examples/mnist + examples/cifar: optional current_ratio constructor arg drives a RandomSampler so current/historical loaders draw a balanced fraction of samples (no behavior change when current_ratio=1). - make_loader (both example utils): accepts a sampler kwarg that overrides shuffle. - Tests: new tests/test_importance_weighting.py (17 tests) and importance_weighting=False on MagicMock updaters in the outer-loop trainer tests so the new sampling branch stays off in mock-driven runs. Reference: Raghavan & Papadimitriou, FGCS 2025. --- examples/cifar/model.py | 51 ++- examples/cifar/src/utils.py | 9 +- examples/mnist/model.py | 47 ++- examples/mnist/utils.py | 9 +- src/apeiron/config/configuration.py | 6 + src/apeiron/training/continuous_trainer.py | 20 +- src/apeiron/training/updater/base.py | 55 ++- .../training/updater/create_updater.py | 34 +- tests/test_continuous_trainer.py | 3 + tests/test_importance_weighting.py | 336 ++++++++++++++++++ 10 files changed, 531 insertions(+), 39 deletions(-) create mode 100644 tests/test_importance_weighting.py diff --git a/examples/cifar/model.py b/examples/cifar/model.py index c1b25e2..d154683 100644 --- a/examples/cifar/model.py +++ b/examples/cifar/model.py @@ -3,7 +3,7 @@ from typing import Tuple, Optional, List, Dict, Any from torch import nn, Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, ConcatDataset +from torch.utils.data import DataLoader, ConcatDataset, RandomSampler from apeiron.model.torch_model_harness import BaseModelHarness from apeiron.config.configuration import Config @@ -102,9 +102,16 @@ class CIFAR_VISION(BaseModelHarness): - get_hist_data_loaders(): ConcatDataset over self.aug_history; then append self.cur_aug """ - def __init__(self, cfg: Config, model: Optional[nn.Module] = None): + def __init__( + self, + cfg: Config, + model: Optional[nn.Module] = None, + current_ratio: float = 1.0, + ): super().__init__(cfg=cfg, model=VisionModelCifar(cfg=cfg)) + self.current_ratio = current_ratio + # FULL datasets (no index split) self.ds_train = get_cifar_train(cfg=cfg, normalize=True) self.ds_val = get_cifar_val(cfg=cfg, normalize=True) @@ -153,9 +160,23 @@ def update_data_stream(self) -> None: nw = getattr(self.cfg.train, "num_workers", self.cfg.train.num_workers) pin = torch.cuda.is_available() - self._cur_train_loader = make_loader( - ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin - ) + if self.current_ratio < 1.0: + n_cur = int(len(ds_train_tf) * self.current_ratio) + cur_sampler = RandomSampler( + ds_train_tf, replacement=True, num_samples=n_cur + ) + self._cur_train_loader = make_loader( + ds_train_tf, + bs, + shuffle=True, + num_workers=nw, + pin_memory=pin, + sampler=cur_sampler, + ) + else: + self._cur_train_loader = make_loader( + ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin + ) self._cur_val_loader = make_loader( ds_val_tf, bs, shuffle=False, num_workers=nw, pin_memory=pin ) @@ -193,9 +214,23 @@ def get_hist_dataloaders( nw = getattr(self.cfg.data, "num_workers", self.cfg.train.num_workers) pin = torch.cuda.is_available() - hist_train_loader = make_loader( - ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin - ) + if self.current_ratio < 1.0: + n_hist = int(len(ds_hist_train) * (1.0 - self.current_ratio)) + hist_sampler = RandomSampler( + ds_hist_train, replacement=True, num_samples=n_hist + ) + hist_train_loader = make_loader( + ds_hist_train, + bs, + shuffle=True, + num_workers=nw, + pin_memory=pin, + sampler=hist_sampler, + ) + else: + hist_train_loader = make_loader( + ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin + ) hist_val_loader = make_loader( ds_hist_val, bs, shuffle=False, num_workers=nw, pin_memory=pin ) diff --git a/examples/cifar/src/utils.py b/examples/cifar/src/utils.py index a5c6041..c40f5b2 100644 --- a/examples/cifar/src/utils.py +++ b/examples/cifar/src/utils.py @@ -3,7 +3,7 @@ from typing import Dict, Any import torch from torch import nn -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import datasets, transforms import torchvision.transforms.functional as TF from apeiron.config.configuration import Config @@ -159,6 +159,7 @@ def make_loader( pin_memory: bool = True, persistent_workers: bool = True, prefetch_factor: int = 2, + sampler: "Sampler | None" = None, ) -> DataLoader: """ Builds a DataLoader from a given Dataset. @@ -185,7 +186,11 @@ def make_loader( DataLoader The built DataLoader. """ - kwargs = dict(batch_size=batch_size, shuffle=shuffle, drop_last=False) + kwargs: dict[str, Any] = dict(batch_size=batch_size, drop_last=False) + if sampler is not None: + kwargs["sampler"] = sampler + else: + kwargs["shuffle"] = shuffle if num_workers > 0: kwargs.update( dict( diff --git a/examples/mnist/model.py b/examples/mnist/model.py index 3dd5fe9..bcc711c 100644 --- a/examples/mnist/model.py +++ b/examples/mnist/model.py @@ -5,7 +5,7 @@ from typing import Tuple, Optional, List, Dict, Any from torch import nn, Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, ConcatDataset +from torch.utils.data import DataLoader, ConcatDataset, RandomSampler from apeiron.model.torch_model_harness import BaseModelHarness from apeiron.config.configuration import Config @@ -54,9 +54,12 @@ class MNIST_CNN(BaseModelHarness): - get_hist_dataloaders(): build loaders over FULL train/val using the chained augmentations of all historical loades """ - def __init__(self, cfg: Config, model: nn.Module = Cnn()): + def __init__( + self, cfg: Config, model: nn.Module = Cnn(), current_ratio: float = 1.0 + ): super().__init__(cfg=cfg, model=model) + self.current_ratio = current_ratio self.eval_metrics = {"accuracy": accuracy, "loss": self.get_criterion()} self.higher_is_better = {"accuracy": True, "loss": False} @@ -127,9 +130,23 @@ def update_data_stream(self) -> None: nw = getattr(self.cfg.train, "num_workers", self.cfg.train.num_workers) pin = torch.cuda.is_available() - self._cur_train_loader = make_loader( - ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin - ) + if self.current_ratio < 1.0: + n_cur = int(len(ds_train_tf) * self.current_ratio) + cur_sampler = RandomSampler( + ds_train_tf, replacement=True, num_samples=n_cur + ) + self._cur_train_loader = make_loader( + ds_train_tf, + bs, + shuffle=True, + num_workers=nw, + pin_memory=pin, + sampler=cur_sampler, + ) + else: + self._cur_train_loader = make_loader( + ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin + ) self._cur_val_loader = make_loader( ds_val_tf, bs, shuffle=False, num_workers=nw, pin_memory=pin ) @@ -166,9 +183,23 @@ def get_hist_dataloaders( nw = getattr(self.cfg.data, "num_workers", self.cfg.train.num_workers) pin = torch.cuda.is_available() - hist_train_loader = make_loader( - ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin - ) + if self.current_ratio < 1.0: + n_hist = int(len(ds_hist_train) * (1.0 - self.current_ratio)) + hist_sampler = RandomSampler( + ds_hist_train, replacement=True, num_samples=n_hist + ) + hist_train_loader = make_loader( + ds_hist_train, + bs, + shuffle=True, + num_workers=nw, + pin_memory=pin, + sampler=hist_sampler, + ) + else: + hist_train_loader = make_loader( + ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin + ) hist_val_loader = make_loader( ds_hist_val, bs, shuffle=False, num_workers=nw, pin_memory=pin ) diff --git a/examples/mnist/utils.py b/examples/mnist/utils.py index 297b0b0..fae790e 100644 --- a/examples/mnist/utils.py +++ b/examples/mnist/utils.py @@ -3,7 +3,7 @@ from typing import Tuple, Dict, List, Any import torch import torchvision -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import datasets, transforms import torchvision.transforms.functional as TF @@ -113,6 +113,7 @@ def make_loader( pin_memory: bool = True, persistent_workers: bool = True, prefetch_factor: int = 2, + sampler: "Sampler | None" = None, ) -> DataLoader: """ Builds a DataLoader from a given Dataset. @@ -139,7 +140,11 @@ def make_loader( DataLoader The built DataLoader. """ - kwargs = dict(batch_size=batch_size, shuffle=shuffle, drop_last=False) + kwargs: dict[str, Any] = dict(batch_size=batch_size, drop_last=False) + if sampler is not None: + kwargs["sampler"] = sampler + else: + kwargs["shuffle"] = shuffle if num_workers > 0: kwargs.update( dict( diff --git a/src/apeiron/config/configuration.py b/src/apeiron/config/configuration.py index e79a566..a7ccd60 100644 --- a/src/apeiron/config/configuration.py +++ b/src/apeiron/config/configuration.py @@ -124,6 +124,12 @@ class ContinualLearningCfg: kfac_lambda: float = 0.01 kfac_ema_decay: float = 0.95 + # Prioritized sampling (Raghavan & Papadimitriou, FGCS 2025) + importance_weighting: bool = False + importance_alpha: float = ( + 1.0 # priority exponent (1=linear, <1=flatter, >1=sharper) + ) + @dataclass(frozen=True) class DriftDetectionCfg: diff --git a/src/apeiron/training/continuous_trainer.py b/src/apeiron/training/continuous_trainer.py index 160e6b0..bd1990e 100644 --- a/src/apeiron/training/continuous_trainer.py +++ b/src/apeiron/training/continuous_trainer.py @@ -4,7 +4,7 @@ from typing import Any, Optional import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, WeightedRandomSampler from tqdm import tqdm from apeiron.config.configuration import Config @@ -70,6 +70,24 @@ def outer_cl_training_loop( cur_train_loader, cur_test_loader = self.modelHarness.get_train_dataloaders() hist_train_loader, hist_test_loader = self.modelHarness.get_hist_dataloaders() + # Prioritized sampling: rebuild loader with importance-based sampler + if self.cl_updater.importance_weighting and self.cl_updater.theta_star: + priorities = self.cl_updater.compute_sample_priorities( + cur_train_loader, self.cfg.device + ) + sampler = WeightedRandomSampler( + weights=priorities.tolist(), + num_samples=len(priorities), + replacement=True, + ) + cur_train_loader = DataLoader( + cur_train_loader.dataset, + batch_size=cur_train_loader.batch_size or self.cfg.train.batch_size, + sampler=sampler, + num_workers=cur_train_loader.num_workers, + drop_last=cur_train_loader.drop_last, + ) + train_iter = iter(cur_train_loader) if hist_train_loader is not None: hist_train_iter = iter(hist_train_loader) diff --git a/src/apeiron/training/updater/base.py b/src/apeiron/training/updater/base.py index 206476d..15505b5 100644 --- a/src/apeiron/training/updater/base.py +++ b/src/apeiron/training/updater/base.py @@ -1,9 +1,12 @@ from __future__ import annotations +import copy from typing import Callable, Optional import torch import torch.nn as nn +from torch.func import functional_call +from torch.utils.data import DataLoader from apeiron.config.configuration import Config from apeiron.model.torch_model_harness import BaseModelHarness @@ -27,6 +30,49 @@ def __init__(self, cfg: Config, modelHarness: BaseModelHarness) -> None: self.criterion: Callable[..., torch.Tensor] = modelHarness.get_criterion() self.model: nn.Module = modelHarness.model + # Anchor weights for importance weighting (shared across all updaters) + self.theta_star: dict[str, torch.Tensor] = { + n: p.detach().clone() + for n, p in self.model.named_parameters() + if p.requires_grad + } + self.importance_weighting: bool = False + self.importance_alpha: float = 1.0 + + def _unreduced_criterion( + self, outputs: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + """Compute per-sample loss (no reduction).""" + crit = copy.copy(self.criterion) + crit.reduction = "none" # type: ignore[attr-defined] + return crit(outputs, y) + + @torch.no_grad() + def compute_sample_priorities( + self, loader: DataLoader, device: str + ) -> torch.Tensor: + """Compute per-sample priority = (L_current - L_anchor)^alpha for prioritized sampling. + + Returns a 1-D tensor of priorities, one per sample in dataset order. + """ + self.model.eval() + anchor = {n: p.detach() for n, p in self.model.named_parameters()} + anchor.update(self.theta_star) + + all_deltas: list[torch.Tensor] = [] + for batch in loader: + x, y = batch[0].to(device), batch[1].to(device) + cur_loss = self._unreduced_criterion(self.model(x), y) + anchor_out = functional_call(self.model, anchor, (x,)) + anchor_loss = self._unreduced_criterion(anchor_out, y) + delta = (cur_loss - anchor_loss).clamp(min=1e-8) + all_deltas.append(delta.cpu()) + + self.model.train() + priorities = torch.cat(all_deltas) + # Apply alpha exponent (alpha=1 → linear, <1 → flatter, >1 → sharper) + return priorities.pow(self.importance_alpha) + def fwd_bwd( self, batch: tuple[torch.Tensor, torch.Tensor], @@ -49,8 +95,13 @@ def cl_preprocessing(self) -> None: @torch.no_grad() def cl_postprocessing(self) -> None: - """Hook called after the training loop ends""" - pass + """Hook called after the training loop ends. + + Updates theta_star to current model parameters. + """ + for n, p in self.model.named_parameters(): + if p.requires_grad and n in self.theta_star: + self.theta_star[n].copy_(p.detach()) @torch.no_grad() def update_pre_fwd_bwd(self) -> None: diff --git a/src/apeiron/training/updater/create_updater.py b/src/apeiron/training/updater/create_updater.py index 37e7d76..e1c2892 100644 --- a/src/apeiron/training/updater/create_updater.py +++ b/src/apeiron/training/updater/create_updater.py @@ -18,29 +18,31 @@ def create_updater(cfg: Config, modelHarness: BaseModelHarness) -> BaseUpdater: Raises: NotImplementedError: If the specified updater mode is not implemented. """ - if cfg.continual_learning.update_mode == "base": - return BaseUpdater(cfg=cfg, modelHarness=modelHarness) + updater: BaseUpdater - if cfg.continual_learning.update_mode == "ewc_online": + if cfg.continual_learning.update_mode == "base": + updater = BaseUpdater(cfg=cfg, modelHarness=modelHarness) + elif cfg.continual_learning.update_mode == "ewc_online": from apeiron.training.updater.ewc import OnlineEWCUpdater - return OnlineEWCUpdater(cfg=cfg, modelHarness=modelHarness) - - if cfg.continual_learning.update_mode == "kfac_online": + updater = OnlineEWCUpdater(cfg=cfg, modelHarness=modelHarness) + elif cfg.continual_learning.update_mode == "kfac_online": from apeiron.training.updater.kfac import OnlineKFACUpdater - return OnlineKFACUpdater(cfg=cfg, modelHarness=modelHarness) - - if cfg.continual_learning.update_mode == "jvp_reg": + updater = OnlineKFACUpdater(cfg=cfg, modelHarness=modelHarness) + elif cfg.continual_learning.update_mode == "jvp_reg": from apeiron.training.updater.jvp_reg import JVPRegUpdater - return JVPRegUpdater(cfg=cfg, modelHarness=modelHarness) - - if cfg.continual_learning.update_mode == "none": + updater = JVPRegUpdater(cfg=cfg, modelHarness=modelHarness) + elif cfg.continual_learning.update_mode == "none": from apeiron.training.updater.no_updater import NoUpdater - return NoUpdater(cfg=cfg, modelHarness=modelHarness) + updater = NoUpdater(cfg=cfg, modelHarness=modelHarness) + else: + raise NotImplementedError( + f"Unknown update_mode: {cfg.continual_learning.update_mode}" + ) - raise NotImplementedError( - f"Unknown update_mode: {cfg.continual_learning.update_mode}" - ) + updater.importance_weighting = cfg.continual_learning.importance_weighting + updater.importance_alpha = cfg.continual_learning.importance_alpha + return updater diff --git a/tests/test_continuous_trainer.py b/tests/test_continuous_trainer.py index 46db544..b581055 100644 --- a/tests/test_continuous_trainer.py +++ b/tests/test_continuous_trainer.py @@ -216,6 +216,7 @@ def test_updater_lifecycle(self, default_cfg, make_harness): mock_updater = MagicMock() mock_updater.fwd_bwd.return_value = 0.5 mock_updater.update_post_fwd_bwd.return_value = 0.1 + mock_updater.importance_weighting = False trainer.cl_updater = mock_updater trainer.outer_cl_training_loop(drift_event_id=1) @@ -246,6 +247,7 @@ def test_runs_max_iter_iterations(self, default_cfg, make_harness): mock_updater = MagicMock() mock_updater.fwd_bwd.return_value = 0.5 mock_updater.update_post_fwd_bwd.return_value = 0.1 + mock_updater.importance_weighting = False trainer.cl_updater = mock_updater trainer.outer_cl_training_loop(drift_event_id=1) @@ -321,6 +323,7 @@ def test_passes_history_batches(self, default_cfg, make_harness): mock_updater = MagicMock() mock_updater.fwd_bwd.return_value = 0.5 mock_updater.update_post_fwd_bwd.return_value = 0.1 + mock_updater.importance_weighting = False trainer.cl_updater = mock_updater trainer.outer_cl_training_loop(drift_event_id=1) diff --git a/tests/test_importance_weighting.py b/tests/test_importance_weighting.py new file mode 100644 index 0000000..ba5439e --- /dev/null +++ b/tests/test_importance_weighting.py @@ -0,0 +1,336 @@ +"""Tests for prioritized sampling importance weighting in CL training.""" + +from __future__ import annotations + +from dataclasses import replace + +import torch +import torch.nn as nn +from torch.utils.data import ( + DataLoader, + RandomSampler, + TensorDataset, + WeightedRandomSampler, +) + +from apeiron.config.configuration import ContinualLearningCfg +from apeiron.training.updater.base import BaseUpdater +from apeiron.training.updater.ewc import OnlineEWCUpdater +from apeiron.training.updater.kfac import OnlineKFACUpdater +from apeiron.training.updater.create_updater import create_updater + + +# --------------------------------------------------------------------------- +# TestUnreducedCriterion +# --------------------------------------------------------------------------- +class TestUnreducedCriterion: + def test_nll_loss_shape(self, default_cfg, make_harness): + """NLLLoss with reduction='none' returns per-sample tensor.""" + harness = make_harness(default_cfg) + harness.get_criterion = lambda: nn.NLLLoss() + updater = BaseUpdater(cfg=default_cfg, modelHarness=harness) + + outputs = torch.log_softmax(torch.randn(8, 3), dim=1) + y = torch.randint(0, 3, (8,)) + result = updater._unreduced_criterion(outputs, y) + assert result.shape == (8,) + + def test_cross_entropy_shape(self, default_cfg, make_harness): + """CrossEntropyLoss with reduction='none' returns per-sample tensor.""" + harness = make_harness(default_cfg) + updater = BaseUpdater(cfg=default_cfg, modelHarness=harness) + + outputs = torch.randn(8, 3) + y = torch.randint(0, 3, (8,)) + result = updater._unreduced_criterion(outputs, y) + assert result.shape == (8,) + + +# --------------------------------------------------------------------------- +# TestComputeSamplePriorities +# --------------------------------------------------------------------------- +class TestComputeSamplePriorities: + def test_priorities_shape(self, default_cfg, make_harness): + """compute_sample_priorities returns one priority per sample.""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(importance_weighting=True), + ) + harness = make_harness(cfg) + updater = create_updater(cfg, harness) + + ds = TensorDataset(torch.randn(20, 4), torch.randint(0, 3, (20,))) + loader = DataLoader(ds, batch_size=8) + priorities = updater.compute_sample_priorities(loader, "cpu") + assert priorities.shape == (20,) + + def test_priorities_positive(self, default_cfg, make_harness): + """All priorities should be positive (clamped at 1e-8).""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(importance_weighting=True), + ) + harness = make_harness(cfg) + updater = create_updater(cfg, harness) + + ds = TensorDataset(torch.randn(16, 4), torch.randint(0, 3, (16,))) + loader = DataLoader(ds, batch_size=8) + priorities = updater.compute_sample_priorities(loader, "cpu") + assert (priorities > 0).all() + + def test_priorities_differ_after_param_change(self, default_cfg, make_harness): + """After modifying params, priorities should differ from uniform.""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(importance_weighting=True), + ) + harness = make_harness(cfg) + updater = create_updater(cfg, harness) + + # Modify model away from anchor + with torch.no_grad(): + for p in harness.model.parameters(): + p.add_(torch.randn_like(p) * 0.5) + + ds = TensorDataset(torch.randn(16, 4), torch.randint(0, 3, (16,))) + loader = DataLoader(ds, batch_size=8) + priorities = updater.compute_sample_priorities(loader, "cpu") + # Not all priorities should be identical + assert priorities.std() > 0 + + def test_alpha_controls_sharpness(self, default_cfg, make_harness): + """Higher alpha should produce more varied priorities.""" + harness = make_harness(default_cfg) + + # Create updaters FIRST (anchors theta_star at current params) + updater_low = BaseUpdater(cfg=default_cfg, modelHarness=harness) + updater_low.importance_weighting = True + updater_low.importance_alpha = 0.5 + + updater_high = BaseUpdater(cfg=default_cfg, modelHarness=harness) + updater_high.importance_weighting = True + updater_high.importance_alpha = 2.0 + + # THEN modify model away from anchor + with torch.no_grad(): + for p in harness.model.parameters(): + p.add_(torch.randn_like(p) * 0.5) + + ds = TensorDataset(torch.randn(32, 4), torch.randint(0, 3, (32,))) + loader = DataLoader(ds, batch_size=16) + + p_low = updater_low.compute_sample_priorities(loader, "cpu") + p_high = updater_high.compute_sample_priorities(loader, "cpu") + + # Higher alpha → more variance in priorities + assert p_high.std() > p_low.std() + + def test_weighted_sampler_from_priorities(self, default_cfg, make_harness): + """Priorities can be used with WeightedRandomSampler.""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(importance_weighting=True), + ) + harness = make_harness(cfg) + updater = create_updater(cfg, harness) + + ds = TensorDataset(torch.randn(20, 4), torch.randint(0, 3, (20,))) + loader = DataLoader(ds, batch_size=8) + priorities = updater.compute_sample_priorities(loader, "cpu") + + sampler = WeightedRandomSampler( + priorities, num_samples=len(priorities), replacement=True + ) + new_loader = DataLoader(ds, batch_size=8, sampler=sampler) + batch = next(iter(new_loader)) + assert batch[0].shape[0] == 8 + + +# --------------------------------------------------------------------------- +# TestStandardFwdBwd +# --------------------------------------------------------------------------- +class TestStandardFwdBwd: + def test_fwd_bwd_uses_standard_loss(self, default_cfg, make_harness): + """fwd_bwd always uses standard criterion (sampling handles importance).""" + harness = make_harness(default_cfg) + updater = BaseUpdater(cfg=default_cfg, modelHarness=harness) + + x = torch.randn(4, 4) + y = torch.randint(0, 3, (4,)) + harness.model.zero_grad() + loss = updater.fwd_bwd((x, y)) + assert isinstance(loss, float) + assert loss >= 0.0 + + +# --------------------------------------------------------------------------- +# TestThetaStar +# --------------------------------------------------------------------------- +class TestThetaStar: + def test_base_initializes_theta_star(self, dummy_harness): + """BaseUpdater should initialize theta_star for all requires_grad params.""" + updater = BaseUpdater(cfg=dummy_harness.cfg, modelHarness=dummy_harness) + param_names = { + n for n, p in dummy_harness.model.named_parameters() if p.requires_grad + } + assert set(updater.theta_star.keys()) == param_names + + def test_ewc_inherits_theta_star(self, default_cfg, make_harness): + """EWC should use theta_star from BaseUpdater (no duplicate init).""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(update_mode="ewc_online"), + ) + harness = make_harness(cfg) + updater = OnlineEWCUpdater(cfg=cfg, modelHarness=harness) + param_names = { + n for n, p in harness.model.named_parameters() if p.requires_grad + } + assert set(updater.theta_star.keys()) == param_names + + def test_kfac_keeps_partial_theta_star(self, default_cfg, make_harness): + """KFAC stores theta_star only for supported layers (Linear/Conv2d).""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(update_mode="kfac_online"), + ) + harness = make_harness(cfg) + updater = OnlineKFACUpdater(cfg=cfg, modelHarness=harness) + # KFAC theta_star uses module names, not parameter names + assert len(updater.theta_star) > 0 + + def test_theta_star_updates_in_postprocessing(self, dummy_harness): + """cl_postprocessing should update theta_star to current params.""" + updater = BaseUpdater(cfg=dummy_harness.cfg, modelHarness=dummy_harness) + + # Modify model parameters + with torch.no_grad(): + for p in dummy_harness.model.parameters(): + p.add_(1.0) + + updater.cl_postprocessing() + + # theta_star should now match current params + for n, p in dummy_harness.model.named_parameters(): + if p.requires_grad and n in updater.theta_star: + assert torch.allclose(updater.theta_star[n], p.detach()) + + +# --------------------------------------------------------------------------- +# TestFunctionalCallWeighting +# --------------------------------------------------------------------------- +class TestFunctionalCallWeighting: + def test_anchor_loss_differs_after_training(self, default_cfg, make_harness): + """After modifying model params, anchor loss should differ from current loss.""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(importance_weighting=True), + ) + harness = make_harness(cfg) + updater = create_updater(cfg, harness) + + x = torch.randn(4, 4) + y = torch.randint(0, 3, (4,)) + + # Modify model + with torch.no_grad(): + for p in harness.model.parameters(): + p.add_(torch.randn_like(p) * 0.5) + + outputs = harness.model(x) + per_sample_loss = updater._unreduced_criterion(outputs, y) + + with torch.no_grad(): + from torch.func import functional_call + + anchor = {n: p.detach() for n, p in harness.model.named_parameters()} + anchor.update(updater.theta_star) + anchor_out = functional_call(harness.model, anchor, (x,)) + anchor_loss = updater._unreduced_criterion(anchor_out, y) + + # Losses should differ since params changed + assert not torch.allclose(per_sample_loss, anchor_loss, atol=1e-6) + + def test_no_gradients_through_anchor(self, default_cfg, make_harness): + """Anchor computation should not create gradients.""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg(importance_weighting=True), + ) + harness = make_harness(cfg) + updater = create_updater(cfg, harness) + + x = torch.randn(4, 4) + y = torch.randint(0, 3, (4,)) + + with torch.no_grad(): + anchor = {n: p.detach() for n, p in harness.model.named_parameters()} + anchor.update(updater.theta_star) + from torch.func import functional_call + + anchor_out = functional_call(harness.model, anchor, (x,)) + anchor_loss = updater._unreduced_criterion(anchor_out, y) + + assert not anchor_loss.requires_grad + + +# --------------------------------------------------------------------------- +# TestBalancedSampling +# --------------------------------------------------------------------------- +class TestBalancedSampling: + def test_sampler_overrides_shuffle(self): + """When sampler is provided, shuffle should not be passed (no error).""" + ds = TensorDataset(torch.randn(100, 4), torch.randint(0, 3, (100,))) + sampler = RandomSampler(ds, replacement=True, num_samples=50) + loader = DataLoader(ds, batch_size=8, sampler=sampler) + batch = next(iter(loader)) + assert batch[0].shape[0] == 8 + + def test_no_sampler_default_behavior(self): + """Without sampler, DataLoader uses shuffle normally.""" + ds = TensorDataset(torch.randn(100, 4), torch.randint(0, 3, (100,))) + loader = DataLoader(ds, batch_size=8, shuffle=True) + batch = next(iter(loader)) + assert batch[0].shape[0] == 8 + + def test_current_ratio_controls_num_samples(self): + """RandomSampler with num_samples controls how many samples are drawn.""" + ds = TensorDataset(torch.randn(100, 4), torch.randint(0, 3, (100,))) + ratio = 0.7 + n_samples = int(len(ds) * ratio) + sampler = RandomSampler(ds, replacement=True, num_samples=n_samples) + + loader = DataLoader(ds, batch_size=16, sampler=sampler) + total = sum(batch[0].shape[0] for batch in loader) + assert total == n_samples + + +# --------------------------------------------------------------------------- +# TestConfigFields +# --------------------------------------------------------------------------- +class TestConfigFields: + def test_default_values(self): + """ContinualLearningCfg should have correct defaults for importance fields.""" + cfg = ContinualLearningCfg() + assert cfg.importance_weighting is False + assert cfg.importance_alpha == 1.0 + + def test_custom_values(self): + """ContinualLearningCfg should accept custom importance values.""" + cfg = ContinualLearningCfg(importance_weighting=True, importance_alpha=0.5) + assert cfg.importance_weighting is True + assert cfg.importance_alpha == 0.5 + + def test_create_updater_passes_config(self, default_cfg, make_harness): + """create_updater should set importance fields on the updater.""" + cfg = replace( + default_cfg, + continual_learning=ContinualLearningCfg( + importance_weighting=True, + importance_alpha=2.0, + ), + ) + harness = make_harness(cfg) + updater = create_updater(cfg, harness) + assert updater.importance_weighting is True + assert updater.importance_alpha == 2.0 From 8f7f3612422e439a0353e9cf9f96c621391d7497 Mon Sep 17 00:00:00 2001 From: Krishnan Raghavan Date: Wed, 17 Jun 2026 17:20:40 -0500 Subject: [PATCH 2/2] Fix theta_star lifecycle and conditionally prioritize historical loader MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refresh theta_star in cl_preprocessing instead of cl_postprocessing so that the anchor lags the model by one CL round. Previously theta_star was overwritten with the post-training weights, which meant compute_sample_priorities always saw L_current == L_anchor on the next round and WeightedRandomSampler collapsed to uniform — making importance_weighting=true a no-op. EWC's Fisher commit and KFAC's A/G EMA commit stay in cl_postprocessing; only the anchor refresh moved. Add BaseUpdater.uses_hist_batch class flag (default False). JVPRegUpdater overrides to True since its fwd_bwd actually consumes hist_batch. The trainer rebuilds cur_train_loader with priority-weighted sampling on every CL round and additionally rebuilds hist_train_loader only when the updater's fwd_bwd will read it — so for EWC/KFAC, where the historical signal is meant to enter via mixing into the current loader, we don't waste work reshaping a loader whose batches get discarded. Extract the rebuild + priority-stats logging into ContinuousTrainer._rebuild_with_priorities and emit a tagged [priority/cur|hist] diagnostic line per round so the priority distribution (n, min/mean/max/std, effective-sample-size fraction) is visible in the run log. Update tests/test_importance_weighting.py to pin the new contract: cl_preprocessing refreshes theta_star, cl_postprocessing must not. Co-Authored-By: Claude Opus 4.7 --- src/apeiron/training/continuous_trainer.py | 81 ++++++++++++++++++---- src/apeiron/training/updater/base.py | 27 ++++++-- src/apeiron/training/updater/ewc.py | 27 +++++--- src/apeiron/training/updater/jvp_reg.py | 6 ++ src/apeiron/training/updater/kfac.py | 17 +++-- tests/test_importance_weighting.py | 36 ++++++++-- 6 files changed, 153 insertions(+), 41 deletions(-) diff --git a/src/apeiron/training/continuous_trainer.py b/src/apeiron/training/continuous_trainer.py index bd1990e..50bf9bc 100644 --- a/src/apeiron/training/continuous_trainer.py +++ b/src/apeiron/training/continuous_trainer.py @@ -61,6 +61,54 @@ def _safe_next( # If we cannot inspect batch size, just accept the batch return current_iter, [b.to(self.cfg.device) for b in batch] + def _rebuild_with_priorities( + self, + loader: DataLoader, + drift_event_id: int, + tag: str, + ) -> DataLoader: + """Rebuild a DataLoader so its sampler draws by per-sample priority. + + Computes priorities = (L_current − L_anchor)^alpha over the loader's + dataset, logs the distribution, and returns a new DataLoader wired + to a WeightedRandomSampler. + """ + logger = get_logger(__name__) + priorities = self.cl_updater.compute_sample_priorities(loader, self.cfg.device) + + # Diagnostic: per-round priority distribution. Non-trivial std means + # theta_star has drifted from the model and prioritized sampling will + # actually re-weight samples. Near-zero std → sampling collapses to + # uniform (the bug-fix regression signal). + with torch.no_grad(): + p = priorities.float() + p_mean = p.mean().item() + p_std = p.std().item() + p_min = p.min().item() + p_max = p.max().item() + # ESS = (sum w)^2 / sum w^2, ranges from 1 (degenerate) to N (uniform). + ess = (p.sum().item() ** 2) / (p.pow(2).sum().item() + 1e-30) + ess_frac = ess / p.numel() + logger.info( + f"[priority/{tag}] drift_event_id={drift_event_id} " + f"n={p.numel()} min={p_min:.3e} mean={p_mean:.3e} " + f"max={p_max:.3e} std={p_std:.3e} ess_frac={ess_frac:.3f}" + ) + + sampler = WeightedRandomSampler( + weights=priorities.tolist(), + num_samples=len(priorities), + replacement=True, + ) + + return DataLoader( + loader.dataset, + batch_size=loader.batch_size or self.cfg.train.batch_size, + sampler=sampler, + num_workers=loader.num_workers, + drop_last=loader.drop_last, + ) + def outer_cl_training_loop( self, drift_event_id: int = 0, @@ -70,25 +118,28 @@ def outer_cl_training_loop( cur_train_loader, cur_test_loader = self.modelHarness.get_train_dataloaders() hist_train_loader, hist_test_loader = self.modelHarness.get_hist_dataloaders() - # Prioritized sampling: rebuild loader with importance-based sampler + # Prioritized sampling. Always rebuild the current-task loader with + # importance-based weights when the gate is on. Additionally rebuild + # the historical loader for updaters that actually consume hist_batch + # in fwd_bwd (uses_hist_batch=True — jvp_reg). For EWC/KFAC the + # historical signal is expected to enter via mixing into the current + # loader, so reweighting hist_train_loader for them would be wasted + # work. if self.cl_updater.importance_weighting and self.cl_updater.theta_star: - priorities = self.cl_updater.compute_sample_priorities( - cur_train_loader, self.cfg.device - ) - sampler = WeightedRandomSampler( - weights=priorities.tolist(), - num_samples=len(priorities), - replacement=True, - ) - cur_train_loader = DataLoader( - cur_train_loader.dataset, - batch_size=cur_train_loader.batch_size or self.cfg.train.batch_size, - sampler=sampler, - num_workers=cur_train_loader.num_workers, - drop_last=cur_train_loader.drop_last, + cur_train_loader = self._rebuild_with_priorities( + cur_train_loader, + drift_event_id=drift_event_id, + tag="cur", ) + if self.cl_updater.uses_hist_batch and hist_train_loader is not None: + hist_train_loader = self._rebuild_with_priorities( + hist_train_loader, + drift_event_id=drift_event_id, + tag="hist", + ) train_iter = iter(cur_train_loader) + if hist_train_loader is not None: hist_train_iter = iter(hist_train_loader) else: diff --git a/src/apeiron/training/updater/base.py b/src/apeiron/training/updater/base.py index 15505b5..b7a8222 100644 --- a/src/apeiron/training/updater/base.py +++ b/src/apeiron/training/updater/base.py @@ -22,8 +22,16 @@ class BaseUpdater: cfg: Configuration object. criterion: Loss function for training. model: Neural network model to update. + uses_hist_batch: True iff fwd_bwd consumes hist_batch directly. + Read by the trainer to decide whether prioritizing the + *historical* loader has any effect. Default False — EWC / KFAC + receive their historical-data signal through mixing into the + current loader, not through the hist_batch argument, so + reweighting hist_train_loader for them would be wasted work. """ + uses_hist_batch: bool = False + def __init__(self, cfg: Config, modelHarness: BaseModelHarness) -> None: """Initialize updater with config and model harness.""" self.cfg = cfg @@ -36,6 +44,7 @@ def __init__(self, cfg: Config, modelHarness: BaseModelHarness) -> None: for n, p in self.model.named_parameters() if p.requires_grad } + self.importance_weighting: bool = False self.importance_alpha: float = 1.0 @@ -90,19 +99,23 @@ def fwd_bwd( @torch.no_grad() def cl_preprocessing(self) -> None: - """Hook called before before the training loop starts""" - pass - - @torch.no_grad() - def cl_postprocessing(self) -> None: - """Hook called after the training loop ends. + """Hook called before the training loop starts. - Updates theta_star to current model parameters. + Snapshots current model parameters into theta_star. Runs AFTER + compute_sample_priorities in the outer CL loop, so theta_star ends + up one round behind the model — the lag the prioritized-sampling + delta `L(w_current, x) - L(theta_star, x)` needs to be non-zero + on the next round. """ for n, p in self.model.named_parameters(): if p.requires_grad and n in self.theta_star: self.theta_star[n].copy_(p.detach()) + @torch.no_grad() + def cl_postprocessing(self) -> None: + """Hook called after the training loop ends.""" + pass + @torch.no_grad() def update_pre_fwd_bwd(self) -> None: """Hook called before gradient computation.""" diff --git a/src/apeiron/training/updater/ewc.py b/src/apeiron/training/updater/ewc.py index d418184..5070f16 100644 --- a/src/apeiron/training/updater/ewc.py +++ b/src/apeiron/training/updater/ewc.py @@ -47,7 +47,14 @@ def __init__(self, cfg: Config, modelHarness: BaseModelHarness) -> None: @torch.no_grad() def cl_preprocessing(self) -> None: - """Called once before the CL loop starts.""" + """Called once before the CL loop starts. + + Refreshes the EWC anchor θ* to the current model weights so that + the EWC quadratic penalty during this round pulls toward the + previous-task converged weights. Runs AFTER compute_sample_priorities + in the outer CL loop, so importance sampling sees the one-round + lag it needs. + """ # Allocate CL accumulators directly on correct device self._cl_fisher_accum = { n: torch.zeros_like(p, device=self.device) @@ -56,14 +63,22 @@ def cl_preprocessing(self) -> None: } self._cl_steps = 0 + # Refresh anchor θ* = current model weights (previous-task end) + for name, p in self.model.named_parameters(): + if p.requires_grad and name in self.theta_star: + self.theta_star[name].copy_(p.detach()) + @torch.no_grad() def cl_postprocessing(self) -> None: """ Called once after the CL loop finishes. - Commits the new prior: + Commits the new Fisher prior: F* <- fisher_decay * F* + F_cl_avg - θ* <- θ_final + + Anchor θ* is NOT refreshed here — it is refreshed at the next + cl_preprocessing, after compute_sample_priorities has had a chance + to read the stale value. """ if self._cl_steps == 0: return @@ -75,12 +90,6 @@ def cl_postprocessing(self) -> None: self.fisher[name].mul_(self.fisher_decay) self.fisher[name].add_(self._cl_fisher_accum[name] / float(self._cl_steps)) - # Update anchor θ* - with torch.no_grad(): - for name, p in self.model.named_parameters(): - if p.requires_grad: - self.theta_star[name].copy_(p.detach()) - # cleanup self._cl_fisher_accum = None self._cl_steps = 0 diff --git a/src/apeiron/training/updater/jvp_reg.py b/src/apeiron/training/updater/jvp_reg.py index e4b615f..4533da6 100644 --- a/src/apeiron/training/updater/jvp_reg.py +++ b/src/apeiron/training/updater/jvp_reg.py @@ -23,6 +23,12 @@ class JVPRegUpdater(BaseUpdater): regularization term to prevent catastrophic forgetting. """ + # fwd_bwd reads hist_batch (memory buffer) directly — JVP term and + # the explicit historical-replay backward both depend on it. Signals + # the trainer that prioritizing hist_train_loader will actually move + # the gradient for this updater. + uses_hist_batch: bool = True + def __init__(self, cfg: Config, modelHarness: BaseModelHarness) -> None: """Initialize JVP updater with config and model harness.""" super().__init__(cfg, modelHarness) diff --git a/src/apeiron/training/updater/kfac.py b/src/apeiron/training/updater/kfac.py index bd466f8..ad0a30c 100644 --- a/src/apeiron/training/updater/kfac.py +++ b/src/apeiron/training/updater/kfac.py @@ -97,6 +97,12 @@ def hook(module, grad_input, grad_output): # ------------------------------------------------------------------ @torch.no_grad() def cl_preprocessing(self): + """Refreshes KFAC anchor θ* and resets per-round accumulators. + + θ* is set to the current model weights here (not in postprocessing) + so that compute_sample_priorities in the outer CL loop reads the + previous round's anchor before it gets overwritten. + """ self._A_accum = { k: torch.zeros_like(v, device=self.device) for k, v in self.A.items() } @@ -105,8 +111,14 @@ def cl_preprocessing(self): } self._cl_steps = 0 + # Refresh anchor θ* = current per-module weights (previous-task end). + for name, module in self.model.named_modules(): + if self._supported(module): + self.theta_star[name].copy_(module.weight.detach()) + @torch.no_grad() def cl_postprocessing(self): + """Commits KFAC factor EMAs. Anchor θ* is NOT touched here.""" if self._cl_steps == 0: return if self._A_accum is None or self._G_accum is None: @@ -119,11 +131,6 @@ def cl_postprocessing(self): self.A[name].add_(self._A_accum[name] / self._cl_steps) self.G[name].add_(self._G_accum[name] / self._cl_steps) - with torch.no_grad(): - for name, module in self.model.named_modules(): - if self._supported(module): - self.theta_star[name].copy_(module.weight.detach()) - self._A_accum = None self._G_accum = None self._cl_steps = 0 diff --git a/tests/test_importance_weighting.py b/tests/test_importance_weighting.py index ba5439e..4257f0a 100644 --- a/tests/test_importance_weighting.py +++ b/tests/test_importance_weighting.py @@ -199,22 +199,48 @@ def test_kfac_keeps_partial_theta_star(self, default_cfg, make_harness): # KFAC theta_star uses module names, not parameter names assert len(updater.theta_star) > 0 - def test_theta_star_updates_in_postprocessing(self, dummy_harness): - """cl_postprocessing should update theta_star to current params.""" + def test_theta_star_updates_in_preprocessing(self, dummy_harness): + """cl_preprocessing should snapshot current params into theta_star.""" updater = BaseUpdater(cfg=dummy_harness.cfg, modelHarness=dummy_harness) - # Modify model parameters + # Modify model parameters AFTER init so theta_star is stale. with torch.no_grad(): for p in dummy_harness.model.parameters(): p.add_(1.0) - updater.cl_postprocessing() + updater.cl_preprocessing() - # theta_star should now match current params + # theta_star should now match current params. for n, p in dummy_harness.model.named_parameters(): if p.requires_grad and n in updater.theta_star: assert torch.allclose(updater.theta_star[n], p.detach()) + def test_theta_star_lags_after_postprocessing(self, dummy_harness): + """cl_postprocessing must NOT refresh theta_star. + + The prioritized-sampling delta L(w_current) - L(theta_star) is only + non-zero if theta_star lags the model by one CL round. Postprocessing + must therefore leave theta_star untouched; the refresh happens in + the NEXT round's cl_preprocessing (after priorities are computed). + """ + updater = BaseUpdater(cfg=dummy_harness.cfg, modelHarness=dummy_harness) + + # Snapshot theta_star as it stands right after init. + original = {n: t.detach().clone() for n, t in updater.theta_star.items()} + + # Simulate a training round moving the model. + with torch.no_grad(): + for p in dummy_harness.model.parameters(): + p.add_(1.0) + + updater.cl_postprocessing() + + # theta_star must still equal the pre-training snapshot. + for n, t in updater.theta_star.items(): + assert torch.allclose(t, original[n]), ( + f"cl_postprocessing modified theta_star[{n}] — lifecycle bug" + ) + # --------------------------------------------------------------------------- # TestFunctionalCallWeighting