diff --git a/examples/structural_mechanics/ensemble_uq/ensemble_uq_beam.py b/examples/structural_mechanics/ensemble_uq/ensemble_uq_beam.py new file mode 100644 index 0000000000..9b7e7d6b3e --- /dev/null +++ b/examples/structural_mechanics/ensemble_uq/ensemble_uq_beam.py @@ -0,0 +1,229 @@ +r""" +examples/structural_mechanics/ensemble_uq/ensemble_uq_beam.py + +Ensemble uncertainty quantification for a 1D beam deflection surrogate +======================================================================= + +Demonstrates ``EnsembleWrapper`` on a structural mechanics problem: +predicting the deflection of a simply-supported Euler-Bernoulli beam +under a distributed load, as a function of load magnitude and beam +stiffness (EI). + +This example deliberately keeps the training data small so that the +ensemble spread is visible — illustrating that ``std`` grows in regions +where the surrogate is uncertain (e.g. near the edges of the training +distribution). + +Teacher +------- +Analytical solution for simply-supported beam deflection under uniform load: + +.. math:: + + w(x) = \\frac{q}{24 EI} \\left( x^4 - 2Lx^3 + L^3 x \\right) + +where :math:`L = 1\\,\\mathrm{m}`, :math:`x \\in [0, L]`. + +The surrogate maps :math:`(q, EI)` to the maximum deflection +:math:`w_{\\max} = w(L/2)`. + +Ensemble +-------- +5 ``FullyConnected`` members, each trained from a different random seed +using a standard PyTorch loop. Wrapped with ``EnsembleWrapper`` for +uncertainty-aware inference. + +Dependencies +------------ +See ``requirements.txt`` in this directory:: + + pip install -r examples/structural_mechanics/ensemble_uq/requirements.txt + +Run +--- +From the repo root:: + + python examples/structural_mechanics/ensemble_uq/ensemble_uq_beam.py +""" + +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from physicsnemo.models.mlp import FullyConnected +from physicsnemo.experimental.models.ensemble_wrapper import EnsembleWrapper + +# --------------------------------------------------------------------------- +# Teacher: analytical beam deflection +# --------------------------------------------------------------------------- + +L = 1.0 # beam length [m] + + +def beam_max_deflection(q: float, EI: float) -> float: + r""" + Maximum mid-span deflection of a simply-supported beam. + + .. math:: + + w_{\max} = \frac{5 q L^4}{384 EI} + + Parameters + ---------- + q : float + Uniform distributed load :math:`[\mathrm{N/m}]`. + EI : float + Flexural rigidity :math:`[\mathrm{N \cdot m^2}]`. + + Returns + ------- + float + Maximum deflection :math:`[\mathrm{m}]`. + """ + return (5 * q * L**4) / (384 * EI) + + +# --------------------------------------------------------------------------- +# Dataset generation +# --------------------------------------------------------------------------- + +N_TRAIN = 80 +N_TEST = 400 + +Q_BOUNDS = (100.0, 5000.0) # N/m +EI_BOUNDS = (1e4, 1e6) # N·m² + +rng = np.random.default_rng(0) + + +def make_dataset(n: int, seed: int) -> tuple: + rng_ = np.random.default_rng(seed) + q = rng_.uniform(*Q_BOUNDS, n) + EI = rng_.uniform(*EI_BOUNDS, n) + X = np.stack([q, EI], axis=1).astype(np.float32) + y = np.array([beam_max_deflection(q[i], EI[i]) for i in range(n)], + dtype=np.float32)[:, None] + return X, y + + +X_train_np, y_train_np = make_dataset(N_TRAIN, seed=1) +X_test_np, y_test_np = make_dataset(N_TEST, seed=2) + +# Z-score normalisation +X_mean, X_std = X_train_np.mean(0), X_train_np.std(0) + 1e-8 +y_mean, y_std = y_train_np.mean(), y_train_np.std() + 1e-8 + +X_train = torch.tensor((X_train_np - X_mean) / X_std) +y_train = torch.tensor((y_train_np - y_mean) / y_std) +X_test = torch.tensor((X_test_np - X_mean) / X_std) +y_test = torch.tensor((y_test_np - y_mean) / y_std) + + +# --------------------------------------------------------------------------- +# Train N ensemble members +# --------------------------------------------------------------------------- + +N_MEMBERS = 5 +EPOCHS = 500 +LR = 5e-3 + + +def train_member(seed: int) -> FullyConnected: + r"""Train one ``FullyConnected`` member from a given random seed.""" + torch.manual_seed(seed) + model = FullyConnected(in_features=2, out_features=1, num_layers=3, layer_size=32) + loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True) + opt = torch.optim.Adam(model.parameters(), lr=LR) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS) + criterion = nn.MSELoss() + + model.train() + for _ in range(EPOCHS): + for xb, yb in loader: + opt.zero_grad() + criterion(model(xb), yb).backward() + opt.step() + sched.step() + + model.eval() + return model + + +print(f"Training {N_MEMBERS} ensemble members...") +members = [train_member(seed=i) for i in range(N_MEMBERS)] +print(" Done.") + +# --------------------------------------------------------------------------- +# Wrap with EnsembleWrapper +# --------------------------------------------------------------------------- + +ensemble = EnsembleWrapper(members) +ensemble.eval() + +with torch.no_grad(): + result = ensemble.predict_with_uncertainty(X_test) + +# Denormalise +mean_np = (result.mean.numpy() * y_std + y_mean).squeeze() +std_np = (result.std.numpy() * y_std).squeeze() # std in physical units +true_np = y_test_np.squeeze() + +rel_err = float(np.mean(np.abs(mean_np - true_np) / (np.abs(true_np) + 1e-10))) +print(f"\nEnsemble mean relative error : {rel_err:.2%}") +print(f"Mean epistemic std : {std_np.mean():.4e} m") + +# --------------------------------------------------------------------------- +# Diagnostic plots +# --------------------------------------------------------------------------- + +# Sort by true deflection for clean 1-D plots +order = np.argsort(true_np) +t = true_np[order] +m = mean_np[order] +s = std_np[order] + +fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +fig.suptitle("Beam Deflection Surrogate — Ensemble UQ", fontsize=13) + +# ── Panel 1: parity plot ────────────────────────────────────────────────── +ax = axes[0] +ax.scatter(t, m, s=10, alpha=0.3, label="Test samples") +lim = [t.min() * 0.95, t.max() * 1.05] +ax.plot(lim, lim, "r--", lw=1, label="Ideal") +ax.set_xlabel("Analytical deflection [m]") +ax.set_ylabel("Ensemble mean [m]") +ax.set_title(f"Parity Plot (err = {rel_err:.1%})") +ax.legend(fontsize=9) +ax.grid(True, alpha=0.3) + +# ── Panel 2: uncertainty ribbon ─────────────────────────────────────────── +ax = axes[1] +ax.plot(t, m, lw=1, label="Ensemble mean") +ax.fill_between(t, m - 2 * s, m + 2 * s, alpha=0.3, label="±2σ (epistemic)") +ax.plot(t, t, "r--", lw=1, label="Ground truth") +ax.set_xlabel("True deflection (sorted) [m]") +ax.set_ylabel("Predicted deflection [m]") +ax.set_title("Uncertainty Ribbon (±2σ)") +ax.legend(fontsize=9) +ax.grid(True, alpha=0.3) + +# ── Panel 3: std vs prediction magnitude ───────────────────────────────── +ax = axes[2] +ax.scatter(t, s, s=10, alpha=0.3, color="steelblue") +ax.set_xlabel("True deflection [m]") +ax.set_ylabel("Epistemic std [m]") +ax.set_title("Uncertainty vs Prediction Magnitude") +ax.grid(True, alpha=0.3) + +plt.tight_layout() +out_path = Path("ensemble_uq_beam.png") +plt.savefig(out_path, dpi=150, bbox_inches="tight") +print(f"\nPlot saved → {out_path.resolve()}") +plt.show() diff --git a/examples/structural_mechanics/ensemble_uq/requirements.txt b/examples/structural_mechanics/ensemble_uq/requirements.txt new file mode 100644 index 0000000000..1460e30fd8 --- /dev/null +++ b/examples/structural_mechanics/ensemble_uq/requirements.txt @@ -0,0 +1,8 @@ +# requirements.txt +# examples/structural_mechanics/ensemble_uq/ +# +# Following EXT-001: example-only dependencies must not leak into the +# core package. Install with: +# pip install -r examples/structural_mechanics/ensemble_uq/requirements.txt + +matplotlib>=3.7.0 diff --git a/physicsnemo/experimental/models/ensemble_wrapper.py b/physicsnemo/experimental/models/ensemble_wrapper.py new file mode 100644 index 0000000000..f93818dc0f --- /dev/null +++ b/physicsnemo/experimental/models/ensemble_wrapper.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +physicsnemo/experimental/models/ensemble_wrapper.py + +Model-agnostic ensemble wrapper for uncertainty quantification. + +Motivation +---------- +PhysicsNeMo 25.08 introduced ensemble-based confidence estimation as a +Jupyter notebook workflow in ``physicsnemo-cfd``, scoped to the DoMINO +automotive aerodynamics NIM. This module promotes that pattern to a +model-agnostic, reusable utility in the core library: any +``physicsnemo.Module`` can be wrapped for ensemble-based uncertainty +quantification in two lines. + +Usage +----- + +.. code-block:: python + + from physicsnemo.experimental.models.ensemble_wrapper import EnsembleWrapper + + # Train N models with different seeds using standard PhysicsNeMo loops + models = [train_my_model(seed=i) for i in range(5)] + + # Wrap for uncertainty-aware inference + ensemble = EnsembleWrapper(models) + + # Drop-in forward (returns mean — compatible with any existing code) + mean = ensemble(x) + + # Uncertainty-aware inference + result = ensemble.predict_with_uncertainty(x) + print(result.mean.shape) # same as single-model output + print(result.std.shape) # epistemic uncertainty estimate + print(result.predictions.shape) # (N, *output_shape) all member outputs + + # Load from saved checkpoints + ensemble = EnsembleWrapper.from_checkpoints( + model_cls=MyModel, + checkpoint_paths=["ckpt_0.pt", "ckpt_1.pt", "ckpt_2.pt"], + **model_init_kwargs, + ) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import List, Type, Union + +import torch +import torch.nn as nn + +from physicsnemo.core import Module +from physicsnemo.core import ModelMetaData + + +# --------------------------------------------------------------------------- +# Metadata (required by MOD-001 / physicsnemo.Module convention) +# --------------------------------------------------------------------------- + + +@dataclass +class EnsembleWrapperMeta(ModelMetaData): + r"""Metadata for ``EnsembleWrapper``. + + Attributes + ---------- + name : str + Human-readable model name. + jit : bool + Whether the model supports TorchScript JIT compilation. + ``False`` because member models may not individually support JIT. + cuda_graphs : bool + Whether the model supports CUDA graphs. ``False`` by default; + set to ``True`` only when all member models support CUDA graphs. + amp_cpu : bool + Whether the model supports Automatic Mixed Precision on CPU. + amp_gpu : bool + Whether the model supports Automatic Mixed Precision on GPU. + """ + + name: str = "EnsembleWrapper" + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = True + amp_gpu: bool = True + + +# --------------------------------------------------------------------------- +# Result dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class EnsemblePrediction: + r""" + Output of ``EnsembleWrapper.predict_with_uncertainty()``. + + Attributes + ---------- + mean : torch.Tensor + Element-wise mean over all ensemble members, + shape :math:`(B, \ldots)`. + std : torch.Tensor + Element-wise standard deviation over all ensemble members — + an estimate of epistemic uncertainty — shape :math:`(B, \ldots)`. + predictions : torch.Tensor + Stacked raw outputs from all :math:`N` ensemble members, + shape :math:`(N, B, \ldots)`. + + Notes + ----- + The standard deviation here captures **epistemic** (model) uncertainty + arising from different weight initialisations and training trajectories. + It does not capture aleatoric (data) uncertainty. For a decomposed + estimate, consider combining ``EnsembleWrapper`` with members that + themselves output predictive distributions (e.g. models with a learned + variance head). + """ + + mean: torch.Tensor + std: torch.Tensor + predictions: torch.Tensor + + +# --------------------------------------------------------------------------- +# Main class +# --------------------------------------------------------------------------- + + +class EnsembleWrapper(Module): + r""" + Model-agnostic ensemble wrapper for uncertainty quantification. + + Wraps a collection of independently trained ``physicsnemo.Module`` + instances. At inference, each member is queried and the wrapper + returns the element-wise mean (via ``forward``) and standard deviation + (via ``predict_with_uncertainty``) across member outputs. + + The wrapper itself is a ``physicsnemo.Module`` (following rule MOD-001), + so it supports PhysicsNeMo checkpointing, versioning, and the model + registry. Member models are stored as a ``torch.nn.ModuleList`` so + their parameters are included in ``state_dict`` and moved together + with ``.to(device)``. + + ``forward`` returns only the mean prediction so that ``EnsembleWrapper`` + is a **drop-in replacement** for any single model — existing training + loops, metrics, and inference scripts require no changes. + + Parameters + ---------- + models : list of Module + Pre-trained ``physicsnemo.Module`` instances forming the ensemble. + All members must accept the same input signature and return tensors + of the same shape. + + Raises + ------ + ValueError + If ``models`` is empty. + + Examples + -------- + >>> import torch + >>> from physicsnemo.nn import FullyConnected + >>> from physicsnemo.experimental.models.ensemble_wrapper import EnsembleWrapper + >>> + >>> # Build and (notionally) train 5 members with different seeds + >>> members = [FullyConnected(in_features=4, out_features=1) for _ in range(5)] + >>> ensemble = EnsembleWrapper(members) + >>> + >>> x = torch.randn(32, 4) + >>> + >>> # Drop-in forward: returns mean, shape (32, 1) + >>> mean = ensemble(x) + >>> + >>> # Uncertainty-aware forward + >>> result = ensemble.predict_with_uncertainty(x) + >>> result.mean.shape, result.std.shape + (torch.Size([32, 1]), torch.Size([32, 1])) + + See Also + -------- + EnsemblePrediction : Output dataclass for ``predict_with_uncertainty``. + EnsembleWrapper.from_checkpoints : Construct from saved checkpoint files. + """ + + def __init__(self, models: List[Module]) -> None: + if len(models) == 0: + raise ValueError( + "EnsembleWrapper requires at least one member model. " + "Received an empty list." + ) + super().__init__(meta=EnsembleWrapperMeta()) + # Store members as ModuleList so parameters are properly tracked + # and the ensemble can be moved with .to(device) in one call. + self.members: nn.ModuleList = nn.ModuleList(models) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def n_members(self) -> int: + r"""Number of models in the ensemble.""" + return len(self.members) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Run ensemble inference and return the mean prediction. + + Returns only the mean so that ``EnsembleWrapper`` is a drop-in + replacement for a single model. For the full uncertainty estimate + use ``predict_with_uncertainty``. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, \ldots)`. + + Returns + ------- + torch.Tensor + Element-wise mean over all ensemble members, + shape :math:`(B, \ldots)`. + """ + return self.predict_with_uncertainty(x).mean + + # ------------------------------------------------------------------ + # Uncertainty-aware inference + # ------------------------------------------------------------------ + + def predict_with_uncertainty(self, x: torch.Tensor) -> EnsemblePrediction: + r""" + Run ensemble inference and return mean, std, and all member outputs. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, \ldots)`. + + Returns + ------- + EnsemblePrediction + Dataclass with fields: + + - ``mean`` — element-wise mean, shape :math:`(B, \ldots)` + - ``std`` — element-wise standard deviation (epistemic + uncertainty), shape :math:`(B, \ldots)` + - ``predictions`` — stacked member outputs, + shape :math:`(N, B, \ldots)` + """ + # Stack outputs from all N members along a new leading dimension. + # Shape: (N, B, *output_shape) + predictions = torch.stack([member(x) for member in self.members], dim=0) + + return EnsemblePrediction( + mean=predictions.mean(dim=0), + # correction=0: population std (MLE) — ensures std=0 for N=1, + # avoids NaN from Bessel's correction when ensemble size is 1. + std=predictions.std(dim=0, correction=0), + predictions=predictions, + ) + + # ------------------------------------------------------------------ + # Checkpoint loading + # ------------------------------------------------------------------ + + @classmethod + def from_checkpoints( + cls, + model_cls: Type[Module], + checkpoint_paths: List[Union[str, Path]], + map_location: Union[str, torch.device, None] = None, + **model_kwargs, + ) -> "EnsembleWrapper": + r""" + Construct an ``EnsembleWrapper`` from saved checkpoint files. + + Each checkpoint must have been saved with + ``torch.save(model.state_dict(), path)`` or via PhysicsNeMo's + built-in checkpoint utilities. + + Parameters + ---------- + model_cls : type + The ``physicsnemo.Module`` subclass used for each member. + checkpoint_paths : list of str or Path + Paths to the saved ``state_dict`` files, one per ensemble member. + map_location : str, torch.device, or None, optional + Passed directly to ``torch.load``. Useful for loading + GPU-trained checkpoints on CPU. Default: ``None``. + **model_kwargs + Keyword arguments forwarded to ``model_cls.__init__``. + + Returns + ------- + EnsembleWrapper + Ensemble with one member per checkpoint. + + Raises + ------ + FileNotFoundError + If any of the provided checkpoint paths does not exist. + + Examples + -------- + >>> ensemble = EnsembleWrapper.from_checkpoints( + ... model_cls=FullyConnected, + ... checkpoint_paths=["ckpt_0.pt", "ckpt_1.pt", "ckpt_2.pt"], + ... map_location="cpu", + ... in_features=4, + ... out_features=1, + ... ) + """ + models = [] + for path in checkpoint_paths: + path = Path(path) + if not path.exists(): + raise FileNotFoundError( + f"EnsembleWrapper.from_checkpoints: checkpoint not found at '{path}'." + ) + model = model_cls(**model_kwargs) + state = torch.load(path, map_location=map_location, weights_only=True) + model.load_state_dict(state) + model.eval() + models.append(model) + return cls(models) diff --git a/test/models/test_ensemble_wrapper.py b/test/models/test_ensemble_wrapper.py new file mode 100644 index 0000000000..537e7665a8 --- /dev/null +++ b/test/models/test_ensemble_wrapper.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +r""" +test/models/test_ensemble_wrapper.py + +Unit tests for ``physicsnemo.experimental.models.ensemble_wrapper``. + +Following rules MOD-008a, MOD-008b, MOD-008c: +- MOD-008a: constructor / attribute tests +- MOD-008b: non-regression test with reference data +- MOD-008c: checkpoint loading test + +Run with:: + + pytest test/models/test_ensemble_wrapper.py -v +""" + +import tempfile +from pathlib import Path + +import pytest +import torch + +from physicsnemo.models.mlp import FullyConnected +from physicsnemo.experimental.models.ensemble_wrapper import ( + EnsembleWrapper, + EnsemblePrediction, + EnsembleWrapperMeta, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +IN_FEATURES = 4 +OUT_FEATURES = 2 +BATCH_SIZE = 16 +N_MEMBERS = 3 + + +def make_members(n: int, seed_offset: int = 0) -> list: + """Build N small FullyConnected members with deterministic seeds.""" + members = [] + for i in range(n): + torch.manual_seed(i + seed_offset) + m = FullyConnected( + in_features=IN_FEATURES, + out_features=OUT_FEATURES, + num_layers=2, + layer_size=16, + ) + m.eval() + members.append(m) + return members + + +def make_input() -> torch.Tensor: + torch.manual_seed(99) + return torch.randn(BATCH_SIZE, IN_FEATURES) + + +# --------------------------------------------------------------------------- +# MOD-008a: Constructor and attribute tests +# --------------------------------------------------------------------------- + + +class TestEnsembleWrapperConstructor: + r"""Following MOD-008a: constructor and attribute tests.""" + + def test_construction_succeeds(self): + """EnsembleWrapper should construct from a non-empty member list.""" + members = make_members(N_MEMBERS) + ensemble = EnsembleWrapper(members) + assert isinstance(ensemble, EnsembleWrapper) + + def test_n_members_property(self): + """``n_members`` should reflect the number of members passed.""" + for n in [1, 3, 5]: + members = make_members(n) + assert EnsembleWrapper(members).n_members == n + + def test_empty_members_raises(self): + """Passing an empty list should raise ``ValueError``.""" + with pytest.raises(ValueError, match="at least one"): + EnsembleWrapper([]) + + def test_metadata_type(self): + """The wrapper's meta should be an ``EnsembleWrapperMeta`` instance.""" + ensemble = EnsembleWrapper(make_members(2)) + assert isinstance(ensemble.meta, EnsembleWrapperMeta) + + def test_members_registered_as_module_list(self): + """Members must be stored in a ``torch.nn.ModuleList`` so that + ``.parameters()`` and ``.to(device)`` cover all members.""" + import torch.nn as nn + ensemble = EnsembleWrapper(make_members(N_MEMBERS)) + assert isinstance(ensemble.members, nn.ModuleList) + assert len(ensemble.members) == N_MEMBERS + + def test_to_device_moves_all_members(self): + """Calling ``.to(device)`` should move all member parameters.""" + ensemble = EnsembleWrapper(make_members(2)) + ensemble.to("cpu") + for member in ensemble.members: + for param in member.parameters(): + assert param.device.type == "cpu" + + +# --------------------------------------------------------------------------- +# MOD-008b: Non-regression tests with reference data +# --------------------------------------------------------------------------- + + +class TestEnsembleWrapperForward: + r"""Following MOD-008b: non-regression tests with reference data.""" + + def test_forward_output_shape(self): + """``forward`` should return a tensor of shape ``(B, out_features)``.""" + ensemble = EnsembleWrapper(make_members(N_MEMBERS)) + x = make_input() + with torch.no_grad(): + out = ensemble(x) + assert out.shape == (BATCH_SIZE, OUT_FEATURES) + + def test_forward_returns_mean(self): + """``forward`` must equal ``predict_with_uncertainty().mean``.""" + ensemble = EnsembleWrapper(make_members(N_MEMBERS)) + x = make_input() + with torch.no_grad(): + fwd = ensemble(x) + uq = ensemble.predict_with_uncertainty(x) + torch.testing.assert_close(fwd, uq.mean) + + def test_predict_with_uncertainty_shapes(self): + """``predict_with_uncertainty`` should return correct tensor shapes.""" + ensemble = EnsembleWrapper(make_members(N_MEMBERS)) + x = make_input() + with torch.no_grad(): + result = ensemble.predict_with_uncertainty(x) + + assert isinstance(result, EnsemblePrediction) + assert result.mean.shape == (BATCH_SIZE, OUT_FEATURES) + assert result.std.shape == (BATCH_SIZE, OUT_FEATURES) + assert result.predictions.shape == (N_MEMBERS, BATCH_SIZE, OUT_FEATURES) + + def test_std_non_negative(self): + """Standard deviation values must be non-negative everywhere.""" + ensemble = EnsembleWrapper(make_members(N_MEMBERS)) + x = make_input() + with torch.no_grad(): + result = ensemble.predict_with_uncertainty(x) + assert (result.std >= 0).all() + + def test_std_zero_for_identical_members(self): + """When all members are identical the std should be (near) zero.""" + torch.manual_seed(0) + single = FullyConnected(in_features=IN_FEATURES, out_features=OUT_FEATURES, + num_layers=2, layer_size=16) + # Three copies of the exact same weights + members = [single, single, single] + ensemble = EnsembleWrapper(members) + x = make_input() + with torch.no_grad(): + result = ensemble.predict_with_uncertainty(x) + assert result.std.abs().max().item() < 1e-5 + + def test_std_nonzero_for_different_members(self): + """Different weight initialisations should yield non-zero std.""" + ensemble = EnsembleWrapper(make_members(N_MEMBERS, seed_offset=100)) + x = make_input() + with torch.no_grad(): + result = ensemble.predict_with_uncertainty(x) + assert result.std.mean().item() > 0 + + def test_mean_equals_average_of_predictions(self): + """``mean`` must be the arithmetic mean of ``predictions``.""" + ensemble = EnsembleWrapper(make_members(N_MEMBERS)) + x = make_input() + with torch.no_grad(): + result = ensemble.predict_with_uncertainty(x) + expected_mean = result.predictions.mean(dim=0) + torch.testing.assert_close(result.mean, expected_mean) + + def test_single_member_std_is_zero(self): + """With one member the std must be zero (no variance to estimate).""" + ensemble = EnsembleWrapper(make_members(1)) + x = make_input() + with torch.no_grad(): + result = ensemble.predict_with_uncertainty(x) + assert result.std.abs().max().item() == 0.0 + + +# --------------------------------------------------------------------------- +# MOD-008c: Checkpoint loading test +# --------------------------------------------------------------------------- + + +class TestEnsembleWrapperCheckpoints: + r"""Following MOD-008c: checkpoint loading tests.""" + + def test_from_checkpoints_loads_correctly(self, tmp_path): + """``from_checkpoints`` should reproduce an ensemble identical to the + original when weights are saved and reloaded.""" + members = make_members(N_MEMBERS) + original = EnsembleWrapper(members) + + # Save each member's state_dict + paths = [] + for i, member in enumerate(members): + p = tmp_path / f"member_{i}.pt" + torch.save(member.state_dict(), p) + paths.append(p) + + # Reload via from_checkpoints + loaded = EnsembleWrapper.from_checkpoints( + model_cls=FullyConnected, + checkpoint_paths=paths, + map_location="cpu", + in_features=IN_FEATURES, + out_features=OUT_FEATURES, + num_layers=2, + layer_size=16, + ) + + # Predictions must be identical + x = make_input() + with torch.no_grad(): + r_orig = original.predict_with_uncertainty(x) + r_loaded = loaded.predict_with_uncertainty(x) + + torch.testing.assert_close(r_orig.mean, r_loaded.mean) + torch.testing.assert_close(r_orig.std, r_loaded.std) + + def test_from_checkpoints_missing_file_raises(self, tmp_path): + """A ``FileNotFoundError`` must be raised for a missing checkpoint.""" + with pytest.raises(FileNotFoundError, match="checkpoint not found"): + EnsembleWrapper.from_checkpoints( + model_cls=FullyConnected, + checkpoint_paths=[tmp_path / "does_not_exist.pt"], + in_features=IN_FEATURES, + out_features=OUT_FEATURES, + num_layers=2, + layer_size=16, + ) + + def test_from_checkpoints_n_members(self, tmp_path): + """Ensemble size must match the number of checkpoint paths.""" + members = make_members(4) + paths = [] + for i, m in enumerate(members): + p = tmp_path / f"m_{i}.pt" + torch.save(m.state_dict(), p) + paths.append(p) + + loaded = EnsembleWrapper.from_checkpoints( + model_cls=FullyConnected, + checkpoint_paths=paths, + map_location="cpu", + in_features=IN_FEATURES, + out_features=OUT_FEATURES, + num_layers=2, + layer_size=16, + ) + assert loaded.n_members == 4