diff --git a/src/lcm/model.py b/src/lcm/model.py index 7ed2981a..9412d660 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -221,13 +221,14 @@ def solve( ) except InvalidValueFunctionError as exc: if log_path is not None and exc.partial_solution is not None: - save_solve_snapshot( + snap_dir = save_solve_snapshot( model=self, params=params, period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] log_path=Path(log_path), log_keep_n_latest=log_keep_n_latest, ) + exc.add_note(f"Snapshot saved to {snap_dir}") raise if log_level == "debug" and log_path is not None: save_solve_snapshot( @@ -331,13 +332,14 @@ def simulate( ) except InvalidValueFunctionError as exc: if log_path is not None and exc.partial_solution is not None: - save_solve_snapshot( + snap_dir = save_solve_snapshot( model=self, params=params, period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] log_path=Path(log_path), log_keep_n_latest=log_keep_n_latest, ) + exc.add_note(f"Snapshot saved to {snap_dir}") raise result = simulate( internal_params=internal_params, diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index eb70efe9..d8152e13 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -12,7 +12,7 @@ from lcm.ages import AgeGrid from lcm.interfaces import InternalRegime -from lcm.typing import FlatRegimeParams, FloatND, InternalParams, RegimeName +from lcm.typing import FloatND, InternalParams, RegimeName from lcm.utils.error_handling import validate_V from lcm.utils.logging import ( format_duration, @@ -73,27 +73,40 @@ def solve( solution: dict[int, MappingProxyType[RegimeName, FloatND]] = {} - # Async diagnostics accumulators: every `jnp.any(isnan)`, - # `jnp.any(isinf)` (and the debug min/max/mean trio) lives here as - # a device-side scalar during the hot loop. No host sync happens - # until the single flush in `_emit_deferred_diagnostics` post-loop. - # This replaces the pre-existing synchronous `log_nan_in_V` + - # `log_V_stats` + `validate_V` triple, which forced one host - # transfer per (regime, period) — ~n_regimes * n_periods stalls - # per solve, a meaningful throughput tax in MSM-style loops. - # Both gates fall out of the public log level: `"off"` ⇒ nothing, - # `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the - # min/max/mean trio. `"off"` skips even the NaN fail-fast — that - # is the documented contract of `"off"` (suppress all output) and - # is what makes the level useful for tight estimation loops. + # Async diagnostics accumulators: per-period `jnp.any(isnan)` / + # `jnp.any(isinf)` (and the debug min/max/mean trio) live here as + # device-side scalars during the hot loop. The two NaN/Inf flags + # are folded into single running scalars; the per-period min/max/ + # mean trio is appended to a list (only emitted at debug, where + # we genuinely want every number on host). + # + # Per-period `block_until_ready()` after the running update forces + # the device kernel to finish before the next period dispatches. + # This frees the per-period `isnan(V_arr)` / `isinf(V_arr)` + # intermediate buffers (V_arr-shaped, so model-dependent) so they + # don't stack up across the loop. `block_until_ready` is a + # *device-only* sync — no host transfer, no PCIe round-trip — so + # it doesn't introduce a host stall: if `max_Q_over_a` (the + # dominant per-period kernel) is in flight, the call returns + # immediately when the small reduction is done. + # + # One host transfer per stat at end of solve (`.item()` on the + # running scalars) decides whether to enter the failure-path + # localisation. On a healthy solve no per-row materialisation + # happens. + # + # Gate falls out of the public log level: + # - `"off"` ⇒ nothing (skips even the NaN fail-fast) + # - `"warning"` / `"progress"` ⇒ NaN/Inf only + # - `"debug"` ⇒ adds the min/max/mean trio diagnostics_enabled = logger.isEnabledFor(logging.WARNING) stats_enabled = logger.isEnabledFor(logging.DEBUG) diagnostic_rows: list[_DiagnosticRow] = [] diagnostic_min: list[FloatND] = [] diagnostic_max: list[FloatND] = [] diagnostic_mean: list[FloatND] = [] - diagnostic_any_nan: list[FloatND] = [] - diagnostic_any_inf: list[FloatND] = [] + running_any_nan: FloatND = jnp.zeros((), dtype=bool) + running_any_inf: FloatND = jnp.zeros((), dtype=bool) logger.info("Starting solution") total_start = time.monotonic() @@ -136,36 +149,44 @@ def solve( # Async reductions: gated on log level. `"off"` skips # everything — no kernel launches, no host syncs, no - # NaN fail-fast. `"warning"` / `"progress"` launches the - # two cheap isnan/isinf reductions; `"debug"` adds the - # min/max/mean trio. Each extra full-V read is a - # memory-bandwidth tax on the larger models, so the - # default keeps it to two reductions per (regime, period). + # NaN fail-fast. `"warning"` / `"progress"` folds two + # cheap isnan/isinf reductions into the running scalars; + # `"debug"` adds the min/max/mean trio. Each extra full-V + # read is a memory-bandwidth tax on the larger models, so + # the default keeps it to two reductions per (regime, + # period). if diagnostics_enabled: if stats_enabled: diagnostic_min.append(jnp.min(V_arr)) diagnostic_max.append(jnp.max(V_arr)) diagnostic_mean.append(jnp.mean(V_arr)) - diagnostic_any_nan.append(jnp.any(jnp.isnan(V_arr))) - diagnostic_any_inf.append(jnp.any(jnp.isinf(V_arr))) + running_any_nan = running_any_nan | jnp.any(jnp.isnan(V_arr)) + running_any_inf = running_any_inf | jnp.any(jnp.isinf(V_arr)) diagnostic_rows.append( _DiagnosticRow( regime_name=regime_name, period=period, age=float(ages.values[period]), - state_action_space=state_action_space, - next_regime_to_V_arr=next_regime_to_V_arr, - regime_params=internal_params[regime_name], - compute_intermediates=( - internal_regime.solve_functions.compute_intermediates.get( - period - ) - ), ) ) period_solution[regime_name] = V_arr + # Force the device-side reduction kernels to finish before the + # next period dispatches, so each period's `isnan` / `isinf` + # (and min/max/mean) intermediate buffers can be freed instead + # of stacking up. `block_until_ready` does NOT transfer to host + # — it is a device-side wait, cheap when the dominant + # per-period kernel (`max_Q_over_a`) is the actual bottleneck. + if diagnostics_enabled: + running_any_nan.block_until_ready() + running_any_inf.block_until_ready() + if stats_enabled and diagnostic_mean: + # Blocking on the last-appended stat suffices: XLA + # serialises dispatch order, so a finished `mean` + # implies a finished `min`/`max` too. + diagnostic_mean[-1].block_until_ready() + # Maintain consistent pytree structure: keep all regime keys, # update active regimes with solved V arrays. next_regime_to_V_arr = MappingProxyType( @@ -181,22 +202,28 @@ def solve( elapsed = time.monotonic() - period_start log_period_timing(logger=logger, elapsed=elapsed) - # One flush of the GPU kernel queue: ship the stacked reductions - # to host in two transfers (isnan / isinf) by default, plus three - # more (min / max / mean) when debug stats were enabled. Skipped - # entirely at `log_level="off"` — nothing was accumulated. + # Fail-fast on NaN: surface the offending period immediately + # instead of finishing the whole backward induction. Costs one + # host transfer of a scalar bool per period — negligible next + # to the per-period `max_Q_over_a` kernel, and only paid when + # diagnostics are on. Inf is non-fatal so we don't break on + # it; the post-loop emitter still raises a warning if any + # period flagged Inf. + if diagnostics_enabled and running_any_nan.item(): + break + if diagnostics_enabled: - _emit_deferred_diagnostics( + _emit_post_loop_diagnostics( logger=logger, diagnostic_rows=diagnostic_rows, - reductions=_StackedReductions( - mins=jnp.stack(diagnostic_min) if diagnostic_min else None, - maxs=jnp.stack(diagnostic_max) if diagnostic_max else None, - means=jnp.stack(diagnostic_mean) if diagnostic_mean else None, - any_nan=jnp.stack(diagnostic_any_nan), - any_inf=jnp.stack(diagnostic_any_inf), - ), solution=MappingProxyType(solution), + internal_regimes=internal_regimes, + internal_params=internal_params, + running_any_nan=running_any_nan, + running_any_inf=running_any_inf, + diagnostic_min=diagnostic_min if stats_enabled else None, + diagnostic_max=diagnostic_max if stats_enabled else None, + diagnostic_mean=diagnostic_mean if stats_enabled else None, ) total_elapsed = time.monotonic() - total_start @@ -390,11 +417,12 @@ def _get_regime_V_shapes( class _DiagnosticRow: """Metadata captured during the backward-induction loop. - Stored refs only — no device work — so appending these rows inside - the hot loop costs essentially nothing. The expensive part (NaN - diagnostic enrichment via `compute_intermediates`) runs at most - once per solve, on the first offending row found after the single - post-loop host flush. + Holds only Python-scalar metadata — no device-array references — so + every (regime, period) row stays at a few bytes regardless of grid + size. State-action space, next-period V mapping, regime params, and + the `compute_intermediates` closure are reconstructed lazily on the + failure path from `internal_regimes`, `internal_params`, and the + partial `solution` built up to that point. """ regime_name: RegimeName @@ -403,139 +431,181 @@ class _DiagnosticRow: """Period index in the backward-induction loop.""" age: float """Age corresponding to `period` (pulled off `AgeGrid.values`).""" - state_action_space: object - """Typed as `object` to avoid a heavy import cycle; consumers know - the actual runtime type from the `max_Q_over_a` signature.""" - next_regime_to_V_arr: MappingProxyType[RegimeName, FloatND] - """Incoming next-period V-arrays, passed through unchanged to - `compute_intermediates` when a NaN is detected.""" - regime_params: FlatRegimeParams - """Flat regime parameters used at this (regime, period).""" - compute_intermediates: Callable | None - """Optional closure that recomputes U / F / E[V] / Q for NaN - diagnostic enrichment. `None` when the regime has no - compute-intermediates closure (e.g. terminal periods).""" - - -@dataclass(frozen=True) -class _StackedReductions: - """Per-stat JAX arrays stacked across all diagnostic rows; still on device. - - `mins` / `maxs` / `means` are `None` when the solve ran with a log - level below `debug` — the GPU wasn't asked to compute those - statistics so there's nothing to stack. - """ - - mins: FloatND | None - """Per-row min of V, or `None` below debug log level.""" - maxs: FloatND | None - """Per-row max of V, or `None` below debug log level.""" - means: FloatND | None - """Per-row mean of V, or `None` below debug log level.""" - any_nan: FloatND - """Per-row boolean flag: any NaN in V at this (regime, period).""" - any_inf: FloatND - """Per-row boolean flag: any Inf in V at this (regime, period).""" -def _emit_deferred_diagnostics( +def _emit_post_loop_diagnostics( *, logger: logging.Logger, diagnostic_rows: list[_DiagnosticRow], - reductions: _StackedReductions, solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + running_any_nan: FloatND, + running_any_inf: FloatND, + diagnostic_min: list[FloatND] | None, + diagnostic_max: list[FloatND] | None, + diagnostic_mean: list[FloatND] | None, ) -> None: - """Flush async diagnostics to host, emit logs, raise on NaN. - - Exactly two host transfers by default (one per stat stack), plus - three more (min / max / mean) when debug stats were enabled. - Ordering: NaN check first so we raise before emitting any stats - lines the user wouldn't see anyway; inf check next (warning only); - per-period stats last at debug log level. The `.tolist()` calls - are what actually block on the GPU queue — everything above this - function ran async. - """ - any_nan = reductions.any_nan.tolist() - any_inf = reductions.any_inf.tolist() - - _raise_if_nan( - diagnostic_rows=diagnostic_rows, - any_nan_per_row=any_nan, - solution=solution, - ) - _warn_if_inf( - logger=logger, - diagnostic_rows=diagnostic_rows, - any_inf_per_row=any_inf, - ) - - if ( - not logger.isEnabledFor(logging.DEBUG) - or reductions.mins is None - or reductions.maxs is None - or reductions.means is None - ): - return + """Flush async diagnostics: raise on NaN, warn on Inf, log debug stats. - mins = reductions.mins.tolist() - maxs = reductions.maxs.tolist() - means = reductions.means.tolist() - for row, v_min, v_max, v_mean in zip( - diagnostic_rows, mins, maxs, means, strict=True - ): - logger.debug( - " %s age %s V min=%.3g max=%.3g mean=%.3g", - row.regime_name, - row.age, - v_min, - v_max, - v_mean, + The two `.item()` calls on the running scalars decide whether to + enter the per-row failure-path localisation. On a healthy solve + neither inner walk runs and no per-row scalar is materialised, so + device memory stays bounded by the V templates currently in flight. + """ + if running_any_nan.item(): + _raise_first_nan_row( + diagnostic_rows=diagnostic_rows, + solution=solution, + internal_regimes=internal_regimes, + internal_params=internal_params, + ) + if running_any_inf.item(): + _warn_inf_rows( + logger=logger, + diagnostic_rows=diagnostic_rows, + solution=solution, + ) + if diagnostic_min is not None and diagnostic_max is not None and diagnostic_mean: + _log_per_period_stats( + logger=logger, + diagnostic_rows=diagnostic_rows, + mins=jnp.stack(diagnostic_min), + maxs=jnp.stack(diagnostic_max), + means=jnp.stack(diagnostic_mean), ) -def _raise_if_nan( +def _raise_first_nan_row( *, diagnostic_rows: list[_DiagnosticRow], - any_nan_per_row: list, # list[bool] solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, ) -> None: - """Find the first NaN-bearing (regime, period) and raise.""" - for row, flag in zip(diagnostic_rows, any_nan_per_row, strict=True): - if flag: - _raise_at(row=row, solution=solution) + """Find the first NaN-bearing (regime, period) and raise. + + Only invoked on the failure path (`running_any_nan` was True). + Materialises one host-side bool per row until the first hit; on + a healthy solve this function is never called. + """ + for row in diagnostic_rows: + V_arr = solution[row.period][row.regime_name] + if jnp.any(jnp.isnan(V_arr)).item(): + _raise_at( + row=row, + solution=solution, + internal_regimes=internal_regimes, + internal_params=internal_params, + ) def _raise_at( *, row: _DiagnosticRow, solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, ) -> None: """Run the enriched NaN diagnostic on a single offending row and raise.""" + internal_regime = internal_regimes[row.regime_name] + regime_params = internal_params[row.regime_name] + state_action_space = internal_regime.state_action_space(regime_params=regime_params) + next_regime_to_V_arr = _reconstruct_next_regime_to_V_arr( + period=row.period, + internal_regimes=internal_regimes, + internal_params=internal_params, + solution=solution, + ) + compute_intermediates = internal_regime.solve_functions.compute_intermediates.get( + row.period + ) V_arr = solution[row.period][row.regime_name] validate_V( V_arr=V_arr, age=row.age, regime_name=row.regime_name, partial_solution=solution, - compute_intermediates=row.compute_intermediates, - state_action_space=row.state_action_space, # ty: ignore[invalid-argument-type] - next_regime_to_V_arr=row.next_regime_to_V_arr, - internal_params=row.regime_params, + compute_intermediates=compute_intermediates, + state_action_space=state_action_space, + next_regime_to_V_arr=next_regime_to_V_arr, + internal_params=regime_params, period=row.period, ) -def _warn_if_inf( +def _reconstruct_next_regime_to_V_arr( + *, + period: int, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], +) -> MappingProxyType[RegimeName, FloatND]: + """Recreate the rolling `next_regime_to_V_arr` that was used at `period`. + + The hot loop rolls the per-regime V forward via `period_solution.get(name, + next_regime_to_V_arr[name])`, so at iteration `period` each regime's slot + holds its V from the smallest later period where it was active, falling + back to a zeros template otherwise. + + We rebuild the same mapping post-hoc from `solution`. The shapes come from + the regime's state-action space at the supplied params — identical to what + `_get_regime_V_shapes` saw during solve setup. + """ + regime_V_shapes = _get_regime_V_shapes( + internal_regimes=internal_regimes, + internal_params=internal_params, + ) + later_periods = sorted(p for p in solution if p > period) + result: dict[RegimeName, FloatND] = {} + for regime_name, shape in regime_V_shapes.items(): + rolled: FloatND | None = None + for q in later_periods: + if regime_name in solution[q]: + rolled = solution[q][regime_name] + break + result[regime_name] = rolled if rolled is not None else jnp.zeros(shape) + return MappingProxyType(result) + + +def _warn_inf_rows( *, logger: logging.Logger, diagnostic_rows: list[_DiagnosticRow], - any_inf_per_row: list, # list[bool] + solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], ) -> None: - """Emit a warning per (regime, period) with Inf values.""" - for row, flag in zip(diagnostic_rows, any_inf_per_row, strict=True): - if flag: + """Emit a warning per (regime, period) with Inf values. + + Only invoked on the failure path (`running_any_inf` was True). + Materialises one host-side bool per row. + """ + for row in diagnostic_rows: + V_arr = solution[row.period][row.regime_name] + if jnp.any(jnp.isinf(V_arr)).item(): logger.warning( "Inf in V_arr for regime '%s' at age %s", row.regime_name, row.age, ) + + +def _log_per_period_stats( + *, + logger: logging.Logger, + diagnostic_rows: list[_DiagnosticRow], + mins: FloatND, + maxs: FloatND, + means: FloatND, +) -> None: + """Emit one debug log line per (regime, period) with V min/max/mean.""" + for row, v_min, v_max, v_mean in zip( + diagnostic_rows, mins.tolist(), maxs.tolist(), means.tolist(), strict=True + ): + logger.debug( + " %s age %s V min=%.3g max=%.3g mean=%.3g", + row.regime_name, + row.age, + v_min, + v_max, + v_mean, + ) diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index c7f16008..ef631349 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -93,12 +93,11 @@ def validate_V( "(e.g. from a NaN survival probability or a NaN fixed param).\n" "- A per-target state_transitions dict omits a reachable target " "(non-zero transition probability to an incomplete target).\n\n" - "To diagnose, re-solve with debug logging:\n\n" - ' model.solve(params=params, log_level="debug", ' - 'log_path="./debug/")\n\n' - "The snapshot saved on failure contains diagnostics that pinpoint " - "where NaN enters (U, E[V], or regime transitions). See the " - "debugging guide:\n" + "See the [NOTE] below for the per-intermediate / per-axis " + "breakdown produced by `compute_intermediates`. When `log_path` " + "is configured, an additional [NOTE] points to the on-disk " + "snapshot directory written before this exception was raised. " + "Debugging guide:\n" "https://pylcm.readthedocs.io/en/latest/user_guide/debugging/" ) exc.partial_solution = partial_solution diff --git a/tests/solution/test_diagnostics.py b/tests/solution/test_diagnostics.py new file mode 100644 index 00000000..262ae0e1 --- /dev/null +++ b/tests/solution/test_diagnostics.py @@ -0,0 +1,125 @@ +"""Tests for the post-loop diagnostics path in `solve_brute.solve`. + +These cover: +- happy path at `log_level="warning"` runs without raising; +- NaN-bearing solves raise `InvalidValueFunctionError` and the message + identifies the offending `(regime, age)`; +- `log_level="debug"` emits one stat line per `(regime, period)`; +- `log_level="off"` emits nothing and skips even the NaN fail-fast. +""" + +import logging +from pathlib import Path + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.exceptions import InvalidValueFunctionError +from lcm.typing import ContinuousAction, ContinuousState, FloatND + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _utility(consumption: ContinuousAction, wealth: ContinuousState) -> FloatND: + return jnp.log(consumption + 1) + 0.01 * wealth + + +def _next_wealth( + wealth: ContinuousState, + consumption: ContinuousAction, + interest_rate: float, +) -> ContinuousState: + return (1 + interest_rate) * (wealth - consumption) + + +def _borrowing_constraint( + consumption: ContinuousAction, wealth: ContinuousState +) -> FloatND: + return consumption <= wealth + + +def _next_regime(period: int) -> FloatND: + return jnp.where(period >= 1, RegimeId.dead, RegimeId.alive) + + +def _make_model() -> Model: + alive = Regime( + functions={"utility": _utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5, n_points=5)}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +_HEALTHY_PARAMS = {"discount_factor": 0.95, "interest_rate": 0.05} + + +def test_warning_level_solves_without_per_row_materialisation(): + """Happy-path solve at log_level="warning" returns finite V without + entering the failure-path localisation.""" + model = _make_model() + period_to_regime_to_V_arr = model.solve(params=_HEALTHY_PARAMS, log_level="warning") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + assert not jnp.any(jnp.isinf(V_arr)) + + +def test_nan_failure_raises_with_regime_and_age(): + """A NaN-producing parameter set raises with the offending (regime, age). + + `discount_factor=NaN` poisons the next-V contribution to Q on the + first non-terminal period; the validator must surface the offending + regime in the error message. + """ + model = _make_model() + params = {**_HEALTHY_PARAMS, "discount_factor": float("nan")} + with pytest.raises(InvalidValueFunctionError, match=r"alive"): + model.solve(params=params, log_level="warning") + + +def test_off_level_solves_without_diagnostics(caplog: pytest.LogCaptureFixture): + """log_level="off" emits no diagnostic records and skips the NaN fail-fast. + + Even with a NaN-producing parameter set, solve() returns instead of + raising — the documented contract of `"off"`. + """ + model = _make_model() + params = {**_HEALTHY_PARAMS, "discount_factor": float("nan")} + with caplog.at_level(logging.DEBUG): + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + assert period_to_regime_to_V_arr is not None + assert not [r for r in caplog.records if r.levelno >= logging.WARNING] + + +def test_debug_level_emits_per_period_stats( + caplog: pytest.LogCaptureFixture, tmp_path: Path +): + """log_level="debug" logs a min/max/mean line for every (regime, period).""" + model = _make_model() + with caplog.at_level(logging.DEBUG, logger="lcm"): + model.solve(params=_HEALTHY_PARAMS, log_level="debug", log_path=tmp_path) + debug_stat_lines = [ + r + for r in caplog.records + if "V min=" in r.getMessage() and "max=" in r.getMessage() + ] + assert len(debug_stat_lines) >= 1