Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Core types and interfaces for trade-study workflows.

::: trade_study.Simulator

::: trade_study.PartialEvaluator

::: trade_study.TrialResult

::: trade_study.ResultsTable
4 changes: 4 additions & 0 deletions docs/api/runner.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ Execute simulations across experimental grids.
::: trade_study.run_grid

::: trade_study.run_adaptive

::: trade_study.run_successive_halving

::: trade_study.run_hyperband
6 changes: 5 additions & 1 deletion src/trade_study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
Constraint,
Direction,
Observable,
PartialEvaluator,
ResultsTable,
Scorer,
Simulator,
TrialResult,
)
from .runner import run_adaptive, run_grid
from .runner import run_adaptive, run_grid, run_hyperband, run_successive_halving
from .stacking import ensemble_predict, stack_bayesian, stack_scores
from .study import (
Phase,
Expand All @@ -44,6 +45,7 @@
"FactorConstraint",
"FactorType",
"Observable",
"PartialEvaluator",
"Phase",
"ResultsTable",
"Scorer",
Expand All @@ -67,6 +69,8 @@
"reduce_factors",
"run_adaptive",
"run_grid",
"run_hyperband",
"run_successive_halving",
"save_results",
"score",
"screen",
Expand Down
35 changes: 35 additions & 0 deletions src/trade_study/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,41 @@ def generate(self, config: dict[str, Any]) -> tuple[Any, Any]:
...


@runtime_checkable
class PartialEvaluator(Protocol):
"""Protocol for incrementally-evaluable trials.

Used by successive-halving / Hyperband (#104) to discard unpromising
configurations after a small fraction of their full budget. The budget
is opaque to the runner — it may be epochs, MCMC iterations, dataset
fractions, mesh resolutions, or seconds. Implementations should
interpret ``budget`` as "run from scratch up to this much work" so a
trial promoted from rung *r* to rung *r+1* is re-trained at the larger
budget rather than continuing from the smaller one (this matches the
canonical Hyperband formulation; implementations are free to cache
intermediate state internally as an optimization).
"""

def evaluate(
self,
config: dict[str, Any],
budget: float,
) -> dict[str, float]:
"""Evaluate ``config`` at the given ``budget`` and return observables.

Args:
config: Dictionary of factor values defining this trial.
budget: Resource budget (epochs, iterations, dataset fraction,
wall seconds, ...). Larger means a higher-fidelity
evaluation.

Returns:
Mapping from observable name to scalar value, including the
metric used for early-stopping.
"""
...


@runtime_checkable
class Scorer(Protocol):
"""Protocol for scoring model output against truth."""
Expand Down
290 changes: 290 additions & 0 deletions src/trade_study/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Annotation,
Direction,
Observable,
PartialEvaluator,
ResultsTable,
Scorer,
Simulator,
Expand Down Expand Up @@ -188,3 +189,292 @@ def objective(trial: optuna.trial.Trial) -> tuple[float, ...]:
scores=np.array(score_rows),
observable_names=obs_names,
)


# ---------------------------------------------------------------------------
# Multi-fidelity early-stopping (#104)
# ---------------------------------------------------------------------------


def _sh_collect_observables(
rung_results: list[list[tuple[int, float, dict[str, float]]]],
) -> list[str]:
"""Stable union of observable names across all collected rung evaluations.

Args:
rung_results: One list of ``(trial_idx, budget, observables)``
tuples per rung.

Returns:
Sorted list of unique observable names.
"""
names: set[str] = set()
for rung in rung_results:
for _, _, obs in rung:
names.update(obs)
return sorted(names)


