diff --git a/docs/README.md b/docs/README.md index 7740f72..6cfc43d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,16 +1,18 @@ # BaseSim Documentation -This directory contains the detailed reference docs for the framework's three main extension points: +This directory contains the detailed reference docs for the framework's three main extension points and configurations: +- `configurations.md`: required and optional configuration settings - `model_harness.md`: model + data-stream integration contract - `drift_detectors.md`: detector classes, detector config, and detector wiring - `continuous_learning.md`: continual-learning trainer, updater modes, and training config ## Read Order -1. Start with `model_harness.md` to understand how models and stream loaders are exposed. -2. Read `drift_detectors.md` to see how monitoring decisions are made. -3. Read `continuous_learning.md` to understand what happens after drift is detected. +1. Start with `configurations.md` to learn on the required and optional configuration parameters used by Apeiron. +2. Continue with `model_harness.md` to understand how models and stream loaders are exposed. +3. Read `drift_detectors.md` to see how monitoring decisions are made. +4. Read `continuous_learning.md` to understand what happens after drift is detected. ## Runtime Flow @@ -20,13 +22,3 @@ This directory contains the detailed reference docs for the framework's three ma 4. On drift, `src/training/continuous_trainer.py` runs a CL loop with an updater from `src/training/updater/create_updater.py`. 5. Logging is stage-aware (`eval`, `drift`, `cl`) via `src/logger/`. -## Config Sections - -The parser expects these TOML sections: - -- Required: `[model]`, `[data]`, `[train]`, `[drift_detection]` -- Optional: `[continual_learning]`, `[visualization]` -- Required top-level keys: `seed`, `device`, `multi_gpu` -- Common top-level optional key: `verbosity` - -See each doc file for field-by-field meaning and defaults. diff --git a/docs/configurations.md b/docs/configurations.md new file mode 100644 index 0000000..971b7f2 --- /dev/null +++ b/docs/configurations.md @@ -0,0 +1,319 @@ + +# Apeiron Configuration Reference + +This document describes all currently supported TOML configuration options available in Apeiron. +The code to parse the configuration file can be found in `src/apeiron/config/configuration.py`. + +The parser expects these TOML sections: + +- Required: `[model]`, `[data]`, `[train]`, `[drift_detection]` +- Optional: `[continual_learning]`, `[visualization]`, `[logging]` +- Required top-level keys: `seed`, `device` +- Common top-level optional key: `verbosity` + +See each section for field-by-field meaning and defaults. + +## Root-Level Configuration + +```toml +seed = 1337 +device = "auto" +multi_gpu = false +verbosity = "INFO" +``` + +| Option | Type | Description | +|----------|------|-------------| +| `seed` | int | Required, random seed used throughout the experiment. | +| `device` | str | Required, execution device (`auto`, `cpu`, `cuda`, `cuda:N`, `mps`). | +| `multi_gpu` | bool | Enables multi-GPU selection logic when CUDA is available. Default: false. | +| `verbosity` | str | Logging verbosity level (DEBUG, INFO, INFO:n, WARNING, ERROR, CRITICAL). Default: INFO. | + +## [model] + +```toml +[model] +name = "myModel" +pretrained_path = "/path/to/model.pt" +max_ckpts = 1 +ckpts_path = "/path/to/checkpoint" +``` + +| Option | Type | Description | +|----------|------|-------------| +| `name` | str | Required, model identifier used by the example factory. | +| `pretrained_path` | str | Path to checkpoint weights. Default: none. | +| `max_ckpts` | int | Number of retained model checkpoints after CL. Default: 0 (no checkpoint is saved). | +| `ckpts_path` | str | Checkpoint output directory. | + + +## [data] + +```toml +[data] +name = "dataset" +path = "data.csv" +batch_size = 32 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `name` | str | Required, dataset/example identifier. | +| `path` | str | Reguired, dataset location. | +| `batch_size` | int | Streaming batch size. Default 1 (inference/drift detection is done on each individual sample). | + + +## [train] + +```toml +[train] +batch_size = 64 +num_workers = 4 +init_lr = 0.001 +grad_accumulation_steps = 1 +max_iter = 600 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `batch_size` | int | Required, taining batch size. | +| `num_workers` | int | Required, dataLoader workers. | +| `init_lr` | float | Required, initial learning rate. | +| `grad_accumulation_steps` | int | Gradient accumulation count. Default: 1. | +| `max_iter` | int | Maximum CL iterations. Default: 600. | + + +## [continual_learning] + +```toml +[continual_learning] +update_mode = "base" + +jvp_lambda = 0.001 +jvp_deltax_norm = 1 + +ewc_lambda = 1000.0 +ewc_ema_decay = 0.95 + +kfac_lambda = 0.01 +kfac_ema_decay = 0.95 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `update_mode` | str | Required, CL strategy to use (base, jvp_reg, ewc_online, kfac_online, none). | +| `jvp_lambda` | float | Weight for JVP regularization term (`jvp_reg` mode). Default: 0.001. | +| `jvp_deltax_norm` | float | Scale factor for JVP input perturbation direction. Default: 1. | +| `ewc_lambda` | float | EWC regularization strength (`ewc_online` mode). Default: 1000. | +| `ewc_ema_decay` | float | EMA decay for online Fisher prior in EWC. Default: 0.95. | +| `kfac_lambda` | float | KFAC penalty strength (`kfac_online` mode). Default: 0.01. | +| `kfac_ema_decay` | float | EMA decay for running Kronecker factors in KFAC mode. Default: 0.95. | + +Details about the CL algorithms available can be found in [docs/continuous_learning.md](continuous_learning.md) + + +## [drift_detection] + +All the parameters are optional. Default values are provided in the examples below. + +Core settings: + +```toml +[drift_detection] +detector_name = "ADWINDetector" +detection_interval = 10 +aggregation = "mean" +metric_index = 0 +reset_after_learning = false +max_stream_updates = 20 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `detector_name` | str | Drift detection algorithm (ADWINDetector, KSWINDetector, PageHinkleyDetector, ModelPerformanceDetector, ModelEvalDetector, EnsembleDetector). | +| `detection_interval` | int | Check drift every N monitored batches. If `<= 0`, checks are disabled. | +| `aggregation` | str | How buffered metric values are aggregated before detector update. Supported by monitor: `mean`, `median`, `last`. | +| `metric_index` | int | Index into `modelHarness.eval_metrics` order. | +| `reset_after_learning` | bool | If true, detector state resets after each CL event. | +| `max_stream_updates` | int | Monitoring stops after this many stream extensions. | + +ADWIN: + +```toml +adwin_delta = 0.002 +adwin_minor_threshold = 0.3 +adwin_moderate_threshold = 0.6 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `adwin_delta` | float | ADWIN confidence/sensitivity parameter. | +| `adwin_minor_threshold` | float | ADWIN regime threshold (CL boundary). | +| `adwin_moderate_threshold` | float | ADWIN regime threshold (fine-tuning boundary). | + +KSWIN: + +```toml +kswin_alpha = 0.005 +kswin_window_size = 100 +kswin_stat_size = 30 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `kswin_alpha` | float | KSWIN significance level. | +| `kswin_window_size` | int | KSWIN reference window size. | +| `kswin_stat_size` | int | KSWIN recent sample window size. | + +Page-Hinkley: + +```toml +ph_min_instances = 30 +ph_delta = 0.005 +ph_threshold = 50 +ph_alpha = 0.9999 +``` + +| Option | Type | Description | +|----------|------|-------------| +| `ph_min_instances` | int | Page-Hinkley warm-up count before detection is meaningful. | +| `ph_delta` | float | Page-Hinkley change magnitude parameter. | +| `ph_threshold` | float | Page-Hinkley trigger threshold. | +| `ph_alpha` | float | Page-Hinkley forgetting factor. | + + +Details about the drift detection algorithms available can be found in [docs/drift_detectors.md](drift_detectors.md) + + +## [visualization] + +The visualization configuration options are optional and are used to store the results of the metrics captured during the run. + +```toml +[visualization] +input = "output/results.csv" +``` + +| Option | Type | Description | +|----------|------|-------------| +| `input` | str | Path to the CSV output file storing the metrics. Default: output/output.csv. | + +Metrics used: +``` +cl/cperf_detector_flop +cl/cperf_detector_flops +cl/cperf_detector_time +cl/cperf_infer_flop +cl/cperf_infer_flops +cl/cperf_infer_time +cl/cperf_optimizer_flop +cl/cperf_optimizer_flops +cl/cperf_optimizer_time +cl/cperf_update_fwd_bwd_flop +cl/cperf_update_fwd_bwd_flops +cl/cperf_update_fwd_bwd_time +cl/drift_event_id +cl/jvp_reg_forgetting_loss +cl/jvp_reg_generation_loss +cl/jvp_reg_total_loss +cl/step +drift/confidence +drift/cperf_detector_flop +drift/cperf_detector_flops +drift/cperf_detector_time +drift/cperf_infer_flop +drift/cperf_infer_flops +drift/cperf_infer_time +drift/cperf_optimizer_flop +drift/cperf_optimizer_flops +drift/cperf_optimizer_time +drift/cperf_update_fwd_bwd_flop +drift/cperf_update_fwd_bwd_flops +drift/cperf_update_fwd_bwd_time +drift/detected +drift/metric_0 +drift/regime +drift/score +drift/step +eval/accuracy +eval/loss +eval/step +eval/test_curr_acc +eval/test_hist_acc +``` + +Example output file: +```csv +step,metric,value +10,eval/accuracy,62.5 +10,eval/loss,2.0406203269958496 +10,eval/step,10 +10,drift/score,0.0 +10,drift/regime,stable +10,drift/confidence,0.998 +10,drift/metric_0,44.6875 +10,drift/cperf_infer_flop,1624834048.0 +10,drift/cperf_infer_time,0.00556622925796546 +10,drift/cperf_infer_flops,291909293113.4319 +10,drift/cperf_detector_flop,0.0 +10,drift/cperf_detector_time,8.56249826028943e-05 +10,drift/cperf_detector_flops,0.0 +10,cl/jvp_reg_total_loss,3.4211268424987793 +10,cl/jvp_reg_forgetting_loss,0.0 +10,cl/jvp_reg_generation_loss,3.4211268424987793 +``` + +## [logging] + +All the parameters are optional. + +```toml +[logging] +backend = "wandb" +experiment_name = "experiment" +mlflow_tracking_uri = "http://localhost:5000" +``` + +| Option | Type | Description | +|----------|------|-------------| +| `backend` | str | Logging backend (wandb, mlflow, or none). Default: wandb. | +| `experiment_name` | str | W&B project name or MLflow experiment name. Default: None. | +| `mlflow_tracking_uri` | str | MLflow tracking server URI. Default: None. | + + +## Complete Valid Example + +```toml +seed = 1337 +device = "auto" +multi_gpu = false +verbosity = "INFO" + +[model] +name = "mnist" +pretrained_path = "examples/mnist/mnist.pth" +max_ckpts = 0 +ckpts_path = "output/mnist" + +[data] +name = "mnist" +path = "" +batch_size = 32 + +[train] +batch_size = 64 +num_workers = 4 +init_lr = 0.001 + +[continual_learning] +update_mode = "base" + +[drift_detection] +detector_name = "ADWINDetector" + +[logging] +backend = "none" + +[visualization] +input = "output/results.csv" +``` diff --git a/docs/continuous_learning.md b/docs/continuous_learning.md index 4d777f5..2495e83 100644 --- a/docs/continuous_learning.md +++ b/docs/continuous_learning.md @@ -126,3 +126,20 @@ ewc_ema_decay = 0.95 kfac_lambda = 1e-2 kfac_ema_decay = 0.95 ``` + +Codes wanting to do continual learning should use the `ContinuousTrainer` class that takes the configuration parameters, model harness, logger and the profiler. + +```python +cfg = build_config ( argv ) + +xtrainer = ContinuousTrainer( + cfg=cfg, + modelHarness=modelHarness, + logger=logger, + profiler=flops_profiler, +) + +trainer.outer_cl_training_loop(drift_event_id=drift_count) +if modelHarness.ckpts_enabled: + ckptpath = modelHarness.save_ckpt(event=drift_count) +``` diff --git a/docs/drift_detectors.md b/docs/drift_detectors.md index e0bc359..cbbce6c 100644 --- a/docs/drift_detectors.md +++ b/docs/drift_detectors.md @@ -182,3 +182,40 @@ adwin_delta = 0.002 adwin_minor_threshold = 0.3 adwin_moderate_threshold = 0.6 ``` + +Code for the workflow in the Monitor implementation: + +```python +detector = load_drift_detector(cfg) + +data = modelHarness.get_stream_dataloader() + +for batch_idx, batch in tqdm( + enumerate(val_loader), + desc="Inference on batches", + leave=False, +): + # Inference on batch and compute all metrics + metrics = self._evaluate_batch(batch) + metric_buffer.append(metrics) + + +metric_idx = cfg.drift_detection.metric_index +metric_values = [m[metric_idx] for m in metric_buffer] + +# aggregate metrics +aggregation = cfg.drift_detection.aggregation +if aggregation == "mean": + agg_metric = float(np.mean(metric_values)) +elif aggregation == "median": + agg_metric = float(np.median(metric_values)) +elif aggregation == "last": + agg_metric = float(metric_values[-1]) + +drift_signal = detector.update(agg_metric) +if drift_signal.drift_detected: + handle_drift(drift_signal) + +self.detector.reset() +modelHarness.update_data_stream() +``` diff --git a/docs/model_harness.md b/docs/model_harness.md index 7fe2f8f..7acad4c 100644 --- a/docs/model_harness.md +++ b/docs/model_harness.md @@ -114,3 +114,269 @@ Supported `model.name` values for CIFAR/ImageNet loaders: - `get_stream_dataloader()` must return a non-`None` loader after `update_data_stream()`. - `get_train_dataloaders()` must return non-`None` loaders after `update_data_stream()`. - If your batch is not exactly `(x, y)`, override `_unpack` in your harness. + +# Creating a Custom Model Harness + +A custom model harness is the integration point between a user model and Apeiron. The harness allows Apeiron to monitor model performance, detect drift, trigger continual learning, and manage retraining without requiring modifications to the underlying scientific model. + +## Required Files + +A custom example typically consists of: + +```text +examples// +├── __init__.py +├── model.py +├── utils.py +└── .toml +``` + +In addition, the example must be registered in: + +```text +examples/utils.py +``` + +through the `get_example()` factory. + +## Step 1: Create the Harness Class + +All custom harnesses must inherit from `BaseModelHarness`. + +```python +class MyHarness(BaseModelHarness): + def __init__(self, cfg): + model = MyModel() + super().__init__(cfg=cfg, model=model) +``` + +The base class handles: + +- Device placement +- Checkpoint support +- Evaluation loops +- Continual-learning integration +- Drift-monitor integration + +## Step 2: Define Evaluation Metrics + +Apeiron's drift detectors operate on metric streams. + +At least one metric should be exposed through: + +```python +self.eval_metrics = { + "accuracy": accuracy, +} +``` + +or for regression: + +```python +self.eval_metrics = { + "mse": regression_mse, +} +``` + +The order of metrics is important because drift detection uses: + +```toml +metric_index = 0 +``` + +which refers to the first metric in `eval_metrics`. + +## Step 3: Implement Required Methods + +### Optimizer + +```python +def get_optmizer(self): + return torch.optim.Adam( + self.model.parameters(), + lr=self.cfg.train.init_lr, + ) +``` + +### Criterion + +```python +def get_criterion(self): + return torch.nn.CrossEntropyLoss() +``` + +or for regression: + +```python +def get_criterion(self): + return torch.nn.MSELoss() +``` + +### Current Training Loaders + +```python +def get_train_dataloaders(self): + return self.train_loader, self.val_loader +``` + +These loaders are used during continual-learning updates. + +### Stream Loader + +```python +def get_stream_dataloader(self): + return self.stream_loader +``` + +This loader is used during monitoring and drift detection. + +### Historical Loaders + +```python +def get_hist_dataloaders(self): + if self.task_counter == 0: + return None, None + + return self.hist_train_loader, self.hist_val_loader +``` + +Historical loaders provide replay data for continual-learning algorithms such as EWC and KFAC. + +### Stream Updates + +```python +def update_data_stream(self): + self.task_counter += 1 + + self.train_loader = build_train_loader() + self.val_loader = build_val_loader() + self.stream_loader = build_stream_loader() +``` + +This function is called by the monitor whenever a new stream segment becomes active. + +## Step 4: Build Dataset Utilities + +Dataset loading should be implemented in `utils.py`. + +Typical responsibilities include: + +- Reading datasets +- Feature preprocessing +- Label extraction +- Train/validation splitting +- Drift simulation +- DataLoader construction + +Example: + +```python +def make_loader(dataset, batch_size): + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + ) +``` + +## Step 5: Register the Example + +Add a branch to `examples/utils.py`: + +```python +elif cfg.data.name == "mydataset": + from examples.mydataset.model import MyHarness + + return MyHarness(cfg=cfg) +``` + +The value must match: + +```toml +[data] +name = "mydataset" +``` + +## Step 6: Create the Configuration File + +Create: + +```text +examples//.toml +``` + +Requirested and optional parameters are described in [configurations.md](configurations.md). + +Only keys defined in `configuration.py` are allowed. Custom dataset-specific parameters should be implemented inside the harness or utility code unless the configuration dataclasses are extended. + +## Step 7: Validate the Integration + +Validate TOML syntax: + +```bash +python -c "import tomllib; tomllib.load(open('examples/mydataset/mydataset.toml', 'rb')); print('TOML OK')" +``` + +Validate factory registration: + +```bash +poetry run python -c "from examples.utils import get_example; print('factory OK')" +``` + +## Step 8: Run a Smoke Test + +Before running a full experiment, execute a small CPU-only test: + +```bash +poetry run python -m src.main \ + --config examples/mydataset/mydataset.toml \ + --set train.max_iter=2 \ + --set drift_detection.max_stream_updates=2 \ + --set drift_detection.detection_interval=1 \ + --set device=cpu \ + --set logging.backend=none +``` + +A successful smoke test confirms: + +- Configuration loading +- Harness construction +- Dataset loading +- Drift-monitor integration +- Continual-learning integration +- Metric reporting + +## Common Errors + +### `KeyError: 'model'` + +The TOML file is missing the `[model]` section. + +### `ImportError` when loading the harness + +The factory registration in `examples/utils.py` is missing or incorrect. + +### `TypeError: DataCfg.__init__() got an unexpected keyword argument ...` + +A TOML key is not defined in the configuration dataclasses. + +### `IndexError: list index out of range` during drift detection + +No evaluation metrics are being emitted. Ensure `self.eval_metrics` contains at least one metric. + +### `get_stream_dataloader()` returns `None` + +The monitor cannot evaluate data until `update_data_stream()` constructs and returns a valid stream loader. + +## Summary + +A custom harness only needs to provide six core behaviors: + +1. Construct the model. +2. Construct the optimizer. +3. Construct the criterion. +4. Provide current train/validation loaders. +5. Provide historical replay loaders. +6. Update the active data stream. + +Once those interfaces are implemented, Apeiron can automatically provide monitoring, drift detection, continual learning, checkpointing, logging, and experiment management for the underlying scientific model. diff --git a/src/apeiron/config/configuration.py b/src/apeiron/config/configuration.py index df55747..e79a566 100644 --- a/src/apeiron/config/configuration.py +++ b/src/apeiron/config/configuration.py @@ -86,8 +86,7 @@ def _select_best_gpu() -> int | None: @dataclass(frozen=True) class ModelCfg: name: str - pretrained_path: str - + pretrained_path: str = "" # FIFO checkpointing: 0 disables, N keeps last N post-CL snapshots max_ckpts: int = 0 ckpts_path: str = "" @@ -156,7 +155,7 @@ class DriftDetectionCfg: @dataclass(frozen=True) class VisualizationCfg: - input: str = "output/cl_only.csv" # CSV path where run metrics are written + input: str = "output/output.csv" # CSV path where run metrics are written @dataclass(frozen=True) @@ -178,7 +177,7 @@ class Config: seed: int device: str - multi_gpu: bool + multi_gpu: bool = False verbosity: str = "INFO" visualization: VisualizationCfg | None = None logging: LoggingCfg | None = None diff --git a/tests/test_config.py b/tests/test_config.py index 7c90b0b..2f88468 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -278,7 +278,7 @@ def test_drift_detection_defaults(self): def test_visualization_cfg(self): viz = VisualizationCfg() - assert viz.input == "output/cl_only.csv" + assert viz.input == "output/output.csv" # ---------------------------------------------------------------------------