Skip to content

solve_brute: lazy diagnostics with per-period fail-fast on NaN#339

Merged
hmgaudecker merged 58 commits intomainfrom
improve/lazy-solve-diagnostics
May 6, 2026
Merged

solve_brute: lazy diagnostics with per-period fail-fast on NaN#339
hmgaudecker merged 58 commits intomainfrom
improve/lazy-solve-diagnostics

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 1, 2026

Summary

Three coupled fixes to the solve-time diagnostics path. All three were necessary to keep the production aca-baseline solve healthy on a 16 GB V100 and surface NaN failures usefully.

1. Stream NaN/Inf reductions instead of stacking-and-flushing (2cc46ff)

PR #334 introduced a deferred-diagnostics accumulator: every (regime, period) appends jnp.any(jnp.isnan(V_arr)) / jnp.any(jnp.isinf(V_arr)) to a Python list, the lists are stacked, and the stacks are .tolist()-ed to host at end of solve. The win was zero per-period host transfers — important for MSM-style estimation loops that re-solve hundreds of times.

The cost: on a 16 GB V100 at production aca-baseline grid sizes, the stacked reduction graph keeps every period's isnan(V_arr) / isinf(V_arr) intermediate alive simultaneously. The post-loop .tolist() asks XLA to compile the fan-in, which then requests ~7.3 GiB on top of the already-resident solution V arrays and OOMs. Symptom: solve() reports every age as "finished in ~14 ms" (dispatch-async times, not actual compute), then JaxRuntimeError: RESOURCE_EXHAUSTED at the first .tolist().

Fix: replace the list-append with a running scalar OR and add a per-period block_until_ready() so each period's reduction kernel finishes (and its intermediate is freed) before the next period dispatches. End of solve transfers two bools to host; on a healthy solve we return without materialising any per-row state. Failure paths walk diagnostic_rows and .item()-localise the offender.

2. Stop pinning per-period V templates in diagnostic_rows (365da07)

The streaming reduction landed first addressed only the per-period reduction buffers; row-level retention was the larger leak. Each _DiagnosticRow was holding the active-period state_action_space, the rolling next_regime_to_V_arr, the regime's flat params, and a compute_intermediates closure (which itself captured the state_action_space). At production grid sizes — 50+ periods × ~6 active regimes — the accumulated references pinned every period's full-shape V mapping in device memory, OOMing the V100 mid-loop on block_until_ready (the next allocation that has nowhere to go).

Fix: strip _DiagnosticRow to the three Python scalars actually needed for failure-path localisation (regime_name, period, age) and reconstruct the heavy bits from solution, internal_regimes, and internal_params inside _raise_at. The reconstruction mirrors the loop's roll-forward semantics: for each regime, take the smallest later period in solution where the regime was active, falling back to a zeros template — the same value the rolling next_regime_to_V_arr slot held during the live dispatch.

3. Fail-fast on NaN per period; rewrite stale diagnostic hint (bf1cdf4)

Two finishing touches that surfaced once the streaming reduction was working:

  • Fail-fast at age boundary. Adds a per-period running_any_nan.item() host transfer right after the existing block_until_ready. On True, the loop breaks out and the existing post-loop emitter raises immediately. Cost: one scalar bool transfer per period — negligible next to max_Q_over_a. Without this, backward induction would finish the entire age range (potentially ~2h on production grids) before raising at the first-NaN row, leaving the user staring at an idle-looking solve. Inf stays non-fatal; the post-loop warning still fires for any period that flagged it.
  • Drop the misleading "re-solve with debug logging" suggestion from validate_V. The diagnostic [NOTE] is added inline by _enrich_with_diagnostics whenever compute_intermediates is wired up — i.e. on the default path — so suggesting a re-solve to "produce" diagnostics is wrong: they were already produced. Replace with a pointer to the [NOTE] for the per-axis breakdown plus a mention of log_path=... for snapshot persistence (the only thing debug-mode actually adds beyond the inline diagnostic).

Why this preserves PR #334's win

block_until_ready() is a device-only sync. No host transfer, no PCIe round-trip — it only waits for the kernel to finish. PR #334 removed .item() / bool-check syncs that did round-trip; that knob stays off. In practice the small jnp.any reduction has finished by the time max_Q_over_a (~14 ms/period) returns, so the call is near-free. The fail-fast .item() (added in commit 3) only fires when diagnostics are enabled (log_level >= "warning"); log_level="off" skips even that, preserving the tight estimation-loop contract.

