Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions examples/cifar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, Optional, List, Dict, Any
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data import DataLoader, ConcatDataset, RandomSampler

from apeiron.model.torch_model_harness import BaseModelHarness
from apeiron.config.configuration import Config
Expand Down Expand Up @@ -102,9 +102,16 @@ class CIFAR_VISION(BaseModelHarness):
- get_hist_data_loaders(): ConcatDataset over self.aug_history; then append self.cur_aug
"""

def __init__(self, cfg: Config, model: Optional[nn.Module] = None):
def __init__(
self,
cfg: Config,
model: Optional[nn.Module] = None,
current_ratio: float = 1.0,
):
super().__init__(cfg=cfg, model=VisionModelCifar(cfg=cfg))

self.current_ratio = current_ratio

# FULL datasets (no index split)
self.ds_train = get_cifar_train(cfg=cfg, normalize=True)
self.ds_val = get_cifar_val(cfg=cfg, normalize=True)
Expand Down Expand Up @@ -153,9 +160,23 @@ def update_data_stream(self) -> None:
nw = getattr(self.cfg.train, "num_workers", self.cfg.train.num_workers)
pin = torch.cuda.is_available()

self._cur_train_loader = make_loader(
ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
if self.current_ratio < 1.0:
n_cur = int(len(ds_train_tf) * self.current_ratio)
cur_sampler = RandomSampler(
ds_train_tf, replacement=True, num_samples=n_cur
)
self._cur_train_loader = make_loader(
ds_train_tf,
bs,
shuffle=True,
num_workers=nw,
pin_memory=pin,
sampler=cur_sampler,
)
else:
self._cur_train_loader = make_loader(
ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
self._cur_val_loader = make_loader(
ds_val_tf, bs, shuffle=False, num_workers=nw, pin_memory=pin
)
Expand Down Expand Up @@ -193,9 +214,23 @@ def get_hist_dataloaders(
nw = getattr(self.cfg.data, "num_workers", self.cfg.train.num_workers)
pin = torch.cuda.is_available()

hist_train_loader = make_loader(
ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
if self.current_ratio < 1.0:
n_hist = int(len(ds_hist_train) * (1.0 - self.current_ratio))
hist_sampler = RandomSampler(
ds_hist_train, replacement=True, num_samples=n_hist
)
hist_train_loader = make_loader(
ds_hist_train,
bs,
shuffle=True,
num_workers=nw,
pin_memory=pin,
sampler=hist_sampler,
)
else:
hist_train_loader = make_loader(
ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
hist_val_loader = make_loader(
ds_hist_val, bs, shuffle=False, num_workers=nw, pin_memory=pin
)
Expand Down
9 changes: 7 additions & 2 deletions examples/cifar/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Any
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from apeiron.config.configuration import Config
Expand Down Expand Up @@ -159,6 +159,7 @@ def make_loader(
pin_memory: bool = True,
persistent_workers: bool = True,
prefetch_factor: int = 2,
sampler: "Sampler | None" = None,
) -> DataLoader:
"""
Builds a DataLoader from a given Dataset.
Expand All @@ -185,7 +186,11 @@ def make_loader(
DataLoader
The built DataLoader.
"""
kwargs = dict(batch_size=batch_size, shuffle=shuffle, drop_last=False)
kwargs: dict[str, Any] = dict(batch_size=batch_size, drop_last=False)
if sampler is not None:
kwargs["sampler"] = sampler
else:
kwargs["shuffle"] = shuffle
if num_workers > 0:
kwargs.update(
dict(
Expand Down
47 changes: 39 additions & 8 deletions examples/mnist/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Tuple, Optional, List, Dict, Any
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data import DataLoader, ConcatDataset, RandomSampler

from apeiron.model.torch_model_harness import BaseModelHarness
from apeiron.config.configuration import Config
Expand Down Expand Up @@ -54,9 +54,12 @@ class MNIST_CNN(BaseModelHarness):
- get_hist_dataloaders(): build loaders over FULL train/val using the chained augmentations of all historical loades
"""

def __init__(self, cfg: Config, model: nn.Module = Cnn()):
def __init__(
self, cfg: Config, model: nn.Module = Cnn(), current_ratio: float = 1.0
):
super().__init__(cfg=cfg, model=model)

self.current_ratio = current_ratio
self.eval_metrics = {"accuracy": accuracy, "loss": self.get_criterion()}
self.higher_is_better = {"accuracy": True, "loss": False}

Expand Down Expand Up @@ -127,9 +130,23 @@ def update_data_stream(self) -> None:
nw = getattr(self.cfg.train, "num_workers", self.cfg.train.num_workers)
pin = torch.cuda.is_available()

self._cur_train_loader = make_loader(
ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
if self.current_ratio < 1.0:
n_cur = int(len(ds_train_tf) * self.current_ratio)
cur_sampler = RandomSampler(
ds_train_tf, replacement=True, num_samples=n_cur
)
self._cur_train_loader = make_loader(
ds_train_tf,
bs,
shuffle=True,
num_workers=nw,
pin_memory=pin,
sampler=cur_sampler,
)
else:
self._cur_train_loader = make_loader(
ds_train_tf, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
self._cur_val_loader = make_loader(
ds_val_tf, bs, shuffle=False, num_workers=nw, pin_memory=pin
)
Expand Down Expand Up @@ -166,9 +183,23 @@ def get_hist_dataloaders(
nw = getattr(self.cfg.data, "num_workers", self.cfg.train.num_workers)
pin = torch.cuda.is_available()

hist_train_loader = make_loader(
ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
if self.current_ratio < 1.0:
n_hist = int(len(ds_hist_train) * (1.0 - self.current_ratio))
hist_sampler = RandomSampler(
ds_hist_train, replacement=True, num_samples=n_hist
)
hist_train_loader = make_loader(
ds_hist_train,
bs,
shuffle=True,
num_workers=nw,
pin_memory=pin,
sampler=hist_sampler,
)
else:
hist_train_loader = make_loader(
ds_hist_train, bs, shuffle=True, num_workers=nw, pin_memory=pin
)
hist_val_loader = make_loader(
ds_hist_val, bs, shuffle=False, num_workers=nw, pin_memory=pin
)
Expand Down
9 changes: 7 additions & 2 deletions examples/mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, Dict, List, Any
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF

Expand Down Expand Up @@ -113,6 +113,7 @@ def make_loader(
pin_memory: bool = True,
persistent_workers: bool = True,
prefetch_factor: int = 2,
sampler: "Sampler | None" = None,
) -> DataLoader:
"""
Builds a DataLoader from a given Dataset.
Expand All @@ -139,7 +140,11 @@ def make_loader(
DataLoader
The built DataLoader.
"""
kwargs = dict(batch_size=batch_size, shuffle=shuffle, drop_last=False)
kwargs: dict[str, Any] = dict(batch_size=batch_size, drop_last=False)
if sampler is not None:
kwargs["sampler"] = sampler
else:
kwargs["shuffle"] = shuffle
if num_workers > 0:
kwargs.update(
dict(
Expand Down
6 changes: 6 additions & 0 deletions src/apeiron/config/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ class ContinualLearningCfg:
kfac_lambda: float = 0.01
kfac_ema_decay: float = 0.95

# Prioritized sampling (Raghavan & Papadimitriou, FGCS 2025)
importance_weighting: bool = False
importance_alpha: float = (
1.0 # priority exponent (1=linear, <1=flatter, >1=sharper)
)


@dataclass(frozen=True)
class DriftDetectionCfg:
Expand Down
71 changes: 70 additions & 1 deletion src/apeiron/training/continuous_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Optional

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm

from apeiron.config.configuration import Config
Expand Down Expand Up @@ -61,6 +61,54 @@ def _safe_next(
# If we cannot inspect batch size, just accept the batch
return current_iter, [b.to(self.cfg.device) for b in batch]

def _rebuild_with_priorities(
self,
loader: DataLoader,
drift_event_id: int,
tag: str,
) -> DataLoader:
"""Rebuild a DataLoader so its sampler draws by per-sample priority.

Computes priorities = (L_current − L_anchor)^alpha over the loader's
dataset, logs the distribution, and returns a new DataLoader wired
to a WeightedRandomSampler.
"""
logger = get_logger(__name__)
priorities = self.cl_updater.compute_sample_priorities(loader, self.cfg.device)

# Diagnostic: per-round priority distribution. Non-trivial std means
# theta_star has drifted from the model and prioritized sampling will
# actually re-weight samples. Near-zero std → sampling collapses to
# uniform (the bug-fix regression signal).
with torch.no_grad():
p = priorities.float()
p_mean = p.mean().item()
p_std = p.std().item()
p_min = p.min().item()
p_max = p.max().item()
# ESS = (sum w)^2 / sum w^2, ranges from 1 (degenerate) to N (uniform).
ess = (p.sum().item() ** 2) / (p.pow(2).sum().item() + 1e-30)
ess_frac = ess / p.numel()
logger.info(
f"[priority/{tag}] drift_event_id={drift_event_id} "
f"n={p.numel()} min={p_min:.3e} mean={p_mean:.3e} "
f"max={p_max:.3e} std={p_std:.3e} ess_frac={ess_frac:.3f}"
)

sampler = WeightedRandomSampler(
weights=priorities.tolist(),
num_samples=len(priorities),
replacement=True,
)

return DataLoader(
loader.dataset,
batch_size=loader.batch_size or self.cfg.train.batch_size,
sampler=sampler,
num_workers=loader.num_workers,
drop_last=loader.drop_last,
)

def outer_cl_training_loop(
self,
drift_event_id: int = 0,
Expand All @@ -70,7 +118,28 @@ def outer_cl_training_loop(
cur_train_loader, cur_test_loader = self.modelHarness.get_train_dataloaders()
hist_train_loader, hist_test_loader = self.modelHarness.get_hist_dataloaders()

# Prioritized sampling. Always rebuild the current-task loader with
# importance-based weights when the gate is on. Additionally rebuild
# the historical loader for updaters that actually consume hist_batch
# in fwd_bwd (uses_hist_batch=True — jvp_reg). For EWC/KFAC the
# historical signal is expected to enter via mixing into the current
# loader, so reweighting hist_train_loader for them would be wasted
# work.
if self.cl_updater.importance_weighting and self.cl_updater.theta_star:
cur_train_loader = self._rebuild_with_priorities(
cur_train_loader,
drift_event_id=drift_event_id,
tag="cur",
)
if self.cl_updater.uses_hist_batch and hist_train_loader is not None:
hist_train_loader = self._rebuild_with_priorities(
hist_train_loader,
drift_event_id=drift_event_id,
tag="hist",
)

train_iter = iter(cur_train_loader)

if hist_train_loader is not None:
hist_train_iter = iter(hist_train_loader)
else:
Expand Down
Loading
Loading