diff --git a/docs/api/regime.md b/docs/api/regime.md new file mode 100644 index 0000000..fa92578 --- /dev/null +++ b/docs/api/regime.md @@ -0,0 +1,16 @@ +# Regime Surrogate + +Regime-conditional surrogate that interpolates factor recommendations +across regime descriptors (e.g. dataset size, noise level) instead of +relying on hard regime buckets. Builds on +[`fit_surrogate`][trade_study.fit_surrogate]. + +Install via the optional extra (same as the base surrogate): + +```bash +uv pip install 'trade-study[surrogate]' +``` + +::: trade_study.fit_regime_surrogate + +::: trade_study.RegimeSurrogate diff --git a/mkdocs.yml b/mkdocs.yml index e1ffbd6..fd43110 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,5 +75,6 @@ nav: - Pareto: api/pareto.md - Stacking: api/stacking.md - Surrogate: api/surrogate.md + - Regime Surrogate: api/regime.md - Visualization: api/viz.md - I/O: api/io.md diff --git a/src/trade_study/__init__.py b/src/trade_study/__init__.py index 0e1e514..28cfe6b 100644 --- a/src/trade_study/__init__.py +++ b/src/trade_study/__init__.py @@ -26,6 +26,7 @@ Simulator, TrialResult, ) +from .regime import RegimeSurrogate, fit_regime_surrogate from .runner import run_adaptive, run_grid, run_hyperband, run_successive_halving from .stacking import ensemble_predict, stack_bayesian, stack_scores from .study import ( @@ -48,6 +49,7 @@ "Observable", "PartialEvaluator", "Phase", + "RegimeSurrogate", "ResultsTable", "Scorer", "Simulator", @@ -60,6 +62,7 @@ "ensemble_predict", "extract_front", "feasibility_filter", + "fit_regime_surrogate", "fit_surrogate", "hypervolume", "igd_plus", diff --git a/src/trade_study/regime.py b/src/trade_study/regime.py new file mode 100644 index 0000000..036a99c --- /dev/null +++ b/src/trade_study/regime.py @@ -0,0 +1,263 @@ +"""Regime-conditional surrogate models (#105). + +Builds on :mod:`trade_study.surrogate` to share information across +regime buckets. Regime descriptors (e.g. ``n_samples``, ``noise``) are +treated as additional input dimensions of a single surrogate over +``regime_factors + factors``, so the model can interpolate factor +recommendations across continuous regime axes instead of relying on +hard buckets. + +Typical use: + +.. code-block:: python + + surrogate = fit_regime_surrogate( + results, + regime_factors=[ + Factor( + "n_samples", FactorType.CONTINUOUS, bounds=(1_000, 10_000) + ) + ], + factors=[Factor("lr", FactorType.CONTINUOUS, bounds=(1e-4, 1e-1))], + method="gp", + ) + best = surrogate.recommend({"n_samples": 2200}, objective="val_loss") + +Optional dependency: install via the ``trade-study[surrogate]`` extra. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np + +from .design import Factor, build_grid +from .surrogate import SurrogateModel, fit_surrogate + +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import NDArray + + from .protocols import ResultsTable + + +_SUPPORTED_MODES: frozenset[str] = frozenset({"min", "max"}) + + +def _merge(regime: dict[str, Any], cfg: dict[str, Any]) -> dict[str, Any]: + """Merge a regime dict into a factor config dict (regime keys win). + + Args: + regime: Mapping of regime-feature names to values. + cfg: Mapping of design-factor names to values. + + Returns: + New dict containing both sets of keys. + """ + out = dict(cfg) + out.update(regime) + return out + + +@dataclass +class RegimeSurrogate: + """Surrogate that conditions on regime features. + + Wraps a single :class:`SurrogateModel` fit over the union of regime + descriptors and design factors. Use :func:`fit_regime_surrogate` to + construct one. + + Attributes: + inner: The underlying :class:`SurrogateModel` over the joint + ``regime_factors + factors`` input space. + regime_factors: Factors describing the regime (additional input + dimensions of the surrogate). + factors: Tunable design factors that are optimized at a given + regime by :meth:`recommend`. + """ + + inner: SurrogateModel + regime_factors: list[Factor] + factors: list[Factor] + + def predict( + self, + regime: dict[str, Any], + config: dict[str, Any], + ) -> dict[str, float]: + """Predict observables at a regime + config pair. + + Args: + regime: Mapping of regime-feature names to values. + config: Mapping of design-factor names to values. + + Returns: + Mapping from observable name to predicted scalar. + """ + return self.inner.predict(_merge(regime, config)) + + def predict_batch( + self, + regime: dict[str, Any], + configs: Sequence[dict[str, Any]], + ) -> dict[str, NDArray[np.float64]]: + """Predict observables for a batch of configs at one regime. + + Args: + regime: Mapping of regime-feature names to values. + configs: Sequence of design-factor configs to score at + ``regime``. + + Returns: + Mapping from observable name to a length-``len(configs)`` + array of predictions. + """ + merged = [_merge(regime, c) for c in configs] + return self.inner.predict_batch(merged) + + def uncertainty( + self, + regime: dict[str, Any], + config: dict[str, Any], + ) -> dict[str, float]: + """Predictive standard deviation per observable (GP only). + + Args: + regime: Mapping of regime-feature names to values. + config: Mapping of design-factor names to values. + + Returns: + Mapping from observable name to predictive standard deviation. + Propagates :class:`NotImplementedError` from the underlying + surrogate when the backend does not expose calibrated + uncertainties (non-GP backends). + """ + return self.inner.uncertainty(_merge(regime, config)) + + def recommend( + self, + regime: dict[str, Any], + *, + objective: str, + mode: str = "min", + n_candidates: int = 512, + seed: int = 0, + candidates: Sequence[dict[str, Any]] | None = None, + ) -> dict[str, Any]: + """Recommend a design-factor config at a query regime. + + Samples ``n_candidates`` configs from the design-factor space via + a scrambled Sobol' sequence and returns the one whose surrogate + prediction for ``objective`` is best under ``mode``. + + Args: + regime: Mapping of regime-feature names to values. + objective: Name of the observable to optimize. Must be one + of ``self.inner.observable_names``. + mode: ``"min"`` or ``"max"``. + n_candidates: Number of design-space samples to evaluate. + Ignored when ``candidates`` is provided. + seed: Seed for the Sobol' sampler. + candidates: Optional explicit list of design-factor configs + to score; if given, overrides ``n_candidates``. + + Returns: + The candidate config (a copy) achieving the best predicted + ``objective`` under ``mode``. + + Raises: + ValueError: If ``objective`` is not a fitted observable, if + ``mode`` is not ``"min"`` or ``"max"``, or if there are + no candidates to score. + """ + if mode not in _SUPPORTED_MODES: + msg = f"mode must be one of {sorted(_SUPPORTED_MODES)}; got {mode!r}" + raise ValueError(msg) + if objective not in self.inner.observable_names: + msg = ( + f"objective {objective!r} is not a fitted observable; " + f"available: {self.inner.observable_names}" + ) + raise ValueError(msg) + pool = ( + list(candidates) + if candidates is not None + else build_grid( + self.factors, + method="sobol", + n_samples=n_candidates, + seed=seed, + ) + ) + if not pool: + msg = "recommend: no candidates to score" + raise ValueError(msg) + preds = self.predict_batch(regime, pool)[objective] + idx = int(np.argmin(preds)) if mode == "min" else int(np.argmax(preds)) + return dict(pool[idx]) + + +def fit_regime_surrogate( + results: ResultsTable, + regime_factors: list[Factor], + factors: list[Factor], + *, + method: str = "gp", + seed: int = 0, + n_estimators: int = 200, +) -> RegimeSurrogate: + """Fit a surrogate that conditions on regime features. + + Internally fits a single :class:`SurrogateModel` over the joint + ``regime_factors + factors`` space, so observables can be + interpolated across continuous regime axes. + + Every config in ``results.configs`` must contain values for both the + regime features and the design factors. + + Args: + results: A :class:`ResultsTable` from previous study runs that + spans multiple regimes. + regime_factors: Factors describing the regime (additional input + dimensions of the surrogate; typically continuous). + factors: Tunable design factors. Together with + ``regime_factors`` these must cover every key referenced in + ``results.configs``. + method: Surrogate backend, ``"gp"`` or ``"rf"``. See + :func:`trade_study.fit_surrogate`. + seed: Random seed forwarded to the backend estimators. + n_estimators: Number of trees for the ``"rf"`` backend. + + Returns: + A fitted :class:`RegimeSurrogate`. + + Raises: + ValueError: If ``regime_factors`` is empty, if a name appears in + both ``regime_factors`` and ``factors``, or if the + underlying :func:`fit_surrogate` call fails. + """ + if not regime_factors: + msg = "fit_regime_surrogate: regime_factors must be non-empty" + raise ValueError(msg) + overlap = {f.name for f in regime_factors} & {f.name for f in factors} + if overlap: + msg = ( + f"fit_regime_surrogate: names appear in both regime_factors and " + f"factors: {sorted(overlap)}" + ) + raise ValueError(msg) + inner = fit_surrogate( + results, + [*regime_factors, *factors], + method=method, + seed=seed, + n_estimators=n_estimators, + ) + return RegimeSurrogate( + inner=inner, + regime_factors=list(regime_factors), + factors=list(factors), + ) diff --git a/tests/test_regime.py b/tests/test_regime.py new file mode 100644 index 0000000..d6068b3 --- /dev/null +++ b/tests/test_regime.py @@ -0,0 +1,266 @@ +"""Tests for regime-conditional surrogate (#105).""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from trade_study import ( + Direction, + Factor, + FactorType, + Observable, + RegimeSurrogate, + build_grid, + fit_regime_surrogate, + run_grid, +) + +if TYPE_CHECKING: + from trade_study.protocols import ResultsTable + + +class _PassWorld: + """Trivial simulator: passes config through.""" + + def generate( + self, + config: dict[str, object], + ) -> tuple[dict[str, object], dict[str, object]]: + return config, config + + +class _RegimeScorer: + """Scorer where ``loss = (lr - 0.1*n)**2``. + + The optimal ``lr`` is a linear function of regime feature ``n``. + """ + + def score( + self, + truth: object, + observations: dict[str, object], + config: dict[str, object], + ) -> dict[str, float]: + del truth, config + n = float(observations["n"]) + lr = float(observations["lr"]) + return {"loss": (lr - 0.1 * n) ** 2} + + +@pytest.fixture +def regime_factor() -> Factor: + return Factor("n", FactorType.CONTINUOUS, bounds=(0.0, 10.0)) + + +@pytest.fixture +def design_factor() -> Factor: + return Factor("lr", FactorType.CONTINUOUS, bounds=(0.0, 1.0)) + + +def _make_results( + regime_factor: Factor, + design_factor: Factor, + n: int = 64, + seed: int = 0, +) -> ResultsTable: + grid = build_grid( + [regime_factor, design_factor], + method="sobol", + n_samples=n, + seed=seed, + ) + obs = [Observable("loss", Direction.MINIMIZE)] + return run_grid(_PassWorld(), _RegimeScorer(), grid, obs) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def test_fit_requires_regime_factors( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=8) + with pytest.raises(ValueError, match="regime_factors must be non-empty"): + fit_regime_surrogate(results, [], [design_factor]) + + +def test_fit_rejects_overlapping_names( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=8) + dup = Factor("n", FactorType.CONTINUOUS, bounds=(0.0, 10.0)) + with pytest.raises(ValueError, match="appear in both"): + fit_regime_surrogate(results, [regime_factor], [dup]) + + +# --------------------------------------------------------------------------- +# Predict / uncertainty +# --------------------------------------------------------------------------- + + +def test_predict_returns_observable_dict( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=32) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + ) + assert isinstance(sur, RegimeSurrogate) + pred = sur.predict({"n": 5.0}, {"lr": 0.5}) + assert set(pred) == {"loss"} + assert isinstance(pred["loss"], float) + + +def test_predict_batch_shape( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=32) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + ) + pred = sur.predict_batch({"n": 5.0}, [{"lr": x} for x in [0.0, 0.25, 0.75]]) + assert pred["loss"].shape == (3,) + + +def test_gp_uncertainty_returns_floats( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=32) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="gp", + seed=0, + ) + unc = sur.uncertainty({"n": 5.0}, {"lr": 0.5}) + assert set(unc) == {"loss"} + assert unc["loss"] >= 0.0 + + +def test_rf_uncertainty_raises( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=16) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + ) + with pytest.raises(NotImplementedError): + sur.uncertainty({"n": 5.0}, {"lr": 0.5}) + + +# --------------------------------------------------------------------------- +# Recommend +# --------------------------------------------------------------------------- + + +def test_recommend_tracks_regime( + regime_factor: Factor, + design_factor: Factor, +) -> None: + """At ``n=2`` the optimum is ``lr~=0.2``; at ``n=8`` it is ``lr~=0.8``.""" + results = _make_results(regime_factor, design_factor, n=128, seed=0) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + n_estimators=200, + ) + low = sur.recommend({"n": 2.0}, objective="loss", n_candidates=128, seed=1) + high = sur.recommend({"n": 8.0}, objective="loss", n_candidates=128, seed=1) + assert low["lr"] < high["lr"] + assert low["lr"] == pytest.approx(0.2, abs=0.2) + assert high["lr"] == pytest.approx(0.8, abs=0.2) + + +def test_recommend_mode_max_inverts_choice( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=64, seed=0) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + ) + pool = [{"lr": 0.1}, {"lr": 0.5}, {"lr": 0.9}] + best_min = sur.recommend({"n": 5.0}, objective="loss", mode="min", candidates=pool) + best_max = sur.recommend({"n": 5.0}, objective="loss", mode="max", candidates=pool) + preds = sur.predict_batch({"n": 5.0}, pool)["loss"] + assert best_min["lr"] == pool[int(np.argmin(preds))]["lr"] + assert best_max["lr"] == pool[int(np.argmax(preds))]["lr"] + + +def test_recommend_rejects_unknown_objective( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=16) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + ) + with pytest.raises(ValueError, match="not a fitted observable"): + sur.recommend({"n": 5.0}, objective="bogus") + + +def test_recommend_rejects_bad_mode( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=16) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + ) + with pytest.raises(ValueError, match="mode must be"): + sur.recommend({"n": 5.0}, objective="loss", mode="bogus") + + +def test_recommend_rejects_empty_candidates( + regime_factor: Factor, + design_factor: Factor, +) -> None: + results = _make_results(regime_factor, design_factor, n=16) + sur = fit_regime_surrogate( + results, + [regime_factor], + [design_factor], + method="rf", + seed=0, + ) + with pytest.raises(ValueError, match="no candidates"): + sur.recommend({"n": 5.0}, objective="loss", candidates=[])