Changes

  • src/lcm/solution/solve_brute.py
    • Replace diagnostic_any_nan / diagnostic_any_inf lists with running_any_nan / running_any_inf scalars folded via | per (regime, period).
    • Per-period block_until_ready() on the running scalars (and on the last-appended mean when stats are enabled).
    • Per-period running_any_nan.item() for fail-fast on NaN; loop breaks on True.
    • End-of-loop _emit_post_loop_diagnostics orchestrator: .item() on each running scalar; only on True does it walk diagnostic_rows and .item()-localise the offending (regime, age) for _raise_at / logger.warning.
    • _DiagnosticRow reduced to (regime_name, period, age) Python scalars; _raise_at reconstructs state_action_space / next_regime_to_V_arr / compute_intermediates from solution + internal_regimes + internal_params.
    • Drop _StackedReductions, _emit_deferred_diagnostics, the per-row-flag-list variants of _raise_if_nan / _warn_if_inf. Replace with _raise_first_nan_row, _warn_inf_rows, _log_per_period_stats.
  • src/lcm/utils/error_handling.py
    • Drop the misleading "re-solve with debug" suggestion from the InvalidValueFunctionError message.
  • tests/solution/test_diagnostics.py (new)
    • Happy-path solve at log_level="warning" returns finite V.
    • NaN-bearing solve raises InvalidValueFunctionError with the offending regime in the message.
    • log_level="off" returns and emits zero WARNING+ records even when the params would NaN.
    • log_level="debug" emits ≥1 per-(regime, period) V min=… max=… mean=… log line.

Test plan

  • pixi run -e tests-cpu tests (885 passed, 5 skipped)
  • prek run --all-files
  • On a V100 box: production task_simulate_aca runs without the OOM at end of solve.
  • On a V100 box: a NaN-injecting params variant now raises within ~10 min of the offending age, not after the full ~2h backward induction.
  • ASV benchmark on a clean GPU: confirm Mahler-Yum / Precautionary Savings within run-to-run noise of Runtime-supplied points on action grids; tighten grid + regime validators #338's baseline. The aca-baseline benchmark may move slightly because of the per-period block_until_ready; expect single-digit % overhead at most.

Followup (separate change, after this lands)

In aca-estimation/src/aca_estimation/task_simulate_baseline.py and task_simulate_aca.py: drop the log_level="off" workaround that was added when the original OOM blocked production runs.

🤖 Generated with Claude Code

hmgaudecker and others added 22 commits April 29, 2026 06:29
Extends the existing runtime-points mechanism (previously state-only)
to continuous action grids. With this change, an action declared as
`IrregSpacedGrid(n_points=N)` adds an `{action_name: {"points":
"Float1D"}}` entry to the regime params template, and `state_action_space()`
substitutes the runtime-supplied points into `continuous_actions` at
solve / simulate time.

Motivation: aca-dev's structural retirement model has a `consumption`
action grid whose lower bound is the per-iteration `consumption_floor`
parameter. Without this change the c-grid bounds would have to be
fixed at build time, which forces either an over-wide grid (wasted
density) or model rebuilds per estimation iteration (recompilation).

Mirrors the existing state-grid treatment:
- `regime_template.py`: walks `regime.actions` alongside `regime.states`,
  factoring the shared shadowing check into helpers.
- `interfaces.InternalRegime.state_action_space()`: builds both
  state and continuous-action replacements in a single sweep over
  `self.grids`, then calls `_base_state_action_space.replace(...)`
  with whichever side actually had substitutions.
- `pandas_utils._is_runtime_grid_param`: also recognises action grids
  so column extraction in `to_dataframe()` keeps working.

Tests (TDD): four new tests in `tests/test_runtime_params.py`,
mirroring the state-grid counterparts — params-template entry,
solve, runtime-vs-fixed equivalence, and a sanity check that
varying runtime points actually changes V.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model now declares `consumption` as `IrregSpacedGrid(n_points=N)`
with runtime-supplied points. The bench builder now passes
`model=model` to `get_benchmark_params` so consumption gridpoints
are injected into params before solving.

aca-model rev: adc8a19 → 4123fe9 (feature/runtime-consumption-points)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`IrregSpacedGrid(n_points=N)` declares a continuous grid whose values
are supplied at runtime via `params[regime][grid_name]['points']`.
Substitution happens inside
`InternalRegime.state_action_space(regime_params=...)` at solve /
simulate time. Any code path that calls `to_jax()` on the base grid
before substitution silently got `jnp.full(N, jnp.nan)` and went on to
compute against the placeholder.