def _sh_validate_inputs(
trials: list[dict[str, Any]],
rungs: list[float],
eta: float,
metric: str,
mode: str,
) -> None:
"""Validate successive-halving arguments.

Args:
trials: Candidate configs.
rungs: Budget per rung, ascending.
eta: Halving factor.
metric: Observable name used for ranking.
mode: ``"min"`` or ``"max"``.

Raises:
ValueError: If any argument is invalid.
"""
if not trials:
msg = "run_successive_halving: trials must be non-empty"
raise ValueError(msg)
if len(rungs) < 1:
msg = "run_successive_halving: rungs must contain at least one budget"
raise ValueError(msg)
if any(b <= 0 for b in rungs):
msg = "run_successive_halving: rung budgets must be positive"
raise ValueError(msg)
if list(rungs) != sorted(rungs):
msg = "run_successive_halving: rungs must be ascending"
raise ValueError(msg)
if eta <= 1:
msg = "run_successive_halving: eta must be > 1"
raise ValueError(msg)
if not metric:
msg = "run_successive_halving: metric must be a non-empty string"
raise ValueError(msg)
if mode not in {"min", "max"}:
msg = f"run_successive_halving: mode must be 'min' or 'max', got {mode!r}"
raise ValueError(msg)


def run_successive_halving(
trials: list[dict[str, Any]],
sim: PartialEvaluator,
*,
rungs: list[float],
eta: float = 3.0,
metric: str,
mode: str = "min",
) -> ResultsTable:
"""Successive-halving multi-fidelity early-stopping (#104).

Evaluates every trial at the lowest rung, keeps the top ``1/eta`` by
``metric`` (according to ``mode``), promotes survivors to the next
budget, and repeats until the highest rung. Every (trial, rung)
evaluation is recorded as one row in the returned :class:`ResultsTable`,
with ``rung`` index and ``budget`` stored in per-row metadata.

Args:
trials: Candidate configurations to evaluate.
sim: A :class:`PartialEvaluator` whose ``evaluate(config, budget)``
returns observables including ``metric``.
rungs: Strictly ascending list of budgets (e.g. epochs, iterations).
Length determines the number of halving rounds.
eta: Reduction factor between rungs (>1). Each rung keeps
``ceil(n_prev / eta)`` survivors. Defaults to 3 per Li et al.
(2017).
metric: Observable name used to rank trials at each rung.
mode: ``"min"`` (lower is better) or ``"max"``.

Returns:
:class:`ResultsTable` whose rows are (trial, rung) evaluations.
Per-row metadata contains ``rung`` (0-indexed), ``budget``,
``trial_index`` (position in the input ``trials`` list),
``promoted`` (whether this trial advanced past this rung), and
``wall_seconds``. Propagates :class:`ValueError` from the input
validator when arguments are invalid.

Raises:
KeyError: If ``metric`` is missing from a returned observables
dict.
"""
_sh_validate_inputs(trials, rungs, eta, metric, mode)

# rung_records[r] = list of (trial_idx, budget, observables) at rung r
rung_records: list[list[tuple[int, float, dict[str, float], float]]] = [
[] for _ in rungs
]
survivors: list[int] = list(range(len(trials)))

for r, budget in enumerate(rungs):
for trial_idx in survivors:
t0 = time.perf_counter()
obs = sim.evaluate(trials[trial_idx], budget)
wall = time.perf_counter() - t0
if metric not in obs:
msg = (
f"run_successive_halving: PartialEvaluator did not return "
f"metric {metric!r} at rung {r} for trial {trial_idx}"
)
raise KeyError(msg)
rung_records[r].append((trial_idx, budget, obs, wall))

if r < len(rungs) - 1:
ranked = sorted(
rung_records[r],
key=lambda row: row[2][metric],
reverse=(mode == "max"),
)
n_keep = max(1, int(np.ceil(len(ranked) / eta)))
survivors = [row[0] for row in ranked[:n_keep]]

obs_names = _sh_collect_observables([
[(idx, b, o) for idx, b, o, _w in rung] for rung in rung_records
])

promoted_at_rung: list[set[int]] = [set() for _ in rungs]
for r in range(len(rungs) - 1):
ranked = sorted(
rung_records[r],
key=lambda row: row[2][metric],
reverse=(mode == "max"),
)
n_keep = max(1, int(np.ceil(len(ranked) / eta)))
promoted_at_rung[r] = {row[0] for row in ranked[:n_keep]}

configs: list[dict[str, Any]] = []
score_rows: list[list[float]] = []
metadata: list[dict[str, Any]] = []
for r, rung in enumerate(rung_records):
for trial_idx, budget, obs, wall in rung:
configs.append(trials[trial_idx])
score_rows.append([obs.get(name, float("nan")) for name in obs_names])
metadata.append({
"rung": r,
"budget": budget,
"trial_index": trial_idx,
"promoted": trial_idx in promoted_at_rung[r],
"wall_seconds": wall,
})

