diff --git a/docs/api/protocols.md b/docs/api/protocols.md index a1b9694..8c37777 100644 --- a/docs/api/protocols.md +++ b/docs/api/protocols.md @@ -12,6 +12,8 @@ Core types and interfaces for trade-study workflows. ::: trade_study.Simulator +::: trade_study.PartialEvaluator + ::: trade_study.TrialResult ::: trade_study.ResultsTable diff --git a/docs/api/runner.md b/docs/api/runner.md index 1f33a77..f80d127 100644 --- a/docs/api/runner.md +++ b/docs/api/runner.md @@ -5,3 +5,7 @@ Execute simulations across experimental grids. ::: trade_study.run_grid ::: trade_study.run_adaptive + +::: trade_study.run_successive_halving + +::: trade_study.run_hyperband diff --git a/src/trade_study/__init__.py b/src/trade_study/__init__.py index f061df3..b6aa241 100644 --- a/src/trade_study/__init__.py +++ b/src/trade_study/__init__.py @@ -20,12 +20,13 @@ Constraint, Direction, Observable, + PartialEvaluator, ResultsTable, Scorer, Simulator, TrialResult, ) -from .runner import run_adaptive, run_grid +from .runner import run_adaptive, run_grid, run_hyperband, run_successive_halving from .stacking import ensemble_predict, stack_bayesian, stack_scores from .study import ( Phase, @@ -44,6 +45,7 @@ "FactorConstraint", "FactorType", "Observable", + "PartialEvaluator", "Phase", "ResultsTable", "Scorer", @@ -67,6 +69,8 @@ "reduce_factors", "run_adaptive", "run_grid", + "run_hyperband", + "run_successive_halving", "save_results", "score", "screen", diff --git a/src/trade_study/protocols.py b/src/trade_study/protocols.py index b472179..cc1893f 100644 --- a/src/trade_study/protocols.py +++ b/src/trade_study/protocols.py @@ -113,6 +113,41 @@ def generate(self, config: dict[str, Any]) -> tuple[Any, Any]: ... +@runtime_checkable +class PartialEvaluator(Protocol): + """Protocol for incrementally-evaluable trials. + + Used by successive-halving / Hyperband (#104) to discard unpromising + configurations after a small fraction of their full budget. The budget + is opaque to the runner — it may be epochs, MCMC iterations, dataset + fractions, mesh resolutions, or seconds. Implementations should + interpret ``budget`` as "run from scratch up to this much work" so a + trial promoted from rung *r* to rung *r+1* is re-trained at the larger + budget rather than continuing from the smaller one (this matches the + canonical Hyperband formulation; implementations are free to cache + intermediate state internally as an optimization). + """ + + def evaluate( + self, + config: dict[str, Any], + budget: float, + ) -> dict[str, float]: + """Evaluate ``config`` at the given ``budget`` and return observables. + + Args: + config: Dictionary of factor values defining this trial. + budget: Resource budget (epochs, iterations, dataset fraction, + wall seconds, ...). Larger means a higher-fidelity + evaluation. + + Returns: + Mapping from observable name to scalar value, including the + metric used for early-stopping. + """ + ... + + @runtime_checkable class Scorer(Protocol): """Protocol for scoring model output against truth.""" diff --git a/src/trade_study/runner.py b/src/trade_study/runner.py index 1eed569..e342c11 100644 --- a/src/trade_study/runner.py +++ b/src/trade_study/runner.py @@ -15,6 +15,7 @@ Annotation, Direction, Observable, + PartialEvaluator, ResultsTable, Scorer, Simulator, @@ -188,3 +189,292 @@ def objective(trial: optuna.trial.Trial) -> tuple[float, ...]: scores=np.array(score_rows), observable_names=obs_names, ) + + +# --------------------------------------------------------------------------- +# Multi-fidelity early-stopping (#104) +# --------------------------------------------------------------------------- + + +def _sh_collect_observables( + rung_results: list[list[tuple[int, float, dict[str, float]]]], +) -> list[str]: + """Stable union of observable names across all collected rung evaluations. + + Args: + rung_results: One list of ``(trial_idx, budget, observables)`` + tuples per rung. + + Returns: + Sorted list of unique observable names. + """ + names: set[str] = set() + for rung in rung_results: + for _, _, obs in rung: + names.update(obs) + return sorted(names) + + +def _sh_validate_inputs( + trials: list[dict[str, Any]], + rungs: list[float], + eta: float, + metric: str, + mode: str, +) -> None: + """Validate successive-halving arguments. + + Args: + trials: Candidate configs. + rungs: Budget per rung, ascending. + eta: Halving factor. + metric: Observable name used for ranking. + mode: ``"min"`` or ``"max"``. + + Raises: + ValueError: If any argument is invalid. + """ + if not trials: + msg = "run_successive_halving: trials must be non-empty" + raise ValueError(msg) + if len(rungs) < 1: + msg = "run_successive_halving: rungs must contain at least one budget" + raise ValueError(msg) + if any(b <= 0 for b in rungs): + msg = "run_successive_halving: rung budgets must be positive" + raise ValueError(msg) + if list(rungs) != sorted(rungs): + msg = "run_successive_halving: rungs must be ascending" + raise ValueError(msg) + if eta <= 1: + msg = "run_successive_halving: eta must be > 1" + raise ValueError(msg) + if not metric: + msg = "run_successive_halving: metric must be a non-empty string" + raise ValueError(msg) + if mode not in {"min", "max"}: + msg = f"run_successive_halving: mode must be 'min' or 'max', got {mode!r}" + raise ValueError(msg) + + +def run_successive_halving( + trials: list[dict[str, Any]], + sim: PartialEvaluator, + *, + rungs: list[float], + eta: float = 3.0, + metric: str, + mode: str = "min", +) -> ResultsTable: + """Successive-halving multi-fidelity early-stopping (#104). + + Evaluates every trial at the lowest rung, keeps the top ``1/eta`` by + ``metric`` (according to ``mode``), promotes survivors to the next + budget, and repeats until the highest rung. Every (trial, rung) + evaluation is recorded as one row in the returned :class:`ResultsTable`, + with ``rung`` index and ``budget`` stored in per-row metadata. + + Args: + trials: Candidate configurations to evaluate. + sim: A :class:`PartialEvaluator` whose ``evaluate(config, budget)`` + returns observables including ``metric``. + rungs: Strictly ascending list of budgets (e.g. epochs, iterations). + Length determines the number of halving rounds. + eta: Reduction factor between rungs (>1). Each rung keeps + ``ceil(n_prev / eta)`` survivors. Defaults to 3 per Li et al. + (2017). + metric: Observable name used to rank trials at each rung. + mode: ``"min"`` (lower is better) or ``"max"``. + + Returns: + :class:`ResultsTable` whose rows are (trial, rung) evaluations. + Per-row metadata contains ``rung`` (0-indexed), ``budget``, + ``trial_index`` (position in the input ``trials`` list), + ``promoted`` (whether this trial advanced past this rung), and + ``wall_seconds``. Propagates :class:`ValueError` from the input + validator when arguments are invalid. + + Raises: + KeyError: If ``metric`` is missing from a returned observables + dict. + """ + _sh_validate_inputs(trials, rungs, eta, metric, mode) + + # rung_records[r] = list of (trial_idx, budget, observables) at rung r + rung_records: list[list[tuple[int, float, dict[str, float], float]]] = [ + [] for _ in rungs + ] + survivors: list[int] = list(range(len(trials))) + + for r, budget in enumerate(rungs): + for trial_idx in survivors: + t0 = time.perf_counter() + obs = sim.evaluate(trials[trial_idx], budget) + wall = time.perf_counter() - t0 + if metric not in obs: + msg = ( + f"run_successive_halving: PartialEvaluator did not return " + f"metric {metric!r} at rung {r} for trial {trial_idx}" + ) + raise KeyError(msg) + rung_records[r].append((trial_idx, budget, obs, wall)) + + if r < len(rungs) - 1: + ranked = sorted( + rung_records[r], + key=lambda row: row[2][metric], + reverse=(mode == "max"), + ) + n_keep = max(1, int(np.ceil(len(ranked) / eta))) + survivors = [row[0] for row in ranked[:n_keep]] + + obs_names = _sh_collect_observables([ + [(idx, b, o) for idx, b, o, _w in rung] for rung in rung_records + ]) + + promoted_at_rung: list[set[int]] = [set() for _ in rungs] + for r in range(len(rungs) - 1): + ranked = sorted( + rung_records[r], + key=lambda row: row[2][metric], + reverse=(mode == "max"), + ) + n_keep = max(1, int(np.ceil(len(ranked) / eta))) + promoted_at_rung[r] = {row[0] for row in ranked[:n_keep]} + + configs: list[dict[str, Any]] = [] + score_rows: list[list[float]] = [] + metadata: list[dict[str, Any]] = [] + for r, rung in enumerate(rung_records): + for trial_idx, budget, obs, wall in rung: + configs.append(trials[trial_idx]) + score_rows.append([obs.get(name, float("nan")) for name in obs_names]) + metadata.append({ + "rung": r, + "budget": budget, + "trial_index": trial_idx, + "promoted": trial_idx in promoted_at_rung[r], + "wall_seconds": wall, + }) + + return ResultsTable( + configs=configs, + scores=np.array(score_rows) if score_rows else np.zeros((0, len(obs_names))), + observable_names=obs_names, + metadata=metadata, + ) + + +def _hyperband_brackets( + max_budget: float, + eta: float, +) -> list[tuple[int, float]]: + """Compute the (n_trials, min_budget) per Hyperband bracket. + + Implements the bracket schedule from Li et al. (2017), Algorithm 1. + + Args: + max_budget: Maximum resource ``R`` allocated to a single trial. + eta: Reduction factor (>1). + + Returns: + List of ``(n_initial_trials, min_budget)`` tuples, one per bracket. + """ + s_max = int(np.floor(np.log(max_budget) / np.log(eta))) + budget_total = (s_max + 1) * max_budget + brackets: list[tuple[int, float]] = [] + for s in range(s_max, -1, -1): + n = int(np.ceil(budget_total / max_budget * eta**s / (s + 1))) + r = max_budget * eta ** (-s) + brackets.append((n, r)) + return brackets + + +def run_hyperband( + trial_factory: Callable[[int, int], list[dict[str, Any]]], + sim: PartialEvaluator, + *, + max_budget: float, + eta: float = 3.0, + metric: str, + mode: str = "min", +) -> ResultsTable: + """Hyperband: multi-bracket successive-halving (#104). + + Wraps :func:`run_successive_halving` with the bracket schedule from + Li et al. (2017). Each bracket trades off the number of initial + trials against the minimum budget per trial; together they hedge + against picking either ratio wrong. + + Args: + trial_factory: Callable ``(bracket_index, n_trials) -> trials`` that + returns a fresh list of candidate configs for each bracket. + Typically this wraps :func:`trade_study.build_grid` with a + bracket-derived seed so brackets sample different points. + sim: A :class:`PartialEvaluator`. + max_budget: Maximum resource ``R`` per trial. + eta: Reduction factor (>1). Defaults to 3. + metric: Observable used for ranking within each bracket. + mode: ``"min"`` or ``"max"``. + + Returns: + Concatenated :class:`ResultsTable` across all brackets, with an + additional ``bracket`` field in each row's metadata. + + Raises: + ValueError: If ``max_budget <= 0`` or ``eta <= 1``. + """ + if max_budget <= 0: + msg = "run_hyperband: max_budget must be positive" + raise ValueError(msg) + if eta <= 1: + msg = "run_hyperband: eta must be > 1" + raise ValueError(msg) + + brackets = _hyperband_brackets(max_budget, eta) + + all_configs: list[dict[str, Any]] = [] + all_scores: list[list[float]] = [] + all_metadata: list[dict[str, Any]] = [] + obs_names: list[str] = [] + + for bracket_idx, (n_initial, r_min) in enumerate(brackets): + trials = trial_factory(bracket_idx, n_initial) + if not trials: + continue + s = len(brackets) - 1 - bracket_idx + rungs = [r_min * eta**i for i in range(s + 1)] + bracket_results = run_successive_halving( + trials, + sim, + rungs=rungs, + eta=eta, + metric=metric, + mode=mode, + ) + if not obs_names: + obs_names = bracket_results.observable_names + elif obs_names != bracket_results.observable_names: + # Pad / reorder so all brackets share the column layout. + union = sorted(set(obs_names) | set(bracket_results.observable_names)) + obs_names = union + + for i, cfg in enumerate(bracket_results.configs): + all_configs.append(cfg) + row = bracket_results.scores[i] + all_scores.append([ + float(row[bracket_results.observable_names.index(n)]) + if n in bracket_results.observable_names + else float("nan") + for n in obs_names + ]) + meta = dict(bracket_results.metadata[i]) + meta["bracket"] = bracket_idx + all_metadata.append(meta) + + return ResultsTable( + configs=all_configs, + scores=np.array(all_scores) if all_scores else np.zeros((0, len(obs_names))), + observable_names=obs_names, + metadata=all_metadata, + ) diff --git a/tests/test_runner.py b/tests/test_runner.py index ec0ef19..8b193e4 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -9,7 +9,12 @@ from trade_study.design import Factor, FactorType from trade_study.protocols import Annotation, Direction, Observable, TrialResult -from trade_study.runner import run_adaptive, run_grid +from trade_study.runner import ( + run_adaptive, + run_grid, + run_hyperband, + run_successive_halving, +) # --------------------------------------------------------------------------- # Toy implementations @@ -358,3 +363,180 @@ def test_run_grid_callback_none( grid = [{"alpha": 0.5}] result = run_grid(world, scorer, grid, observables) assert len(result.configs) == 1 + + +# --------------------------------------------------------------------------- +# Successive halving / Hyperband (#104) +# --------------------------------------------------------------------------- + + +class _ToyPartialEvaluator: + """Toy PartialEvaluator: loss decays as `target * exp(-budget/scale)` + noise. + + Trials with smaller `target` reach lower loss faster, so they should + survive successive halving. + """ + + def evaluate( + self, + config: dict[str, Any], + budget: float, + ) -> dict[str, float]: + """Return a budget-decayed loss for ``config``. + + Args: + config: Must contain ``target`` (asymptotic loss). + budget: Resource budget; larger ⇒ better fidelity. + + Returns: + Dict with ``loss`` and ``budget`` observables. + """ + target = float(config["target"]) + loss = target + np.exp(-budget / 5.0) + return {"loss": loss, "budget": float(budget)} + + +def test_successive_halving_keeps_best() -> None: + sim = _ToyPartialEvaluator() + trials = [{"target": t} for t in [0.1, 0.5, 0.9, 0.05, 0.7, 0.3, 0.6, 0.2, 0.8]] + results = run_successive_halving( + trials, + sim, + rungs=[1.0, 3.0, 9.0], + eta=3.0, + metric="loss", + mode="min", + ) + # Final rung should contain ceil(ceil(9/3)/3) = 1 trial, + # which must be the lowest target. + final = [m for m in results.metadata if m["rung"] == 2] + assert len(final) == 1 + survivor_idx = final[0]["trial_index"] + assert trials[survivor_idx]["target"] == pytest.approx(0.05) + + +def test_successive_halving_row_count() -> None: + sim = _ToyPartialEvaluator() + trials = [{"target": t / 10} for t in range(9)] + results = run_successive_halving( + trials, + sim, + rungs=[1.0, 3.0, 9.0], + eta=3.0, + metric="loss", + mode="min", + ) + # Rung sizes: 9, 3, 1 → 13 rows. + assert len(results.configs) == 9 + 3 + 1 + assert results.scores.shape == (13, 2) + + +def test_successive_halving_records_metadata() -> None: + sim = _ToyPartialEvaluator() + trials = [{"target": 0.1}, {"target": 0.9}] + results = run_successive_halving( + trials, + sim, + rungs=[1.0, 3.0], + eta=2.0, + metric="loss", + mode="min", + ) + keys = set(results.metadata[0]) + assert {"rung", "budget", "trial_index", "promoted", "wall_seconds"} <= keys + promoted_first_rung = [ + m for m in results.metadata if m["rung"] == 0 and m["promoted"] + ] + assert len(promoted_first_rung) == 1 + + +def test_successive_halving_max_mode() -> None: + sim = _ToyPartialEvaluator() + trials = [{"target": t} for t in [0.1, 0.9]] + # In max mode, the higher loss wins. + results = run_successive_halving( + trials, + sim, + rungs=[1.0, 3.0], + eta=2.0, + metric="loss", + mode="max", + ) + final = [m for m in results.metadata if m["rung"] == 1] + assert len(final) == 1 + assert trials[final[0]["trial_index"]]["target"] == pytest.approx(0.9) + + +def test_successive_halving_validation() -> None: + sim = _ToyPartialEvaluator() + trials = [{"target": 0.1}] + with pytest.raises(ValueError, match="trials must be non-empty"): + run_successive_halving([], sim, rungs=[1.0], metric="loss") + with pytest.raises(ValueError, match="ascending"): + run_successive_halving(trials, sim, rungs=[3.0, 1.0], metric="loss") + with pytest.raises(ValueError, match="positive"): + run_successive_halving(trials, sim, rungs=[0.0], metric="loss") + with pytest.raises(ValueError, match="eta must be > 1"): + run_successive_halving(trials, sim, rungs=[1.0], eta=1.0, metric="loss") + with pytest.raises(ValueError, match="metric must be"): + run_successive_halving(trials, sim, rungs=[1.0], metric="") + with pytest.raises(ValueError, match="mode must be"): + run_successive_halving(trials, sim, rungs=[1.0], metric="loss", mode="bogus") + with pytest.raises(ValueError, match="rungs must contain"): + run_successive_halving(trials, sim, rungs=[], metric="loss") + + +def test_successive_halving_missing_metric() -> None: + class _NoMetric: + def evaluate(self, _config: dict[str, Any], _budget: float) -> dict[str, float]: + return {"other": 1.0} + + with pytest.raises(KeyError, match="did not return metric"): + run_successive_halving( + [{"target": 0.1}], + _NoMetric(), + rungs=[1.0], + metric="loss", + ) + + +def test_hyperband_runs_all_brackets() -> None: + sim = _ToyPartialEvaluator() + rng = np.random.default_rng(0) + + def factory(bracket_idx: int, n: int) -> list[dict[str, Any]]: + # Bracket-seeded random targets so brackets explore different points. + local_rng = np.random.default_rng(bracket_idx + 1) + return [{"target": float(local_rng.uniform(0.0, 1.0))} for _ in range(n)] + + _ = rng # silence unused + results = run_hyperband( + factory, + sim, + max_budget=9.0, + eta=3.0, + metric="loss", + mode="min", + ) + brackets_seen = {m["bracket"] for m in results.metadata} + # eta=3, R=9 → s_max = 2 → 3 brackets (s = 2, 1, 0). + assert brackets_seen == {0, 1, 2} + assert "loss" in results.observable_names + + +def test_hyperband_validation() -> None: + sim = _ToyPartialEvaluator() + + def factory(_b: int, _n: int) -> list[dict[str, Any]]: + return [{"target": 0.1}] + + with pytest.raises(ValueError, match="max_budget must be positive"): + run_hyperband(factory, sim, max_budget=0.0, metric="loss") + with pytest.raises(ValueError, match="eta must be > 1"): + run_hyperband(factory, sim, max_budget=9.0, eta=1.0, metric="loss") + + +def test_partial_evaluator_protocol_runtime_check() -> None: + from trade_study.protocols import PartialEvaluator + + assert isinstance(_ToyPartialEvaluator(), PartialEvaluator)