That is exactly what fired in `validate_initial_conditions` for
`task_simulate_aca`: the validator built the action grid by calling
`internal_regime.grids[name].to_jax()` (placeholder NaNs), then asked
`borrowing_constraint(consumption=NaN, wealth=W)` whether each
gridpoint was feasible. NaN comparisons are False, so every action
was reported infeasible for every subject in every initial regime.

Make the invariant explicit: `IrregSpacedGrid.to_jax()` raises
`GridInitializationError` for runtime-supplied grids, with a message
pointing the caller at `state_action_space(regime_params=...)` for
real values or `.n_points` for shape. Confine the legitimate
"placeholder needed for AOT tracing" caller (the base state-action
space) to a private helper in `state_action_space.py` that uses NaN
explicitly. Reroute `_check_regime_feasibility` through the
substituted state-action space.

Add regression tests covering both runtime action and runtime state
grids round-tripping `simulate(check_initial_conditions=True)`, and
unit tests pinning down the new raise + the existing NaN-source
mechanics in `map_coordinates` / `get_irreg_coordinate`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move late `DiscreteGrid`, `map_coordinates`, and `get_irreg_coordinate`
imports to the module top level (PLC0415), drop the unnecessary `val`
assignment before return (RET504), and mark the unused `wealth` arg in
the local `borrow` constraint as `# noqa: ARG001`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A regime function whose output is then re-indexed by a discrete state
inside another consumer (function, constraint, or transition) is a
silent footgun: pylcm broadcasts function outputs to per-cell scalars
before consumption, so the indexing silently produces NaN at runtime
instead of the intended scalar. The aca-baseline benchmark hit this
via `bequest(... utility_scale_factor[pref_type])` where
`utility_scale_factor` is registered as a regime function — the dead
regime's V came back all-NaN with no actionable error.

Adds an AST-walking validator in `validate_logical_consistency` that
inspects every consumer (functions, constraints, transition) for a
`Subscript(Name=X, slice=Name=Y)` pattern where `X` is in
`regime.functions` and `Y` is a `DiscreteGrid` state. If any clash is
found, raises `RegimeInitializationError` listing each clash and
pointing the user at the safe pattern (function takes the state,
returns a scalar — see `discount_factor`).

