From ae7014b3463d62a2cc683684f0704b8bd8854629 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Tue, 12 May 2026 23:33:17 -0400 Subject: [PATCH 1/8] Add documentation for creating model harness This document provides a guide on creating a custom model harness for Apeiron using BaseModelHarness, detailing required lifecycle methods, examples, and best practices. --- docs/create_model_harness.md | 506 +++++++++++++++++++++++++++++++++++ 1 file changed, 506 insertions(+) create mode 100644 docs/create_model_harness.md diff --git a/docs/create_model_harness.md b/docs/create_model_harness.md new file mode 100644 index 0000000..01055f1 --- /dev/null +++ b/docs/create_model_harness.md @@ -0,0 +1,506 @@ +# Creating an Application Model Harness + +This document describes how to create a custom model harness for Apeiron using `BaseModelHarness`. + +A model harness is the integration layer between your application and the Apeiron continual-learning framework. It is responsible for: + +- Managing data streams +- Providing dataloaders +- Configuring optimizers and loss functions +- Defining evaluation metrics +- Supporting checkpointing and drift evaluation + +--- + +# Overview + +To integrate a model into Apeiron, create a subclass of: + +```python +from apeiron.model.torch_model_harness import BaseModelHarness +``` + +Your subclass adapts your application's: + +- datasets +- models +- training streams +- evaluation logic + +to Apeiron's runtime lifecycle. + +--- + +# Required Lifecycle Methods + +Your harness must implement the following methods. + +| Method | Purpose | +|---|---| +| `get_optmizer()` | Return the optimizer | +| `update_data_stream()` | Refresh or replace active stream data | +| `get_stream_dataloader()` | Return continual-learning stream loader | +| `get_hist_dataloaders()` | Return historical train/validation loaders | +| `get_train_dataloaders()` | Return active train/validation loaders | +| `get_criterion()` | Return the training loss function | + +--- + +# Minimal Harness Example + +```python +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import DataLoader, TensorDataset + +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness + + +class ApplicationHarness(BaseModelHarness): + + def __init__(self, cfg: Config, model: nn.Module): + super().__init__(cfg, model) + + self._train_loader = self._make_loader(split="train") + self._val_loader = self._make_loader(split="val") + + self._hist_train_loader = None + self._hist_val_loader = None + + self.eval_metrics["accuracy"] = self._accuracy + + def get_optmizer(self): + return Adam( + self.model.parameters(), + lr=self.cfg.model.lr + ) + + def update_data_stream(self) -> None: + """ + Refresh stream data. + + Replace this with application-specific logic. + """ + + self._train_loader = self._make_loader(split="train") + self._val_loader = self._make_loader(split="val") + + def get_stream_dataloader(self) -> DataLoader: + return self._train_loader + + def get_hist_dataloaders(self): + return ( + self._hist_train_loader, + self._hist_val_loader + ) + + def get_train_dataloaders(self): + return ( + self._train_loader, + self._val_loader + ) + + def get_criterion(self): + return nn.CrossEntropyLoss() + + def _make_loader(self, split: str) -> DataLoader: + + x = torch.randn(128, 10) + y = torch.randint(0, 2, (128,)) + + dataset = TensorDataset(x, y) + + return DataLoader( + dataset, + batch_size=self.cfg.model.batch_size + ) + + @staticmethod + def _accuracy( + y_hat: torch.Tensor, + y: torch.Tensor + ) -> torch.Tensor: + + preds = y_hat.argmax(dim=1) + + return (preds == y).float().mean() +``` + +--- + +# Constructor Requirements + +Every harness constructor should: + +1. Accept: + - `Config` + - `torch.nn.Module` + +2. Call: + +```python +super().__init__(cfg, model) +``` + +This initializes: + +- device placement +- configuration access +- metric registry + +Example: + +```python +def __init__(self, cfg: Config, model: nn.Module): + super().__init__(cfg, model) +``` + +--- + +# Optimizer Configuration + +Implement: + +```python +get_optmizer() +``` + +This method should return a PyTorch optimizer. + +Example: + +```python +def get_optmizer(self): + return Adam( + self.model.parameters(), + lr=self.cfg.model.lr + ) +``` + +Parameter groups are supported: + +```python +return Adam([ + {"params": backbone.parameters(), "lr": 1e-5}, + {"params": head.parameters(), "lr": 1e-3}, +]) +``` + +--- + +# Data Stream Management + +Implement: + +```python +update_data_stream() +``` + +This method is responsible for refreshing or replacing the current data stream. + +Typical use cases: + +- sliding windows +- simulated drift +- streaming inference +- periodic dataset refresh +- online learning + +Example: + +```python +def update_data_stream(self): + self._train_loader = load_new_stream() +``` + +--- + +# Stream Dataloader + +Implement: + +```python +get_stream_dataloader() +``` + +This dataloader is used for continual learning. + +Example: + +```python +def get_stream_dataloader(self): + return self._train_loader +``` + +--- + +# Historical Dataloaders + +Implement: + +```python +get_hist_dataloaders() +``` + +Used for: + +- retention testing +- drift measurement +- historical evaluation + +Expected return type: + +```python +(train_loader, val_loader) +``` + +If no historical data exists: + +```python +return (None, None) +``` + +--- + +# Train and Validation Dataloaders + +Implement: + +```python +get_train_dataloaders() +``` + +Expected return: + +```python +(train_loader, val_loader) +``` + +The validation loader (`index 1`) is used internally by: + +```python +eval() +history_eval() +``` + +Example: + +```python +def get_train_dataloaders(self): + return self._train_loader, self._val_loader +``` + +--- + +# Loss Functions + +Implement: + +```python +get_criterion() +``` + +Example: + +```python +def get_criterion(self): + return nn.CrossEntropyLoss() +``` + +Any PyTorch-compatible loss function is supported. + +--- + +# Evaluation Metrics + +Metrics are stored in: + +```python +self.eval_metrics +``` + +Each metric must accept: + +```python +(y_hat, y) +``` + +and return: + +- tensor +- float +- scalar numeric value + +Example: + +```python +self.eval_metrics["accuracy"] = self._accuracy +``` + +Metric implementation: + +```python +@staticmethod +def _accuracy(y_hat, y): + preds = y_hat.argmax(dim=1) + return (preds == y).float().mean() +``` + +--- + +# Batch Format + +By default, batches are expected to be: + +```python +(x, y) +``` + +If your dataloader returns: + +- dictionaries +- metadata +- custom objects +- multimodal batches + +override: + +```python +_unpack() +``` + +Example: + +```python +def _unpack(self, batch): + + x = batch["features"] + y = batch["labels"] + + return x, y +``` + +--- + +# Evaluation Lifecycle + +The framework provides: + +```python +eval() +``` + +and: + +```python +history_eval() +``` + +These methods: + +- switch model to evaluation mode +- iterate over validation loaders +- compute registered metrics +- aggregate metric averages + +No implementation is required unless custom behavior is needed. + +--- + +# Checkpointing + +Checkpointing is automatically enabled when: + +```python +cfg.model.max_ckpts > 0 +``` + +and: + +```python +cfg.model.ckpts_path +``` + +are configured. + +Save checkpoints with: + +```python +save_ckpt(event) +``` + +Example: + +```python +self.save_ckpt(event=4) +``` + +Generated checkpoint files: + +```text +drift_adaptation_4.pt +``` + +A `latest` pointer file is also maintained automatically. + +Older checkpoints are removed once the retention limit is exceeded. + +--- + +# Recommended Project Structure + +Example application layout: + +```text +application/ +├── data/ +├── models/ +├── harness/ +│ └── application_harness.py +├── training/ +└── configs/ +``` + +--- + +# Best Practices + +## Keep stream logic isolated + +Avoid embedding stream refresh logic throughout the application. + +Prefer: + +```python +update_data_stream() +``` + +as the single source of truth. + +--- + +## Use `_unpack()` for compatibility + +Avoid hardcoding batch assumptions in evaluation or training loops. + +Override `_unpack()` instead. + +--- + +## Register metrics once + +Register metrics during initialization: + +```python +self.eval_metrics["f1"] = self._f1 +``` + +Avoid re-registering metrics dynamically. + +--- + +## Keep loaders persistent + +Avoid reconstructing datasets unnecessarily unless drift or stream updates require it. From 1b918816cc642cf12ad6c2d064c6645d4afa3eb4 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Fri, 5 Jun 2026 23:05:56 -0400 Subject: [PATCH 2/8] Add configuration options --- docs/configurations.md | 193 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 docs/configurations.md diff --git a/docs/configurations.md b/docs/configurations.md new file mode 100644 index 0000000..532739c --- /dev/null +++ b/docs/configurations.md @@ -0,0 +1,193 @@ + +# Apeiron Configuration Reference + +This document describes all currently supported TOML configuration options available in Apeiron, based on `src/apeiron/config/configuration.py`. + +## Root-Level Configuration + +```toml +seed = 1337 +device = "auto" +multi_gpu = false +verbosity = "INFO" +``` + +| Option | Type | Description | +|----------|------|-------------| +| `seed` | int | Random seed used throughout the experiment. | +| `device` | str | Execution device (`auto`, `cpu`, `cuda`, `cuda:N`, `mps`). | +| `multi_gpu` | bool | Enables multi-GPU selection logic when CUDA is available. | +| `verbosity` | str | Logging verbosity level. | + +## [model] + +```toml +[model] +name = "myModel" +pretrained_path = "" +max_ckpts = 0 +ckpts_path = "" +``` + +| Option | Type | Description | +|----------|------|-------------| +| `name` | str | Model identifier used by the example factory. | +| `pretrained_path` | str | Path to checkpoint weights. | +| `max_ckpts` | int | Number of retained CL checkpoints. | +| `ckpts_path` | str | Checkpoint output directory. | + +Unsupported unless `ModelCfg` is extended: + +```toml +architecture_path = "..." +class_name = "..." +input_dim = 512 +num_classes = 10 +``` + +## [data] + +```toml +[data] +name = "dataset" +path = "data.csv" +batch_size = 32 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `name` | str | Dataset/example identifier. | +| `path` | str | Dataset location. | +| `batch_size` | int | Streaming batch size. | + +Unsupported unless `DataCfg` is extended: + +```toml +target_col = "label" +num_features = 512 +feature_columns = ["a", "b"] +``` + +## [train] + +```toml +[train] +batch_size = 64 +num_workers = 4 +init_lr = 0.001 +grad_accumulation_steps = 1 +max_iter = 600 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `batch_size` | int | Training batch size. | +| `num_workers` | int | DataLoader workers. | +| `init_lr` | float | Initial learning rate. | +| `grad_accumulation_steps` | int | Gradient accumulation count. | +| `max_iter` | int | Maximum CL iterations. | + +## [continual_learning] + +```toml +[continual_learning] +update_mode = "base" +jvp_lambda = 0.001 +jvp_deltax_norm = 1 +ewc_lambda = 1000.0 +ewc_ema_decay = 0.95 +kfac_lambda = 0.01 +kfac_ema_decay = 0.95 +``` + +## [drift_detection] + +Core settings: + +```toml +[drift_detection] +detector_name = "ADWINDetector" +detection_interval = 10 +aggregation = "mean" +metric_index = 0 +reset_after_learning = false +max_stream_updates = 20 +``` + +ADWIN: + +```toml +adwin_delta = 0.002 +adwin_minor_threshold = 0.3 +adwin_moderate_threshold = 0.6 +``` + +KSWIN: + +```toml +kswin_alpha = 0.005 +kswin_window_size = 100 +kswin_stat_size = 30 +``` + +Page-Hinkley: + +```toml +ph_min_instances = 30 +ph_delta = 0.005 +ph_threshold = 50 +ph_alpha = 0.9999 +``` + +## [visualization] + +```toml +[visualization] +input = "output/results.csv" +``` + +## [logging] + +```toml +[logging] +backend = "wandb" +experiment_name = "experiment" +mlflow_tracking_uri = "http://localhost:5000" +``` + +## Complete Valid Example + +```toml +seed = 1337 +device = "auto" +multi_gpu = false +verbosity = "INFO" + +[model] +name = "mnist" +pretrained_path = "examples/mnist/mnist.pth" +max_ckpts = 0 +ckpts_path = "output/mnist" + +[data] +name = "mnist" +path = "" +batch_size = 32 + +[train] +batch_size = 64 +num_workers = 4 +init_lr = 0.001 + +[continual_learning] +update_mode = "base" + +[drift_detection] +detector_name = "ADWINDetector" + +[logging] +backend = "none" + +[visualization] +input = "output/results.csv" +``` From 4cd01b2a94718010d85f47c3722eb1aa0fb421c8 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Sun, 7 Jun 2026 00:36:40 -0400 Subject: [PATCH 3/8] More documentation --- docs/configurations.md | 183 ++++++++++++++++++++++------ src/apeiron/config/configuration.py | 5 +- 2 files changed, 151 insertions(+), 37 deletions(-) diff --git a/docs/configurations.md b/docs/configurations.md index 532739c..0a71f5e 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -1,7 +1,8 @@ # Apeiron Configuration Reference -This document describes all currently supported TOML configuration options available in Apeiron, based on `src/apeiron/config/configuration.py`. +This document describes all currently supported TOML configuration options available in Apeiron. +The code to parse the configuration file can be found in `src/apeiron/config/configuration.py`. ## Root-Level Configuration @@ -14,36 +15,28 @@ verbosity = "INFO" | Option | Type | Description | |----------|------|-------------| -| `seed` | int | Random seed used throughout the experiment. | -| `device` | str | Execution device (`auto`, `cpu`, `cuda`, `cuda:N`, `mps`). | -| `multi_gpu` | bool | Enables multi-GPU selection logic when CUDA is available. | -| `verbosity` | str | Logging verbosity level. | +| `seed` | int | Required, random seed used throughout the experiment. | +| `device` | str | Required, execution device (`auto`, `cpu`, `cuda`, `cuda:N`, `mps`). | +| `multi_gpu` | bool | Enables multi-GPU selection logic when CUDA is available. Default: false. | +| `verbosity` | str | Logging verbosity level (DEBUG, INFO, INFO:n, WARNING, ERROR, CRITICAL). Default: INFO. | ## [model] ```toml [model] name = "myModel" -pretrained_path = "" -max_ckpts = 0 -ckpts_path = "" +pretrained_path = "/path/to/model.pt" +max_ckpts = 1 +ckpts_path = "/path/to/checkpoint" ``` | Option | Type | Description | |----------|------|-------------| -| `name` | str | Model identifier used by the example factory. | -| `pretrained_path` | str | Path to checkpoint weights. | -| `max_ckpts` | int | Number of retained CL checkpoints. | +| `name` | str | Required, model identifier used by the example factory. | +| `pretrained_path` | str | Path to checkpoint weights. Default: none. | +| `max_ckpts` | int | Number of retained model checkpoints after CL. Default: 0 (no checkpoint is saved). | | `ckpts_path` | str | Checkpoint output directory. | -Unsupported unless `ModelCfg` is extended: - -```toml -architecture_path = "..." -class_name = "..." -input_dim = 512 -num_classes = 10 -``` ## [data] @@ -56,17 +49,10 @@ batch_size = 32 | Option | Type | Description | |----------|------|-------------| -| `name` | str | Dataset/example identifier. | -| `path` | str | Dataset location. | -| `batch_size` | int | Streaming batch size. | - -Unsupported unless `DataCfg` is extended: +| `name` | str | Required, dataset/example identifier. | +| `path` | str | Reguired, dataset location. | +| `batch_size` | int | Streaming batch size. Default 1 (inference/drift detection is done on each individual sample). | -```toml -target_col = "label" -num_features = 512 -feature_columns = ["a", "b"] -``` ## [train] @@ -81,27 +67,46 @@ max_iter = 600 | Option | Type | Description | |----------|------|-------------| -| `batch_size` | int | Training batch size. | -| `num_workers` | int | DataLoader workers. | -| `init_lr` | float | Initial learning rate. | -| `grad_accumulation_steps` | int | Gradient accumulation count. | -| `max_iter` | int | Maximum CL iterations. | +| `batch_size` | int | Required, taining batch size. | +| `num_workers` | int | Required, dataLoader workers. | +| `init_lr` | float | Required, initial learning rate. | +| `grad_accumulation_steps` | int | Gradient accumulation count. Default: 1. | +| `max_iter` | int | Maximum CL iterations. Default: 600. | + ## [continual_learning] ```toml [continual_learning] update_mode = "base" + jvp_lambda = 0.001 jvp_deltax_norm = 1 + ewc_lambda = 1000.0 ewc_ema_decay = 0.95 + kfac_lambda = 0.01 kfac_ema_decay = 0.95 ``` +| Option | Type | Description | +|----------|------|-------------| +| `update_mode` | str | Required, CL strategy to use (base, jvp_reg, ewc_online, kfac_online, none). | +| `jvp_lambda` | float | Weight for JVP regularization term (`jvp_reg` mode). Default: 0.001. | +| `jvp_deltax_norm` | float | Scale factor for JVP input perturbation direction. Default: 1. | +| `ewc_lambda` | float | EWC regularization strength (`ewc_online` mode). Default: 1000. | +| `ewc_ema_decay` | float | EMA decay for online Fisher prior in EWC. Default: 0.95. | +| `kfac_lambda` | float | KFAC penalty strength (`kfac_online` mode). Default: 0.01. | +| `kfac_ema_decay` | float | EMA decay for running Kronecker factors in KFAC mode. Default: 0.95. | + +Details about the CL algorithms available can be found in [docs/continuous_learning.md](continuous_learning.md) + + ## [drift_detection] +All the parameters are optional. Default values are provided in the examples below. + Core settings: ```toml @@ -114,6 +119,15 @@ reset_after_learning = false max_stream_updates = 20 ``` +| Option | Type | Description | +|----------|------|-------------| +| `detector_name` | str | Drift detection algorithm (ADWINDetector, KSWINDetector, PageHinkleyDetector, ModelPerformanceDetector, ModelEvalDetector, EnsembleDetector). | +| `detection_interval` | int | Check drift every N monitored batches. If `<= 0`, checks are disabled. | +| `aggregation` | str | How buffered metric values are aggregated before detector update. Supported by monitor: `mean`, `median`, `last`. | +| `metric_index` | int | Index into `modelHarness.eval_metrics` order. | +| `reset_after_learning` | bool | If true, detector state resets after each CL event. | +| `max_stream_updates` | int | Monitoring stops after this many stream extensions. | + ADWIN: ```toml @@ -122,6 +136,12 @@ adwin_minor_threshold = 0.3 adwin_moderate_threshold = 0.6 ``` +| Option | Type | Description | +|----------|------|-------------| +| `adwin_delta` | float | ADWIN confidence/sensitivity parameter. | +| `adwin_minor_threshold` | float | ADWIN regime threshold (CL boundary). | +| `adwin_moderate_threshold` | float | ADWIN regime threshold (fine-tuning boundary). | + KSWIN: ```toml @@ -130,6 +150,12 @@ kswin_window_size = 100 kswin_stat_size = 30 ``` +| Option | Type | Description | +|----------|------|-------------| +| `kswin_alpha` | float | KSWIN significance level. | +| `kswin_window_size` | int | KSWIN reference window size. | +| `kswin_stat_size` | int | KSWIN recent sample window size. | + Page-Hinkley: ```toml @@ -139,13 +165,95 @@ ph_threshold = 50 ph_alpha = 0.9999 ``` +| Option | Type | Description | +|----------|------|-------------| +| `ph_min_instances` | int | Page-Hinkley warm-up count before detection is meaningful. | +| `ph_delta` | float | Page-Hinkley change magnitude parameter. | +| `ph_threshold` | float | Page-Hinkley trigger threshold. | +| `ph_alpha` | float | Page-Hinkley forgetting factor. | + + +Details about the drift detection algorithms available can be found in [docs/drift_detectors.md](drift_detectors.md) + + ## [visualization] +The visualization configuration options are used to store the results of the metrics captured during the run. + ```toml [visualization] input = "output/results.csv" ``` +| Option | Type | Description | +|----------|------|-------------| +| `input` | str | Required, path to the CSV output file storing the metrics. | + +Metrics used: +``` +cl/cperf_detector_flop +cl/cperf_detector_flops +cl/cperf_detector_time +cl/cperf_infer_flop +cl/cperf_infer_flops +cl/cperf_infer_time +cl/cperf_optimizer_flop +cl/cperf_optimizer_flops +cl/cperf_optimizer_time +cl/cperf_update_fwd_bwd_flop +cl/cperf_update_fwd_bwd_flops +cl/cperf_update_fwd_bwd_time +cl/drift_event_id +cl/jvp_reg_forgetting_loss +cl/jvp_reg_generation_loss +cl/jvp_reg_total_loss +cl/step +drift/confidence +drift/cperf_detector_flop +drift/cperf_detector_flops +drift/cperf_detector_time +drift/cperf_infer_flop +drift/cperf_infer_flops +drift/cperf_infer_time +drift/cperf_optimizer_flop +drift/cperf_optimizer_flops +drift/cperf_optimizer_time +drift/cperf_update_fwd_bwd_flop +drift/cperf_update_fwd_bwd_flops +drift/cperf_update_fwd_bwd_time +drift/detected +drift/metric_0 +drift/regime +drift/score +drift/step +eval/accuracy +eval/loss +eval/step +eval/test_curr_acc +eval/test_hist_acc +``` + +Example output file: +```csv +step,metric,value +10,eval/accuracy,62.5 +10,eval/loss,2.0406203269958496 +10,eval/step,10 +10,drift/score,0.0 +10,drift/regime,stable +10,drift/confidence,0.998 +10,drift/metric_0,44.6875 +10,drift/cperf_infer_flop,1624834048.0 +10,drift/cperf_infer_time,0.00556622925796546 +10,drift/cperf_infer_flops,291909293113.4319 +10,drift/cperf_detector_flop,0.0 +10,drift/cperf_detector_time,8.56249826028943e-05 +10,drift/cperf_detector_flops,0.0 +10,cl/jvp_reg_total_loss,3.4211268424987793 +10,cl/jvp_reg_forgetting_loss,0.0 +10,cl/jvp_reg_generation_loss,3.4211268424987793 +``` + ## [logging] ```toml @@ -155,6 +263,13 @@ experiment_name = "experiment" mlflow_tracking_uri = "http://localhost:5000" ``` +| Option | Type | Description | +|----------|------|-------------| +| `backend` | str | Logging backend (wandb, mlflow, or none). Default: wandb. | +| `experiment_name` | str | W&B project name or MLflow experiment name. Default: None. | +| `mlflow_tracking_uri` | str | MLflow tracking server URI. Default: None. | + + ## Complete Valid Example ```toml diff --git a/src/apeiron/config/configuration.py b/src/apeiron/config/configuration.py index df55747..c2e9e42 100644 --- a/src/apeiron/config/configuration.py +++ b/src/apeiron/config/configuration.py @@ -86,8 +86,7 @@ def _select_best_gpu() -> int | None: @dataclass(frozen=True) class ModelCfg: name: str - pretrained_path: str - + pretrained_path: str = "" # FIFO checkpointing: 0 disables, N keeps last N post-CL snapshots max_ckpts: int = 0 ckpts_path: str = "" @@ -178,7 +177,7 @@ class Config: seed: int device: str - multi_gpu: bool + multi_gpu: bool = False verbosity: str = "INFO" visualization: VisualizationCfg | None = None logging: LoggingCfg | None = None From 871f154becf0cf2728ff8257440da82baac40002 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Sun, 7 Jun 2026 11:11:59 -0400 Subject: [PATCH 4/8] configuration updates --- docs/README.md | 20 ++++++-------------- docs/configurations.md | 9 +++++++++ 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/docs/README.md b/docs/README.md index 7740f72..fcc24af 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,16 +1,18 @@ # BaseSim Documentation -This directory contains the detailed reference docs for the framework's three main extension points: +This directory contains the detailed reference docs for the framework's three main extension points and configurations: +- `configurations.md`: required and optional configuration settings - `model_harness.md`: model + data-stream integration contract - `drift_detectors.md`: detector classes, detector config, and detector wiring - `continuous_learning.md`: continual-learning trainer, updater modes, and training config ## Read Order -1. Start with `model_harness.md` to understand how models and stream loaders are exposed. -2. Read `drift_detectors.md` to see how monitoring decisions are made. -3. Read `continuous_learning.md` to understand what happens after drift is detected. +1. Start with `configurations.md` to learn on the required and optional configuration parameters used by Apeiron. +2. Start with `model_harness.md` to understand how models and stream loaders are exposed. +3. Read `drift_detectors.md` to see how monitoring decisions are made. +4. Read `continuous_learning.md` to understand what happens after drift is detected. ## Runtime Flow @@ -20,13 +22,3 @@ This directory contains the detailed reference docs for the framework's three ma 4. On drift, `src/training/continuous_trainer.py` runs a CL loop with an updater from `src/training/updater/create_updater.py`. 5. Logging is stage-aware (`eval`, `drift`, `cl`) via `src/logger/`. -## Config Sections - -The parser expects these TOML sections: - -- Required: `[model]`, `[data]`, `[train]`, `[drift_detection]` -- Optional: `[continual_learning]`, `[visualization]` -- Required top-level keys: `seed`, `device`, `multi_gpu` -- Common top-level optional key: `verbosity` - -See each doc file for field-by-field meaning and defaults. diff --git a/docs/configurations.md b/docs/configurations.md index 0a71f5e..db40d63 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -4,6 +4,15 @@ This document describes all currently supported TOML configuration options available in Apeiron. The code to parse the configuration file can be found in `src/apeiron/config/configuration.py`. +The parser expects these TOML sections: + +- Required: `[model]`, `[data]`, `[train]`, `[drift_detection]` +- Optional: `[continual_learning]`, `[visualization]` +- Required top-level keys: `seed`, `device` +- Common top-level optional key: `verbosity` + +See each section for field-by-field meaning and defaults. + ## Root-Level Configuration ```toml From b6b9951864090a151f580cc555b2a8ea9a798154 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Sun, 7 Jun 2026 11:12:27 -0400 Subject: [PATCH 5/8] Moved the section for creating model harness inside the documentation md --- docs/create_model_harness.md | 506 ----------------------------------- docs/model_harness.md | 266 ++++++++++++++++++ 2 files changed, 266 insertions(+), 506 deletions(-) delete mode 100644 docs/create_model_harness.md diff --git a/docs/create_model_harness.md b/docs/create_model_harness.md deleted file mode 100644 index 01055f1..0000000 --- a/docs/create_model_harness.md +++ /dev/null @@ -1,506 +0,0 @@ -# Creating an Application Model Harness - -This document describes how to create a custom model harness for Apeiron using `BaseModelHarness`. - -A model harness is the integration layer between your application and the Apeiron continual-learning framework. It is responsible for: - -- Managing data streams -- Providing dataloaders -- Configuring optimizers and loss functions -- Defining evaluation metrics -- Supporting checkpointing and drift evaluation - ---- - -# Overview - -To integrate a model into Apeiron, create a subclass of: - -```python -from apeiron.model.torch_model_harness import BaseModelHarness -``` - -Your subclass adapts your application's: - -- datasets -- models -- training streams -- evaluation logic - -to Apeiron's runtime lifecycle. - ---- - -# Required Lifecycle Methods - -Your harness must implement the following methods. - -| Method | Purpose | -|---|---| -| `get_optmizer()` | Return the optimizer | -| `update_data_stream()` | Refresh or replace active stream data | -| `get_stream_dataloader()` | Return continual-learning stream loader | -| `get_hist_dataloaders()` | Return historical train/validation loaders | -| `get_train_dataloaders()` | Return active train/validation loaders | -| `get_criterion()` | Return the training loss function | - ---- - -# Minimal Harness Example - -```python -import torch -from torch import nn -from torch.optim import Adam -from torch.utils.data import DataLoader, TensorDataset - -from apeiron.config.configuration import Config -from apeiron.model.torch_model_harness import BaseModelHarness - - -class ApplicationHarness(BaseModelHarness): - - def __init__(self, cfg: Config, model: nn.Module): - super().__init__(cfg, model) - - self._train_loader = self._make_loader(split="train") - self._val_loader = self._make_loader(split="val") - - self._hist_train_loader = None - self._hist_val_loader = None - - self.eval_metrics["accuracy"] = self._accuracy - - def get_optmizer(self): - return Adam( - self.model.parameters(), - lr=self.cfg.model.lr - ) - - def update_data_stream(self) -> None: - """ - Refresh stream data. - - Replace this with application-specific logic. - """ - - self._train_loader = self._make_loader(split="train") - self._val_loader = self._make_loader(split="val") - - def get_stream_dataloader(self) -> DataLoader: - return self._train_loader - - def get_hist_dataloaders(self): - return ( - self._hist_train_loader, - self._hist_val_loader - ) - - def get_train_dataloaders(self): - return ( - self._train_loader, - self._val_loader - ) - - def get_criterion(self): - return nn.CrossEntropyLoss() - - def _make_loader(self, split: str) -> DataLoader: - - x = torch.randn(128, 10) - y = torch.randint(0, 2, (128,)) - - dataset = TensorDataset(x, y) - - return DataLoader( - dataset, - batch_size=self.cfg.model.batch_size - ) - - @staticmethod - def _accuracy( - y_hat: torch.Tensor, - y: torch.Tensor - ) -> torch.Tensor: - - preds = y_hat.argmax(dim=1) - - return (preds == y).float().mean() -``` - ---- - -# Constructor Requirements - -Every harness constructor should: - -1. Accept: - - `Config` - - `torch.nn.Module` - -2. Call: - -```python -super().__init__(cfg, model) -``` - -This initializes: - -- device placement -- configuration access -- metric registry - -Example: - -```python -def __init__(self, cfg: Config, model: nn.Module): - super().__init__(cfg, model) -``` - ---- - -# Optimizer Configuration - -Implement: - -```python -get_optmizer() -``` - -This method should return a PyTorch optimizer. - -Example: - -```python -def get_optmizer(self): - return Adam( - self.model.parameters(), - lr=self.cfg.model.lr - ) -``` - -Parameter groups are supported: - -```python -return Adam([ - {"params": backbone.parameters(), "lr": 1e-5}, - {"params": head.parameters(), "lr": 1e-3}, -]) -``` - ---- - -# Data Stream Management - -Implement: - -```python -update_data_stream() -``` - -This method is responsible for refreshing or replacing the current data stream. - -Typical use cases: - -- sliding windows -- simulated drift -- streaming inference -- periodic dataset refresh -- online learning - -Example: - -```python -def update_data_stream(self): - self._train_loader = load_new_stream() -``` - ---- - -# Stream Dataloader - -Implement: - -```python -get_stream_dataloader() -``` - -This dataloader is used for continual learning. - -Example: - -```python -def get_stream_dataloader(self): - return self._train_loader -``` - ---- - -# Historical Dataloaders - -Implement: - -```python -get_hist_dataloaders() -``` - -Used for: - -- retention testing -- drift measurement -- historical evaluation - -Expected return type: - -```python -(train_loader, val_loader) -``` - -If no historical data exists: - -```python -return (None, None) -``` - ---- - -# Train and Validation Dataloaders - -Implement: - -```python -get_train_dataloaders() -``` - -Expected return: - -```python -(train_loader, val_loader) -``` - -The validation loader (`index 1`) is used internally by: - -```python -eval() -history_eval() -``` - -Example: - -```python -def get_train_dataloaders(self): - return self._train_loader, self._val_loader -``` - ---- - -# Loss Functions - -Implement: - -```python -get_criterion() -``` - -Example: - -```python -def get_criterion(self): - return nn.CrossEntropyLoss() -``` - -Any PyTorch-compatible loss function is supported. - ---- - -# Evaluation Metrics - -Metrics are stored in: - -```python -self.eval_metrics -``` - -Each metric must accept: - -```python -(y_hat, y) -``` - -and return: - -- tensor -- float -- scalar numeric value - -Example: - -```python -self.eval_metrics["accuracy"] = self._accuracy -``` - -Metric implementation: - -```python -@staticmethod -def _accuracy(y_hat, y): - preds = y_hat.argmax(dim=1) - return (preds == y).float().mean() -``` - ---- - -# Batch Format - -By default, batches are expected to be: - -```python -(x, y) -``` - -If your dataloader returns: - -- dictionaries -- metadata -- custom objects -- multimodal batches - -override: - -```python -_unpack() -``` - -Example: - -```python -def _unpack(self, batch): - - x = batch["features"] - y = batch["labels"] - - return x, y -``` - ---- - -# Evaluation Lifecycle - -The framework provides: - -```python -eval() -``` - -and: - -```python -history_eval() -``` - -These methods: - -- switch model to evaluation mode -- iterate over validation loaders -- compute registered metrics -- aggregate metric averages - -No implementation is required unless custom behavior is needed. - ---- - -# Checkpointing - -Checkpointing is automatically enabled when: - -```python -cfg.model.max_ckpts > 0 -``` - -and: - -```python -cfg.model.ckpts_path -``` - -are configured. - -Save checkpoints with: - -```python -save_ckpt(event) -``` - -Example: - -```python -self.save_ckpt(event=4) -``` - -Generated checkpoint files: - -```text -drift_adaptation_4.pt -``` - -A `latest` pointer file is also maintained automatically. - -Older checkpoints are removed once the retention limit is exceeded. - ---- - -# Recommended Project Structure - -Example application layout: - -```text -application/ -├── data/ -├── models/ -├── harness/ -│ └── application_harness.py -├── training/ -└── configs/ -``` - ---- - -# Best Practices - -## Keep stream logic isolated - -Avoid embedding stream refresh logic throughout the application. - -Prefer: - -```python -update_data_stream() -``` - -as the single source of truth. - ---- - -## Use `_unpack()` for compatibility - -Avoid hardcoding batch assumptions in evaluation or training loops. - -Override `_unpack()` instead. - ---- - -## Register metrics once - -Register metrics during initialization: - -```python -self.eval_metrics["f1"] = self._f1 -``` - -Avoid re-registering metrics dynamically. - ---- - -## Keep loaders persistent - -Avoid reconstructing datasets unnecessarily unless drift or stream updates require it. diff --git a/docs/model_harness.md b/docs/model_harness.md index 7fe2f8f..7acad4c 100644 --- a/docs/model_harness.md +++ b/docs/model_harness.md @@ -114,3 +114,269 @@ Supported `model.name` values for CIFAR/ImageNet loaders: - `get_stream_dataloader()` must return a non-`None` loader after `update_data_stream()`. - `get_train_dataloaders()` must return non-`None` loaders after `update_data_stream()`. - If your batch is not exactly `(x, y)`, override `_unpack` in your harness. + +# Creating a Custom Model Harness + +A custom model harness is the integration point between a user model and Apeiron. The harness allows Apeiron to monitor model performance, detect drift, trigger continual learning, and manage retraining without requiring modifications to the underlying scientific model. + +## Required Files + +A custom example typically consists of: + +```text +examples// +├── __init__.py +├── model.py +├── utils.py +└── .toml +``` + +In addition, the example must be registered in: + +```text +examples/utils.py +``` + +through the `get_example()` factory. + +## Step 1: Create the Harness Class + +All custom harnesses must inherit from `BaseModelHarness`. + +```python +class MyHarness(BaseModelHarness): + def __init__(self, cfg): + model = MyModel() + super().__init__(cfg=cfg, model=model) +``` + +The base class handles: + +- Device placement +- Checkpoint support +- Evaluation loops +- Continual-learning integration +- Drift-monitor integration + +## Step 2: Define Evaluation Metrics + +Apeiron's drift detectors operate on metric streams. + +At least one metric should be exposed through: + +```python +self.eval_metrics = { + "accuracy": accuracy, +} +``` + +or for regression: + +```python +self.eval_metrics = { + "mse": regression_mse, +} +``` + +The order of metrics is important because drift detection uses: + +```toml +metric_index = 0 +``` + +which refers to the first metric in `eval_metrics`. + +## Step 3: Implement Required Methods + +### Optimizer + +```python +def get_optmizer(self): + return torch.optim.Adam( + self.model.parameters(), + lr=self.cfg.train.init_lr, + ) +``` + +### Criterion + +```python +def get_criterion(self): + return torch.nn.CrossEntropyLoss() +``` + +or for regression: + +```python +def get_criterion(self): + return torch.nn.MSELoss() +``` + +### Current Training Loaders + +```python +def get_train_dataloaders(self): + return self.train_loader, self.val_loader +``` + +These loaders are used during continual-learning updates. + +### Stream Loader + +```python +def get_stream_dataloader(self): + return self.stream_loader +``` + +This loader is used during monitoring and drift detection. + +### Historical Loaders + +```python +def get_hist_dataloaders(self): + if self.task_counter == 0: + return None, None + + return self.hist_train_loader, self.hist_val_loader +``` + +Historical loaders provide replay data for continual-learning algorithms such as EWC and KFAC. + +### Stream Updates + +```python +def update_data_stream(self): + self.task_counter += 1 + + self.train_loader = build_train_loader() + self.val_loader = build_val_loader() + self.stream_loader = build_stream_loader() +``` + +This function is called by the monitor whenever a new stream segment becomes active. + +## Step 4: Build Dataset Utilities + +Dataset loading should be implemented in `utils.py`. + +Typical responsibilities include: + +- Reading datasets +- Feature preprocessing +- Label extraction +- Train/validation splitting +- Drift simulation +- DataLoader construction + +Example: + +```python +def make_loader(dataset, batch_size): + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + ) +``` + +## Step 5: Register the Example + +Add a branch to `examples/utils.py`: + +```python +elif cfg.data.name == "mydataset": + from examples.mydataset.model import MyHarness + + return MyHarness(cfg=cfg) +``` + +The value must match: + +```toml +[data] +name = "mydataset" +``` + +## Step 6: Create the Configuration File + +Create: + +```text +examples//.toml +``` + +Requirested and optional parameters are described in [configurations.md](configurations.md). + +Only keys defined in `configuration.py` are allowed. Custom dataset-specific parameters should be implemented inside the harness or utility code unless the configuration dataclasses are extended. + +## Step 7: Validate the Integration + +Validate TOML syntax: + +```bash +python -c "import tomllib; tomllib.load(open('examples/mydataset/mydataset.toml', 'rb')); print('TOML OK')" +``` + +Validate factory registration: + +```bash +poetry run python -c "from examples.utils import get_example; print('factory OK')" +``` + +## Step 8: Run a Smoke Test + +Before running a full experiment, execute a small CPU-only test: + +```bash +poetry run python -m src.main \ + --config examples/mydataset/mydataset.toml \ + --set train.max_iter=2 \ + --set drift_detection.max_stream_updates=2 \ + --set drift_detection.detection_interval=1 \ + --set device=cpu \ + --set logging.backend=none +``` + +A successful smoke test confirms: + +- Configuration loading +- Harness construction +- Dataset loading +- Drift-monitor integration +- Continual-learning integration +- Metric reporting + +## Common Errors + +### `KeyError: 'model'` + +The TOML file is missing the `[model]` section. + +### `ImportError` when loading the harness + +The factory registration in `examples/utils.py` is missing or incorrect. + +### `TypeError: DataCfg.__init__() got an unexpected keyword argument ...` + +A TOML key is not defined in the configuration dataclasses. + +### `IndexError: list index out of range` during drift detection + +No evaluation metrics are being emitted. Ensure `self.eval_metrics` contains at least one metric. + +### `get_stream_dataloader()` returns `None` + +The monitor cannot evaluate data until `update_data_stream()` constructs and returns a valid stream loader. + +## Summary + +A custom harness only needs to provide six core behaviors: + +1. Construct the model. +2. Construct the optimizer. +3. Construct the criterion. +4. Provide current train/validation loaders. +5. Provide historical replay loaders. +6. Update the active data stream. + +Once those interfaces are implemented, Apeiron can automatically provide monitoring, drift detection, continual learning, checkpointing, logging, and experiment management for the underlying scientific model. From 7e6afba03a428081671e2b4792e3d23e0c09f5fa Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Sun, 7 Jun 2026 15:57:28 -0400 Subject: [PATCH 6/8] wip --- docs/README.md | 2 +- docs/configurations.md | 8 +++++--- docs/continuous_learning.md | 17 +++++++++++++++++ src/apeiron/config/configuration.py | 2 +- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/docs/README.md b/docs/README.md index fcc24af..6cfc43d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,7 +10,7 @@ This directory contains the detailed reference docs for the framework's three ma ## Read Order 1. Start with `configurations.md` to learn on the required and optional configuration parameters used by Apeiron. -2. Start with `model_harness.md` to understand how models and stream loaders are exposed. +2. Continue with `model_harness.md` to understand how models and stream loaders are exposed. 3. Read `drift_detectors.md` to see how monitoring decisions are made. 4. Read `continuous_learning.md` to understand what happens after drift is detected. diff --git a/docs/configurations.md b/docs/configurations.md index db40d63..971b7f2 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -7,7 +7,7 @@ The code to parse the configuration file can be found in `src/apeiron/config/con The parser expects these TOML sections: - Required: `[model]`, `[data]`, `[train]`, `[drift_detection]` -- Optional: `[continual_learning]`, `[visualization]` +- Optional: `[continual_learning]`, `[visualization]`, `[logging]` - Required top-level keys: `seed`, `device` - Common top-level optional key: `verbosity` @@ -187,7 +187,7 @@ Details about the drift detection algorithms available can be found in [docs/dri ## [visualization] -The visualization configuration options are used to store the results of the metrics captured during the run. +The visualization configuration options are optional and are used to store the results of the metrics captured during the run. ```toml [visualization] @@ -196,7 +196,7 @@ input = "output/results.csv" | Option | Type | Description | |----------|------|-------------| -| `input` | str | Required, path to the CSV output file storing the metrics. | +| `input` | str | Path to the CSV output file storing the metrics. Default: output/output.csv. | Metrics used: ``` @@ -265,6 +265,8 @@ step,metric,value ## [logging] +All the parameters are optional. + ```toml [logging] backend = "wandb" diff --git a/docs/continuous_learning.md b/docs/continuous_learning.md index 4d777f5..2495e83 100644 --- a/docs/continuous_learning.md +++ b/docs/continuous_learning.md @@ -126,3 +126,20 @@ ewc_ema_decay = 0.95 kfac_lambda = 1e-2 kfac_ema_decay = 0.95 ``` + +Codes wanting to do continual learning should use the `ContinuousTrainer` class that takes the configuration parameters, model harness, logger and the profiler. + +```python +cfg = build_config ( argv ) + +xtrainer = ContinuousTrainer( + cfg=cfg, + modelHarness=modelHarness, + logger=logger, + profiler=flops_profiler, +) + +trainer.outer_cl_training_loop(drift_event_id=drift_count) +if modelHarness.ckpts_enabled: + ckptpath = modelHarness.save_ckpt(event=drift_count) +``` diff --git a/src/apeiron/config/configuration.py b/src/apeiron/config/configuration.py index c2e9e42..e79a566 100644 --- a/src/apeiron/config/configuration.py +++ b/src/apeiron/config/configuration.py @@ -155,7 +155,7 @@ class DriftDetectionCfg: @dataclass(frozen=True) class VisualizationCfg: - input: str = "output/cl_only.csv" # CSV path where run metrics are written + input: str = "output/output.csv" # CSV path where run metrics are written @dataclass(frozen=True) From 386476ac829c718c6b391a1a5501a9ac39aae733 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Sun, 7 Jun 2026 16:04:04 -0400 Subject: [PATCH 7/8] wip --- docs/drift_detectors.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/drift_detectors.md b/docs/drift_detectors.md index e0bc359..cbbce6c 100644 --- a/docs/drift_detectors.md +++ b/docs/drift_detectors.md @@ -182,3 +182,40 @@ adwin_delta = 0.002 adwin_minor_threshold = 0.3 adwin_moderate_threshold = 0.6 ``` + +Code for the workflow in the Monitor implementation: + +```python +detector = load_drift_detector(cfg) + +data = modelHarness.get_stream_dataloader() + +for batch_idx, batch in tqdm( + enumerate(val_loader), + desc="Inference on batches", + leave=False, +): + # Inference on batch and compute all metrics + metrics = self._evaluate_batch(batch) + metric_buffer.append(metrics) + + +metric_idx = cfg.drift_detection.metric_index +metric_values = [m[metric_idx] for m in metric_buffer] + +# aggregate metrics +aggregation = cfg.drift_detection.aggregation +if aggregation == "mean": + agg_metric = float(np.mean(metric_values)) +elif aggregation == "median": + agg_metric = float(np.median(metric_values)) +elif aggregation == "last": + agg_metric = float(metric_values[-1]) + +drift_signal = detector.update(agg_metric) +if drift_signal.drift_detected: + handle_drift(drift_signal) + +self.detector.reset() +modelHarness.update_data_stream() +``` From 1a428c49aaee5a7037922fa32c6485ce20131b51 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Mon, 8 Jun 2026 11:29:18 -0400 Subject: [PATCH 8/8] CI pass --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 7c90b0b..2f88468 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -278,7 +278,7 @@ def test_drift_detection_defaults(self): def test_visualization_cfg(self): viz = VisualizationCfg() - assert viz.input == "output/cl_only.csv" + assert viz.input == "output/output.csv" # ---------------------------------------------------------------------------