return ResultsTable(
configs=configs,
scores=np.array(score_rows) if score_rows else np.zeros((0, len(obs_names))),
observable_names=obs_names,
metadata=metadata,
)


def _hyperband_brackets(
max_budget: float,
eta: float,
) -> list[tuple[int, float]]:
"""Compute the (n_trials, min_budget) per Hyperband bracket.

Implements the bracket schedule from Li et al. (2017), Algorithm 1.

Args:
max_budget: Maximum resource ``R`` allocated to a single trial.
eta: Reduction factor (>1).

Returns:
List of ``(n_initial_trials, min_budget)`` tuples, one per bracket.
"""
s_max = int(np.floor(np.log(max_budget) / np.log(eta)))
budget_total = (s_max + 1) * max_budget
brackets: list[tuple[int, float]] = []
for s in range(s_max, -1, -1):
n = int(np.ceil(budget_total / max_budget * eta**s / (s + 1)))
r = max_budget * eta ** (-s)
brackets.append((n, r))
return brackets


def run_hyperband(
trial_factory: Callable[[int, int], list[dict[str, Any]]],
sim: PartialEvaluator,
*,
max_budget: float,
eta: float = 3.0,
metric: str,
mode: str = "min",
) -> ResultsTable:
"""Hyperband: multi-bracket successive-halving (#104).

Wraps :func:`run_successive_halving` with the bracket schedule from
Li et al. (2017). Each bracket trades off the number of initial
trials against the minimum budget per trial; together they hedge
against picking either ratio wrong.

Args:
trial_factory: Callable ``(bracket_index, n_trials) -> trials`` that
returns a fresh list of candidate configs for each bracket.
Typically this wraps :func:`trade_study.build_grid` with a
bracket-derived seed so brackets sample different points.
sim: A :class:`PartialEvaluator`.
max_budget: Maximum resource ``R`` per trial.
eta: Reduction factor (>1). Defaults to 3.
metric: Observable used for ranking within each bracket.
mode: ``"min"`` or ``"max"``.

Returns:
Concatenated :class:`ResultsTable` across all brackets, with an
additional ``bracket`` field in each row's metadata.

Raises:
ValueError: If ``max_budget <= 0`` or ``eta <= 1``.
"""
if max_budget <= 0:
msg = "run_hyperband: max_budget must be positive"
raise ValueError(msg)
if eta <= 1:
msg = "run_hyperband: eta must be > 1"
raise ValueError(msg)

brackets = _hyperband_brackets(max_budget, eta)

all_configs: list[dict[str, Any]] = []
all_scores: list[list[float]] = []
all_metadata: list[dict[str, Any]] = []
obs_names: list[str] = []

for bracket_idx, (n_initial, r_min) in enumerate(brackets):
trials = trial_factory(bracket_idx, n_initial)
if not trials:
continue
s = len(brackets) - 1 - bracket_idx
rungs = [r_min * eta**i for i in range(s + 1)]
bracket_results = run_successive_halving(
trials,
sim,
rungs=rungs,
eta=eta,
metric=metric,
mode=mode,
)
if not obs_names:
obs_names = bracket_results.observable_names
elif obs_names != bracket_results.observable_names:
# Pad / reorder so all brackets share the column layout.
union = sorted(set(obs_names) | set(bracket_results.observable_names))
obs_names = union

for i, cfg in enumerate(bracket_results.configs):
all_configs.append(cfg)
row = bracket_results.scores[i]
all_scores.append([
float(row[bracket_results.observable_names.index(n)])
if n in bracket_results.observable_names
else float("nan")
for n in obs_names
])
meta = dict(bracket_results.metadata[i])
meta["bracket"] = bracket_idx
all_metadata.append(meta)

return ResultsTable(
configs=all_configs,
scores=np.array(all_scores) if all_scores else np.zeros((0, len(obs_names))),
observable_names=obs_names,
metadata=all_metadata,
)
Loading
Loading