Three TDD tests in `tests/test_function_output_state_indexing.py`:
- the clash raises (functions case)
- the safe pattern (function takes the state, scalar return) builds
- the check applies to constraints too

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model `feature/runtime-consumption-points` 4123fe9 → 1342861
(refactors `utility_scale_factor` to take `pref_type` and return a
scalar, eliminating the regime-function-output / state-indexed-input
clash that produced NaN in the dead regime's V).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ion space

`create_regime_state_action_space` (used during forward simulation) was
calling `create_state_action_space` directly, which leaves
`pass_points_at_runtime=True` IrregSpacedGrid action grids as their NaN
placeholder. The placeholder fed straight into
`argmax_and_max_Q_over_a` and `_lookup_values_from_indices`, so optimal
actions came back NaN, the source regime's `next_state` propagated NaN
into every target regime's namespaced state, and `validate_V` raised on
the first downstream regime whose utility depended on those states
(the dead regime in aca-model: assets/pref_type both NaN).

Route through `internal_regime.state_action_space(regime_params=...)`
(the same path solve uses) and overlay the per-subject states. Add a
TDD regression test in tests/test_runtime_params.py covering the
simulate path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…grid values

`LogSpacedGrid` previously inherited only the generic continuous-grid
checks (start < stop, n_points > 0). With `start <= 0`, `to_jax()`
silently returned NaN/-inf, and the bug would only surface deep
inside an interpolation kernel. Now refuses at construction.

While here, tighten two adjacent silent-failure modes:

- `_validate_continuous_grid` rejects non-finite `start`/`stop`.
  `start >= stop` is False for NaN, so a NaN bound previously slipped
  through every check.
- `_validate_irreg_spaced_grid` rejects non-finite points. The
  ascending-order test uses `>=`, which is False for NaN, so a NaN
  point previously passed the order check silently.

Both matter for runtime-supplied grids: e.g. `geomspace(consumption_floor,
MAX, N)` with a bad `consumption_floor` produces all-NaN points, and we
want that caught at the grid layer rather than as a downstream V_arr
NaN diagnostic.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… banners

- tests/test_single_feasible_action.py: drop three decorative section
  banners (AGENTS.md prohibits `# ---...---` separators); fold the
  banner prose into the docstrings of the tests/helpers below.
- tests/test_single_feasible_action.py: type-annotate `_crra_bequest`
  and `_alive_utility`'s pref_type / consumption_weight /
  coefficient_rra arguments (DiscreteState / FloatND).
- tests/test_runtime_params.py: type-annotate `_make_action_grid_model`
  and `_make_action_grid_model_with_stateful_dead`.
- src/lcm/simulation/transitions.py: re-run `_validate_all_states_present`
  in the new `create_regime_state_action_space` (the substitution
  switch from `create_state_action_space(states=...)` to
  `base.replace(states=...)` had silently dropped this check).
- src/lcm/params/regime_template.py: docstring on
  `_fail_if_runtime_grid_shadows_function`; fix stale phrasing in
  `create_regime_params_template` ("matching the state name" →
  "matching the state or action name").
- src/lcm/interfaces.py: comment why the `_ShockGrid` substitution
  branch is gated on `in_states` only (state-only by design,
  AGENTS.md forbids ShockGrids as actions; gate is the explicit
  enforcement of that invariant).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The validator's error message already explains why; the class docstring
only needs the contract.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rid path

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Derived categoricals (`regime.derived_categoricals`, function outputs
that pylcm treats as categoricals — see
https://pylcm.readthedocs.io/en/latest/pandas-interop/#derived-categoricals)
suffer the same per-cell broadcast clash as discrete states. Extend
`discrete_state_names` in `_validate_function_output_state_indexing`
to include them; add a TDD test.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…module)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
pylcm is a general library; references to a particular companion
application become stale fast and force readers to know unrelated
projects to follow the test rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The variable previously named `discrete_state_names` accumulated state
DiscreteGrids, derived categoricals, and now discrete actions — all
three suffer the same per-cell broadcast clash when a consumer does
`func_output[X]`. Renamed the variable, the two helpers
(`_validate_function_output_grid_indexing`,
`_find_function_output_grid_indexing`), the test module
(`test_function_output_grid_indexing.py`), and the error-message
wording ("discrete state" → "discrete grid"). Added a TDD test for
the discrete-action case.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion

The previous docstring claimed the indexing 'silently produces NaN', but a
disabled-validator probe shows otherwise:

- When the producer takes the discrete grid as input, its output is a
  per-cell scalar; `func_output[grid]` raises `IndexError: Too many indices`
  at trace time. This is the real footgun the validator should catch.
- When the producer does NOT take the discrete grid as input, its output
  stays array-shaped and `func_output[grid]` is correct code that solves
  to sensible V values.

The previous validator flagged both shapes — including the safe one — as a
clash. Tighten: only fire when the producing function also takes the
discrete grid as input. Update the description to match observed behaviour
(IndexError, not NaN). Add a regression test that exercises the
array-valued-producer + state-indexed-consumer shape and asserts it builds
without raising.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #334 introduced a deferred-diagnostics accumulator that appends every
(regime, period) NaN/Inf flag to a Python list, stacks the lists at end
of solve, and `.tolist()`s the stacks to host. On a 16 GB V100 at
production aca-baseline grid sizes the stacked reduction graph holds the
per-period `isnan(V_arr)` / `isinf(V_arr)` intermediates alive
simultaneously; the post-loop `.tolist()` then asks XLA to compile the
fan-in and OOMs on a ~7.3 GiB allocation on top of the already-resident
solution V arrays. Symptom: backward induction reports every age as
"finished in ~14 ms" (dispatch-async times), then
`JaxRuntimeError: RESOURCE_EXHAUSTED` at the first `.tolist()`.

Fix: replace the per-period list-append with a running scalar OR; add a
per-period `block_until_ready()` so each period's reduction kernel
finishes (and its intermediate is freed) before the next period
dispatches. `block_until_ready` is device-only — no host transfer, no
PCIe round-trip — so it doesn't reintroduce the per-period sync that
#334 removed; in practice the small reduction has finished by the time
`max_Q_over_a` (~14 ms/period) returns.

End of solve: one `.item()` per running scalar. On a healthy solve those
two bools are False and we return without materialising any per-row
state. Failure paths (`running_any_nan` / `running_any_inf` True) walk
`diagnostic_rows` and materialise one bool per row to localise the
offender — same total host transfers as the prior code, but only on the
failure path.

Debug-stats path (`log_level="debug"`) still appends min/max/mean per
period; a single per-period `block_until_ready` after the appends frees
those intermediates too. The end-of-solve `_log_per_period_stats` keeps
the existing per-(regime, period) log line.

`_StackedReductions`, `_emit_deferred_diagnostics`, and the old
`_raise_if_nan` / `_warn_if_inf` (taking pre-materialised flag lists)
are replaced by `_emit_post_loop_diagnostics` (orchestrator),
`_raise_first_nan_row`, `_warn_inf_rows`, and `_log_per_period_stats`.

Tests: new `tests/solution/test_diagnostics.py` covering the four log
levels — happy-path warning, NaN-raise with `(regime, age)` in the
message, off-level skip, and per-period debug stats.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 1, 2026

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

Benchmark comparison (main → HEAD)

Comparing b880524f (main) → eb694320 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 47.150 s 47.910 s 1.02
peak GPU mem 2.86 GB 3.38 GB 1.18
compilation time 422.38 s 429.47 s 1.02
peak CPU mem 8.42 GB 8.31 GB 0.99
Mahler-Yum execution time 4.769 s 4.627 s 0.97
peak GPU mem 522 MB 522 MB 1.00
compilation time 16.19 s 16.66 s 1.03
peak CPU mem 1.67 GB 1.68 GB 1.00
Precautionary Savings - Solve execution time 45.8 ms 51.6 ms 1.13
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.53 s 2.61 s 1.03
peak CPU mem 1.12 GB 1.12 GB 1.00
Precautionary Savings - Simulate execution time 115.6 ms 119.9 ms 1.04
peak GPU mem 340 MB 340 MB 1.00
compilation time 6.21 s 5.88 s 0.95
peak CPU mem 1.28 GB 1.29 GB 1.00
Precautionary Savings - Solve & Simulate execution time 140.6 ms 148.7 ms 1.06
peak GPU mem 577 MB 577 MB 1.00
compilation time 8.05 s 7.77 s 0.96
peak CPU mem 1.28 GB 1.28 GB 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 280.8 ms 276.7 ms 0.99
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 8.60 s 8.72 s 1.01
peak CPU mem 1.33 GB 1.34 GB 1.00

Each `_DiagnosticRow` previously held the active-period
`state_action_space`, the rolling `next_regime_to_V_arr`, the regime's
flat params, and a `compute_intermediates` closure (which itself
captured the state_action_space). At production grid sizes — 50+
periods × ~6 active regimes — the accumulated references pin every
period's full-shape V mapping in device memory, OOMing the V100 16 GB
mid-loop on `block_until_ready` (the next allocation that has nowhere
to go).

The streaming NaN/Inf reduction landed earlier addressed only the
per-period reduction buffers; the row-level retention is the larger
leak. Strip `_DiagnosticRow` to the three Python scalars actually
needed for failure-path localisation (`regime_name`, `period`, `age`)
and reconstruct the heavy bits from `solution`, `internal_regimes`,
and `internal_params` inside `_raise_at`. The reconstruction mirrors
the loop's roll-forward semantics: for each regime, take the smallest
later period in `solution` where the regime was active, falling back
to a zeros template — the same value the rolling
`next_regime_to_V_arr` slot held during the live dispatch.

Also lock the row's shape via a structural test so future changes
that re-introduce device-backed fields fail loudly in CI rather than
silently regressing OOM behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two changes targeting the NaN-in-V failure path:

1. Fail-fast at age boundary. Adds a per-period
   `running_any_nan.item()` host transfer right after the existing
   `block_until_ready`. On True, the loop breaks out and the existing
   post-loop emitter raises immediately. Cost: one scalar bool transfer
   per period — negligible next to `max_Q_over_a`. Without this,
   backward induction would finish the entire age range (potentially
   ~2h on production grids) before raising at the first-NaN row,
   leaving the user staring at an idle-looking solve.
   Inf stays non-fatal; the post-loop warning still fires for any
   period that flagged it.

2. Drop the misleading "re-solve with debug logging" suggestion from
   `validate_V`. The diagnostic [NOTE] is added inline by
   `_enrich_with_diagnostics` whenever `compute_intermediates` is
   wired up — i.e. on the default path — so suggesting a re-solve to
   "produce" diagnostics is wrong: they were already produced. Replace
   with a pointer to the [NOTE] for the per-axis breakdown plus a
   mention of `log_path=...` for snapshot persistence (the only thing
   debug-mode actually adds beyond the inline diagnostic).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker changed the title solve_brute: stream NaN/Inf reductions instead of stacking-and-flushing solve_brute: lazy diagnostics with per-period fail-fast on NaN May 4, 2026
hmgaudecker and others added 2 commits May 4, 2026 07:53
When `log_path` is configured, the failure path already calls
`save_solve_snapshot(...)` (`model.py:223-230` and `:334-341`) before
re-raising — but the path it returns wasn't surfaced anywhere, so the
user saw a generic "pass `log_path=...`" hint pointing them to do
something they had already done. Capture the returned `snap_dir` and
attach it via `exc.add_note(f"Snapshot saved to {snap_dir}")`. The
note appears alongside the diagnostic-summary note that
`_enrich_with_diagnostics` adds, so the user sees both the per-axis
NaN breakdown and the exact `solve_snapshot_NNN/` directory in one
exception.

Drop the now-redundant `log_path=...` suggestion from `validate_V`'s
message. Replace with a short pointer to the [NOTE] block: when
`log_path` is set, the second note has the path; when it isn't, the
inline diagnostic still pinpoints the offending intermediate. The
debugging-guide link stays.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A state transition that consumes another state transition's output (e.g.
`next_wealth(next_aime, ...)`) was rejected at solve time with
`InvalidParamsError: Missing required parameter:
'<regime>__next_wealth__next_aime'`. The upstream cause is in
`create_regime_params_template`, which classified `next_<state>`
references inside transition signatures as regime-level fixed_params.

`get_next_state_function_for_solution` already merges all transitions
and DAG functions into a single dict before calling
`concatenate_functions`, so dags resolves the chain at evaluation time.
The block is purely the params-template step. Extending the exempt set
with `{f"next_{name}" for name in regime.states}` lets dags do its job.

This unlocks per-target transition factories whose outputs feed each
other — e.g., a `next_assets` that reads `next_aime` to compute a
next-period imputed value (the aca-model pension correction use case).

No JAX parallelism implications: the fix is build-time bookkeeping only.
JIT scope, vmap structure, and scan layout are unchanged; the new
dependency edge runs per-gridpoint inside the same merged DAG.
hmgaudecker added a commit that referenced this pull request May 4, 2026
The merge from #339 reverted the pin to 134286108, which is pre-
n_subjects-requirement and would TypeError on the benchmark's
create_benchmark_model(n_subjects=_N_SUBJECTS) call. Restore the
83f22500 pin (post-pension-correction, accepts n_subjects) — only
this branch's pylcm has Model.n_subjects, so the bump is safe here.
The simulate path flattened all transitions into one DAG keyed by
`<target>__<next_state>`, which prevented an unqualified `next_<state>`
parameter on a transition or auxiliary function from resolving across
the per-target boundary. The solve path doesn't have this problem
because it builds one DAG per target with bare `next_<state>` keys.

Switch the simulate compile to mirror that structure: one DAG per
target with unqualified keys, then merge per-target outputs into a
single flat dict keyed by `<target>__next_<state>` for the downstream
`_update_states_for_subjects` consumer.

Stochastic-transition wrappers keep their target-qualified
`key_<target>__next_<state>` and `weight_<target>__next_<state>` arg
names so multi-target callers still draw distinct realisations per
target.

This unblocks the aca-model pension imputation correction, where
`imputed_pension_wealth_next_period(next_aime, ...)` consumes
unqualified `next_aime` from the per-target DAG.

Extends test_chained_state_transitions with a `model.simulate(...)`
case — the original test only exercised solve, which masked this
asymmetry.
Copy link
Copy Markdown
Member Author

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autoreview

@hmgaudecker hmgaudecker requested review from mj023 and timmens May 5, 2026 18:39
Comment thread tests/test_nan_diagnostics.py Outdated
Comment thread src/lcm/solution/solve_brute.py Outdated
hmgaudecker added a commit that referenced this pull request May 6, 2026
Two cleanups in the per-target simulate path landed via #339/#340.

(1) `_build_combined_simulation_function` previously wrapped a
    `(*args, **kwargs)` shim with `with_signature`, then zipped
    positional args back to names. Functionally fine, but the inner
    function's signature did not match the advertised one, requiring
    a `# ty: ignore` and protocol-mismatch warnings. Synthesise a
    real function with named positional-or-keyword parameters via
    `exec()` (same pattern as `dataclasses` and `attrs`). `vmap_1d`
    and other introspecting callers now see a faithful signature.

(2) `test_get_next_state_function_with_simulate_target` claimed to
    assert qualified output keys (`mock__next_a`, `mock__next_b`)
    but compared against unqualified `{"a": ..., "b": ...}`. The
    test passed only because `pybaum.tree_equal` ignores dict keys
    when leaf counts and values match. Rewrite to assert keys
    directly via `set(got.keys())` and check both regime outputs;
    drop the unused stochastic scaffolding (`f_weight_b`, `f_b`
    returning None, key kwarg).
hmgaudecker and others added 14 commits May 6, 2026 05:41
mj023's inline (next_state.py): drop hand-rolled
`_build_combined_simulation_function` (`with_signature` over a
`(*args, **kwargs)` shim that zipped positional args back to names).
Use `concatenate_functions` directly on `per_target_funcs`. The
combined function now returns the nested form
`{target_regime_name: {next_<state>: array}}` instead of the flat
`<target>__<next_state>` shape, which is more natural to construct
and exactly what the consumer needs.

Move the per-target flattening into the consumer
`_update_states_for_subjects`: iterate the outer regime keys, strip
`next_` from inner keys, rebuild the flat `<target>__<state>` lookup
into `all_states`. Update the test fixture and the
`NextStateSimulationFunction` protocol's return type accordingly,
plus broaden the `vmap_1d` `FunctionWithArrayReturn` typevar bound
to admit the new nested-mapping return.

timmens' inline (regime_template.py): fold `next_state_names` into
`H_variables` rather than carrying it as a separate "exemption". The
docstring no longer needs the multi-paragraph rationale; the unified
`H_variables` set documents itself: regime functions, `period`,
`age`, `E_next_V`, and `next_<state>` outputs are all internal
wiring that pylcm resolves at evaluation time, never user-facing
fixed_params.

timmens' inline (next_state.py:99): rename `target_trans` to
`target_transitions` in the loop variable and the
`_extend_target_transitions_for_simulation` signature.

timmens' inline (test_chained_state_transitions.py): rewrite both
tests to assert behavior in user-facing terms instead of rehearsing
the prior bug. `test_solve_..._returns_finite_value_function` checks
that the active regime's V is finite. `test_simulate_..._yields_
expected_next_wealth` checks that `next_wealth_t = wealth_t -
consumption_t + 0.1 * next_aime_t` holds period-over-period in the
simulated DataFrame, which can only succeed if the chained
dependency `next_aime → next_wealth` was wired correctly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5cd2187 on .ai-instructions/main adds a "Test-Driven Development —
always" section to AGENTS.md (with two further subsections on
behavior-focused docstrings and concrete-value assertions). pylcm's
own .ai-instructions submodule is independent of aca-dev's, so this
bump is needed for agents working in pylcm directly to see the new
guidance — prompted by the pylcm #342 review feedback that motivated
the TDD section in the first place.

Also picks up 528a011..135a3cd: pinned-tool-version bump and Tier B
gitignore additions for pytask.lock files.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The TDD guidance landed canonically in .ai-instructions/AGENTS.md
(5cd2187), but reaches pylcm only via the @-include chain through the
.ai-instructions submodule. Inline directly in pylcm/AGENTS.md so the
policy is load-bearing in the file an agent sees first when working
in pylcm, not contingent on the submodule pointer being current.

Restructure: promote `Testing` to its own top-level section before
`Development Notes` (was a `### Testing Style` subsection inside it).
Three new TDD subsections sit above the existing pytest-mechanics
bullets:

- Test-Driven Development — always (red-green-refactor cycle, applied
  to features / bug fixes / refactors).
- Test docstrings — describe behavior, not history (pretend the
  reader has never seen the PR).
- Concrete-value assertions (assert what the result is, not just
  that it didn't crash).

Verbatim from .ai-instructions/AGENTS.md so the two stay in sync.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
timmens' inline (test_nan_diagnostics.py): delete
`test_diagnostic_row_holds_only_python_scalars`. The test asserted
that `_DiagnosticRow.__dataclass_fields__` equals a fixed three-name
set with a docstring rehearsing why those specific fields were
removed. That's testing the implementation, not behavior — an OOM
regression would not be caught by counting dataclass fields. The
load-bearing constraint ("no device-backed references on the row")
already lives in the `_DiagnosticRow` docstring, where readers will
find it.

timmens' inline (solve_brute.py:431): tighten `_DiagnosticRow` and
`_emit_post_loop_diagnostics` docstrings. Both rehearsed the prior
design ("The earlier design captured ...", "16 GB device that was
OOMing on the previous stack-and-flush pattern") in second-paragraph
"before the fix" framing. Drop those paragraphs; keep the
forward-looking constraint ("only Python-scalar metadata, no device
references") and the actual mechanism ("two `.item()` calls decide
whether to enter the per-row failure path"). Same anti-pattern the
new TDD-always section in AGENTS.md just codified for tests — apply
it to source docstrings as well.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Per timmens' #342 review (and a polish pass on the prior fold-into-
H_variables commit): drop the H_variables intermediate. There's no
remaining reason to keep two sets — every name in the union is
"internal wiring that pylcm resolves at evaluation time, never
user-facing fixed_params" — so build `variables` directly with all
six categories (states, actions, regime functions, next_<state>
outputs, period, age, E_next_V).

Also: collapse the function signature to one line, tighten the
docstring (`(period, age, E_next_V)` on one line, backtick-quote
`IrregSpacedGrid`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
V_arr shape is model-dependent (states × grid points × regimes), so
the per-period isnan/isinf intermediate buffer size is too. The
qualitative point — these allocations stack up across periods if not
freed — is what matters; the size figure was a distraction that
implied a fixed scale only true on whichever model the comment was
written against.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codifies three patterns observed across this PR's review cleanups and
#339's, where source docstrings and inline comments rehearsed prior
implementations, cited PR numbers, and quoted model-specific figures.
None of that survives the 9-month-without-PR-context test.

The new "## Docstring Style" section sits next to "## Testing" and
adds three subsections, each with good/bad examples drawn from real
recent code:

- Describe state, not history (no "earlier", "previously", "the old
  design", "before the fix").
- No PR numbers, no model-specific magic numbers (PRs rot; "~2 MB at
  production grid sizes" only holds on one model/box).
- Bulleted lists for enumerated cases (one bullet per case beats
  running prose for log levels, regime kinds, dispatch strategies).

Cross-references the existing "Test docstrings — describe behavior,
not history" subsection rather than duplicating the same rule with
test-specific framing.

Also bump .ai-instructions submodule pointer to 609ac4a, which
landed the same docstring-style guidance on the canonical version.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t list

Three follow-ups to the AGENTS.md docstring-style section just
inlined on this branch via the cascade-merge:

- Drop "the per-period host stalls that #334 removed" from the
  diagnostics-accumulator comment. PR refs rot. Restate as "doesn't
  introduce a host stall" with the same forward-looking explanation.
- Drop "(~2 MB each at production grid sizes)" — V_arr shape is
  model-dependent (states × grid points × regimes), so any fixed
  byte figure misleads on every other model. Restated as
  "(V_arr-shaped, so model-dependent)" earlier in the same comment;
  this commit's change is the second history-framed phrase from the
  same block.
- Reflow the log-level prose paragraph into a bulleted list. The
  three cases (`"off"` / `"warning"` ∪ `"progress"` / `"debug"`) read
  faster as bullets, and the `"off"` qualifier ("skips even the NaN
  fail-fast") fits inline rather than needing the separate sentence
  about contracts and estimation loops.

Also drop "without the deferred-stack fan-in that previously OOMed at
production sizes" from `tests/solution/test_diagnostics.py` — same
"previously" anti-pattern.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Base automatically changed from feature/next-state-deps-in-transitions to main May 6, 2026 05:45
hmgaudecker added a commit that referenced this pull request May 6, 2026
…ns bump, source cleanups

Brings the #342#339 stack onto #340. Net surface:

- New `## Docstring Style` section in AGENTS.md (3 subsections with
  good/bad examples drawn from #342/#339 cleanups).
- `.ai-instructions` submodule pointer to 609ac4a, picking up the
  same docstring-style guidance on the canonical version.
- `_build_combined_simulation_function` is gone — replaced by direct
  `concatenate_functions` on `per_target_funcs`. Output shape
  changes from flat `<target>__<next_state>` to nested
  `{target: {next_<state>: array}}`. Consumer
  `_update_states_for_subjects` updated accordingly.
- `regime_template.create_regime_params_template`: `H_variables`
  intermediate dropped; single `variables` set with all categories.
- `solve_brute.py` diagnostics-accumulator comment: drop "#334"
  PR ref, drop "~2 MB at production grid sizes" magic number,
  reflow log-level prose into bulleted list.
- `tests/test_chained_state_transitions.py`: rewrite to assert
  concrete chain mathematics (`next_wealth_t = wealth_t - c_t +
  0.1 * next_aime_t`) instead of "no NaN".
- `tests/test_next_state.py::test_get_next_state_function_with_simulate_target`
  updated to assert nested output shape, replacing the prior
  flat-key version that was added on this branch as `c969b1a`.

Conflict resolution: take #339's version of `next_state.py` and
`test_chained_state_transitions.py` (the cascade-target's reshape
supersedes the branch's earlier exec-based combined function).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker added a commit that referenced this pull request May 6, 2026
The cascade-merge from #339 changed `get_next_state_function_for_simulation`
to return a nested mapping `{target: {next_<state>: array}}` (the
flat `<target>__<next_state>` form went away with
`_build_combined_simulation_function`). The simulate-target test on
this branch was set up against the flat shape (added as `c969b1a`
before the merge); update its assertions to the nested form so the
test still validates the actual output structure.
@hmgaudecker hmgaudecker merged commit a4eca9b into main May 6, 2026
10 checks passed
@hmgaudecker hmgaudecker deleted the improve/lazy-solve-diagnostics branch May 6, 2026 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants