Skip to content

Model.n_subjects: AOT-compile simulate, lock integer dtype to int32#340

Open
hmgaudecker wants to merge 103 commits intomainfrom
feat/simulate-aot-n-subjects
Open

Model.n_subjects: AOT-compile simulate, lock integer dtype to int32#340
hmgaudecker wants to merge 103 commits intomainfrom
feat/simulate-aot-n-subjects

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 1, 2026

Goal

Do the same for simulate as we did for solve before: AOT-compile. Main reason is to fail early and get cleaner benchmarks (sim-jit was included in runtime previously).

This showed subtle differences re what int types were used during compilation and runs. Explicit casts inside the hot loop turned out to be necessary, but very costly. For speed / mem efficiency / peace of mind, convert everything to jnp.int32 at the boundary (no use cases will ever require ints > 2_147_483_647).

Summary

Two coupled changes that together let simulate(...) ship a single AOT-compiled program per unique callable and batch shape:

1. Model(n_subjects: int | None = None) — AOT-compile simulate

  • When n_subjects is set, the first matching simulate(...) call AOT-compiles every unique simulate function (argmax_and_max_Q_over_a per period, next_state per regime, compute_regime_transition_probs per regime) for batch shape n_subjects in parallel via ThreadPoolExecutor, mirroring solve's existing AOT path in solve_brute._compile_all_functions.
  • Subsequent calls with the same size hit the cache; calls with a mismatching size warn once per size and fall back to the runtime-traced path.
  • Side benefit: normalises period_to_regime_to_V_arr at the entry of simulate(...) so every period dispatches with the same pytree (active-regime padding with zeros). Without this the last period's empty next_regime_to_V_arr breaks both the AOT pytree signature and JAX's own JIT cache reuse.
  • n_subjects=None (the default) preserves the previous lazy behaviour.

2. Lock integer dtype to int32 end-to-end (formerly #341, squash-merged as 6c610d1)

The lazy JIT cache silently retraces per dtype, so int32 (no x64) vs int64 (x64) variants of the same regime compiled into different specialisations. AOT compile via jax.jit(...).lower(**args).compile() ships a single signature and broke at runtime with int32[N] vs int64[N] mismatches. Fix:

  • Tighten Int1D/IntND/DiscreteState/DiscreteAction/ScalarInt in src/lcm/typing.py from Int to Int32 — ~113 internal usages inherit the dtype constraint without further edits.
  • Pin DiscreteGrid.to_jax() to int32 regardless of jax_enable_x64.
  • build_initial_states casts discrete states to grid dtype (one-shot fix for the period-0 vs period-1+ dispatch dtype split).
  • Boundary casts at every site whose integer dtype actually depends on jax_enable_x64:
    • regime_building/argmax.py:46, 68argmax and scalar fallback (real int64→int32 narrowing under x64).
    • simulation/initial_conditions.py:365, 409 and utils/error_handling.py:376where index extractions in error/validation paths.
  • Lock the AOT signature inputs:
    • pandas_utils.py:155 — regime-id ingestion cast to int32 to match DiscreteGrid.to_jax().
    • simulation/simulate.py:102subject_regime_ids sentinel buffer pinned to int32 regardless of input regime dtype.

Note: jnp.searchsorted already returns int32 even under x64, and unravel_index outputs are immediately consumed by index ops that accept any int dtype — so no inline casts are needed inside hot interpolation/lookup kernels (an earlier draft added them, which broke XLA fusion and roughly doubled GPU peak memory on aca-baseline).

Test plan

  • pixi run -e tests-cpu pytest tests/simulation/test_simulate_aot.py — 7 new tests cover validation, status-quo regression, AOT-compile-once, mismatch warn, warn-once-per-size.
  • pixi run -e tests-cpu pytest tests/ -n 7 — 895 passed, 5 skipped (full pylcm suite green; +3 dtype-invariant tests in tests/test_int_dtype_invariants.py).
  • pixi run -e type-checking ty — clean.
  • prek run --all-files — clean.
  • aca-dev simulate-AOT smoke — simulate OK, n=20 with all int initial-condition entries reporting int32.

Stacked on #339.

🤖 Generated with Claude Code

hmgaudecker and others added 23 commits April 29, 2026 06:29
Extends the existing runtime-points mechanism (previously state-only)
to continuous action grids. With this change, an action declared as
`IrregSpacedGrid(n_points=N)` adds an `{action_name: {"points":
"Float1D"}}` entry to the regime params template, and `state_action_space()`
substitutes the runtime-supplied points into `continuous_actions` at
solve / simulate time.

Motivation: aca-dev's structural retirement model has a `consumption`
action grid whose lower bound is the per-iteration `consumption_floor`
parameter. Without this change the c-grid bounds would have to be
fixed at build time, which forces either an over-wide grid (wasted
density) or model rebuilds per estimation iteration (recompilation).

Mirrors the existing state-grid treatment:
- `regime_template.py`: walks `regime.actions` alongside `regime.states`,
  factoring the shared shadowing check into helpers.
- `interfaces.InternalRegime.state_action_space()`: builds both
  state and continuous-action replacements in a single sweep over
  `self.grids`, then calls `_base_state_action_space.replace(...)`
  with whichever side actually had substitutions.
- `pandas_utils._is_runtime_grid_param`: also recognises action grids
  so column extraction in `to_dataframe()` keeps working.

Tests (TDD): four new tests in `tests/test_runtime_params.py`,
mirroring the state-grid counterparts — params-template entry,
solve, runtime-vs-fixed equivalence, and a sanity check that
varying runtime points actually changes V.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model now declares `consumption` as `IrregSpacedGrid(n_points=N)`
with runtime-supplied points. The bench builder now passes
`model=model` to `get_benchmark_params` so consumption gridpoints
are injected into params before solving.

aca-model rev: adc8a19 → 4123fe9 (feature/runtime-consumption-points)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`IrregSpacedGrid(n_points=N)` declares a continuous grid whose values
are supplied at runtime via `params[regime][grid_name]['points']`.
Substitution happens inside
`InternalRegime.state_action_space(regime_params=...)` at solve /
simulate time. Any code path that calls `to_jax()` on the base grid
before substitution silently got `jnp.full(N, jnp.nan)` and went on to
compute against the placeholder.

That is exactly what fired in `validate_initial_conditions` for
`task_simulate_aca`: the validator built the action grid by calling
`internal_regime.grids[name].to_jax()` (placeholder NaNs), then asked
`borrowing_constraint(consumption=NaN, wealth=W)` whether each
gridpoint was feasible. NaN comparisons are False, so every action
was reported infeasible for every subject in every initial regime.

Make the invariant explicit: `IrregSpacedGrid.to_jax()` raises
`GridInitializationError` for runtime-supplied grids, with a message
pointing the caller at `state_action_space(regime_params=...)` for
real values or `.n_points` for shape. Confine the legitimate
"placeholder needed for AOT tracing" caller (the base state-action
space) to a private helper in `state_action_space.py` that uses NaN
explicitly. Reroute `_check_regime_feasibility` through the
substituted state-action space.

Add regression tests covering both runtime action and runtime state
grids round-tripping `simulate(check_initial_conditions=True)`, and
unit tests pinning down the new raise + the existing NaN-source
mechanics in `map_coordinates` / `get_irreg_coordinate`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move late `DiscreteGrid`, `map_coordinates`, and `get_irreg_coordinate`
imports to the module top level (PLC0415), drop the unnecessary `val`
assignment before return (RET504), and mark the unused `wealth` arg in
the local `borrow` constraint as `# noqa: ARG001`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A regime function whose output is then re-indexed by a discrete state
inside another consumer (function, constraint, or transition) is a
silent footgun: pylcm broadcasts function outputs to per-cell scalars
before consumption, so the indexing silently produces NaN at runtime
instead of the intended scalar. The aca-baseline benchmark hit this
via `bequest(... utility_scale_factor[pref_type])` where
`utility_scale_factor` is registered as a regime function — the dead
regime's V came back all-NaN with no actionable error.

Adds an AST-walking validator in `validate_logical_consistency` that
inspects every consumer (functions, constraints, transition) for a
`Subscript(Name=X, slice=Name=Y)` pattern where `X` is in
`regime.functions` and `Y` is a `DiscreteGrid` state. If any clash is
found, raises `RegimeInitializationError` listing each clash and
pointing the user at the safe pattern (function takes the state,
returns a scalar — see `discount_factor`).

Three TDD tests in `tests/test_function_output_state_indexing.py`:
- the clash raises (functions case)
- the safe pattern (function takes the state, scalar return) builds
- the check applies to constraints too

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model `feature/runtime-consumption-points` 4123fe9 → 1342861
(refactors `utility_scale_factor` to take `pref_type` and return a
scalar, eliminating the regime-function-output / state-indexed-input
clash that produced NaN in the dead regime's V).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ion space

`create_regime_state_action_space` (used during forward simulation) was
calling `create_state_action_space` directly, which leaves
`pass_points_at_runtime=True` IrregSpacedGrid action grids as their NaN
placeholder. The placeholder fed straight into
`argmax_and_max_Q_over_a` and `_lookup_values_from_indices`, so optimal
actions came back NaN, the source regime's `next_state` propagated NaN
into every target regime's namespaced state, and `validate_V` raised on
the first downstream regime whose utility depended on those states
(the dead regime in aca-model: assets/pref_type both NaN).

Route through `internal_regime.state_action_space(regime_params=...)`
(the same path solve uses) and overlay the per-subject states. Add a
TDD regression test in tests/test_runtime_params.py covering the
simulate path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…grid values

`LogSpacedGrid` previously inherited only the generic continuous-grid
checks (start < stop, n_points > 0). With `start <= 0`, `to_jax()`
silently returned NaN/-inf, and the bug would only surface deep
inside an interpolation kernel. Now refuses at construction.

While here, tighten two adjacent silent-failure modes:

- `_validate_continuous_grid` rejects non-finite `start`/`stop`.
  `start >= stop` is False for NaN, so a NaN bound previously slipped
  through every check.
- `_validate_irreg_spaced_grid` rejects non-finite points. The
  ascending-order test uses `>=`, which is False for NaN, so a NaN
  point previously passed the order check silently.

Both matter for runtime-supplied grids: e.g. `geomspace(consumption_floor,
MAX, N)` with a bad `consumption_floor` produces all-NaN points, and we
want that caught at the grid layer rather than as a downstream V_arr
NaN diagnostic.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… banners

- tests/test_single_feasible_action.py: drop three decorative section
  banners (AGENTS.md prohibits `# ---...---` separators); fold the
  banner prose into the docstrings of the tests/helpers below.
- tests/test_single_feasible_action.py: type-annotate `_crra_bequest`
  and `_alive_utility`'s pref_type / consumption_weight /
  coefficient_rra arguments (DiscreteState / FloatND).
- tests/test_runtime_params.py: type-annotate `_make_action_grid_model`
  and `_make_action_grid_model_with_stateful_dead`.
- src/lcm/simulation/transitions.py: re-run `_validate_all_states_present`
  in the new `create_regime_state_action_space` (the substitution
  switch from `create_state_action_space(states=...)` to
  `base.replace(states=...)` had silently dropped this check).
- src/lcm/params/regime_template.py: docstring on
  `_fail_if_runtime_grid_shadows_function`; fix stale phrasing in
  `create_regime_params_template` ("matching the state name" →
  "matching the state or action name").
- src/lcm/interfaces.py: comment why the `_ShockGrid` substitution
  branch is gated on `in_states` only (state-only by design,
  AGENTS.md forbids ShockGrids as actions; gate is the explicit
  enforcement of that invariant).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The validator's error message already explains why; the class docstring
only needs the contract.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rid path

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Derived categoricals (`regime.derived_categoricals`, function outputs
that pylcm treats as categoricals — see
https://pylcm.readthedocs.io/en/latest/pandas-interop/#derived-categoricals)
suffer the same per-cell broadcast clash as discrete states. Extend
`discrete_state_names` in `_validate_function_output_state_indexing`
to include them; add a TDD test.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…module)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
pylcm is a general library; references to a particular companion
application become stale fast and force readers to know unrelated
projects to follow the test rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The variable previously named `discrete_state_names` accumulated state
DiscreteGrids, derived categoricals, and now discrete actions — all
three suffer the same per-cell broadcast clash when a consumer does
`func_output[X]`. Renamed the variable, the two helpers
(`_validate_function_output_grid_indexing`,
`_find_function_output_grid_indexing`), the test module
(`test_function_output_grid_indexing.py`), and the error-message
wording ("discrete state" → "discrete grid"). Added a TDD test for
the discrete-action case.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion

The previous docstring claimed the indexing 'silently produces NaN', but a
disabled-validator probe shows otherwise:

- When the producer takes the discrete grid as input, its output is a
  per-cell scalar; `func_output[grid]` raises `IndexError: Too many indices`
  at trace time. This is the real footgun the validator should catch.
- When the producer does NOT take the discrete grid as input, its output
  stays array-shaped and `func_output[grid]` is correct code that solves
  to sensible V values.

The previous validator flagged both shapes — including the safe one — as a
clash. Tighten: only fire when the producing function also takes the
discrete grid as input. Update the description to match observed behaviour
(IndexError, not NaN). Add a regression test that exercises the
array-valued-producer + state-indexed-consumer shape and asserts it builds
without raising.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #334 introduced a deferred-diagnostics accumulator that appends every
(regime, period) NaN/Inf flag to a Python list, stacks the lists at end
of solve, and `.tolist()`s the stacks to host. On a 16 GB V100 at
production aca-baseline grid sizes the stacked reduction graph holds the
per-period `isnan(V_arr)` / `isinf(V_arr)` intermediates alive
simultaneously; the post-loop `.tolist()` then asks XLA to compile the
fan-in and OOMs on a ~7.3 GiB allocation on top of the already-resident
solution V arrays. Symptom: backward induction reports every age as
"finished in ~14 ms" (dispatch-async times), then
`JaxRuntimeError: RESOURCE_EXHAUSTED` at the first `.tolist()`.

Fix: replace the per-period list-append with a running scalar OR; add a
per-period `block_until_ready()` so each period's reduction kernel
finishes (and its intermediate is freed) before the next period
dispatches. `block_until_ready` is device-only — no host transfer, no
PCIe round-trip — so it doesn't reintroduce the per-period sync that
#334 removed; in practice the small reduction has finished by the time
`max_Q_over_a` (~14 ms/period) returns.

End of solve: one `.item()` per running scalar. On a healthy solve those
two bools are False and we return without materialising any per-row
state. Failure paths (`running_any_nan` / `running_any_inf` True) walk
`diagnostic_rows` and materialise one bool per row to localise the
offender — same total host transfers as the prior code, but only on the
failure path.

Debug-stats path (`log_level="debug"`) still appends min/max/mean per
period; a single per-period `block_until_ready` after the appends frees
those intermediates too. The end-of-solve `_log_per_period_stats` keeps
the existing per-(regime, period) log line.

`_StackedReductions`, `_emit_deferred_diagnostics`, and the old
`_raise_if_nan` / `_warn_if_inf` (taking pre-materialised flag lists)
are replaced by `_emit_post_loop_diagnostics` (orchestrator),
`_raise_first_nan_row`, `_warn_inf_rows`, and `_log_per_period_stats`.

Tests: new `tests/solution/test_diagnostics.py` covering the four log
levels — happy-path warning, NaN-raise with `(regime, age)` in the
message, off-level skip, and per-period debug stats.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each `_DiagnosticRow` previously held the active-period
`state_action_space`, the rolling `next_regime_to_V_arr`, the regime's
flat params, and a `compute_intermediates` closure (which itself
captured the state_action_space). At production grid sizes — 50+
periods × ~6 active regimes — the accumulated references pin every
period's full-shape V mapping in device memory, OOMing the V100 16 GB
mid-loop on `block_until_ready` (the next allocation that has nowhere
to go).

The streaming NaN/Inf reduction landed earlier addressed only the
per-period reduction buffers; the row-level retention is the larger
leak. Strip `_DiagnosticRow` to the three Python scalars actually
needed for failure-path localisation (`regime_name`, `period`, `age`)
and reconstruct the heavy bits from `solution`, `internal_regimes`,
and `internal_params` inside `_raise_at`. The reconstruction mirrors
the loop's roll-forward semantics: for each regime, take the smallest
later period in `solution` where the regime was active, falling back
to a zeros template — the same value the rolling
`next_regime_to_V_arr` slot held during the live dispatch.

Also lock the row's shape via a structural test so future changes
that re-introduce device-backed fields fail loudly in CI rather than
silently regressing OOM behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 1, 2026

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

Benchmark comparison (main → HEAD)

Comparing a4eca9bf (main) → 14f81fc7 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 46.946 s 29.561 s 0.63
peak GPU mem 671 MB 614 MB 0.91
compilation time 427.20 s 424.90 s 0.99
peak CPU mem 8.39 GB 10.36 GB 1.23
Mahler-Yum execution time 4.756 s 4.767 s 1.00
peak GPU mem 522 MB 522 MB 1.00
compilation time 16.46 s 16.41 s 1.00
peak CPU mem 1.68 GB 1.68 GB 1.00
Precautionary Savings - Solve execution time 45.5 ms 46.4 ms 1.02
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.52 s 2.44 s 0.97
peak CPU mem 1.12 GB 1.12 GB 1.00
Precautionary Savings - Simulate execution time 119.6 ms 118.1 ms 0.99
peak GPU mem 340 MB 340 MB 1.00
compilation time 5.99 s 5.93 s 0.99
peak CPU mem 1.29 GB 1.29 GB 1.00
Precautionary Savings - Solve & Simulate execution time 136.7 ms 160.5 ms 1.17
peak GPU mem 577 MB 577 MB 1.00
compilation time 7.78 s 7.68 s 0.99
peak CPU mem 1.28 GB 1.28 GB 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 280.2 ms 287.6 ms 1.03
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 8.59 s 8.40 s 0.98
peak CPU mem 1.33 GB 1.34 GB 1.00

@hmgaudecker hmgaudecker changed the title Model.n_subjects: AOT-compile simulate functions for fixed batch shape Model.n_subjects: AOT-compile simulate, lock integer dtype to int32 May 3, 2026
@hmgaudecker hmgaudecker force-pushed the feat/simulate-aot-n-subjects branch from 4306092 to 10ba5d4 Compare May 3, 2026 15:36
Two changes targeting the NaN-in-V failure path:

1. Fail-fast at age boundary. Adds a per-period
   `running_any_nan.item()` host transfer right after the existing
   `block_until_ready`. On True, the loop breaks out and the existing
   post-loop emitter raises immediately. Cost: one scalar bool transfer
   per period — negligible next to `max_Q_over_a`. Without this,
   backward induction would finish the entire age range (potentially
   ~2h on production grids) before raising at the first-NaN row,
   leaving the user staring at an idle-looking solve.
   Inf stays non-fatal; the post-loop warning still fires for any
   period that flagged it.

2. Drop the misleading "re-solve with debug logging" suggestion from
   `validate_V`. The diagnostic [NOTE] is added inline by
   `_enrich_with_diagnostics` whenever `compute_intermediates` is
   wired up — i.e. on the default path — so suggesting a re-solve to
   "produce" diagnostics is wrong: they were already produced. Replace
   with a pointer to the [NOTE] for the per-axis breakdown plus a
   mention of `log_path=...` for snapshot persistence (the only thing
   debug-mode actually adds beyond the inline diagnostic).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker force-pushed the feat/simulate-aot-n-subjects branch from 3406f06 to 5e09d46 Compare May 4, 2026 05:15
hmgaudecker and others added 11 commits May 6, 2026 06:40
V_arr shape is model-dependent (states × grid points × regimes), so
the per-period isnan/isinf intermediate buffer size is too. The
qualitative point — these allocations stack up across periods if not
freed — is what matters; the size figure was a distraction that
implied a fixed scale only true on whichever model the comment was
written against.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codifies three patterns observed across this PR's review cleanups and
#339's, where source docstrings and inline comments rehearsed prior
implementations, cited PR numbers, and quoted model-specific figures.
None of that survives the 9-month-without-PR-context test.

The new "## Docstring Style" section sits next to "## Testing" and
adds three subsections, each with good/bad examples drawn from real
recent code:

- Describe state, not history (no "earlier", "previously", "the old
  design", "before the fix").
- No PR numbers, no model-specific magic numbers (PRs rot; "~2 MB at
  production grid sizes" only holds on one model/box).
- Bulleted lists for enumerated cases (one bullet per case beats
  running prose for log levels, regime kinds, dispatch strategies).

Cross-references the existing "Test docstrings — describe behavior,
not history" subsection rather than duplicating the same rule with
test-specific framing.

Also bump .ai-instructions submodule pointer to 609ac4a, which
landed the same docstring-style guidance on the canonical version.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t list

Three follow-ups to the AGENTS.md docstring-style section just
inlined on this branch via the cascade-merge:

- Drop "the per-period host stalls that #334 removed" from the
  diagnostics-accumulator comment. PR refs rot. Restate as "doesn't
  introduce a host stall" with the same forward-looking explanation.
- Drop "(~2 MB each at production grid sizes)" — V_arr shape is
  model-dependent (states × grid points × regimes), so any fixed
  byte figure misleads on every other model. Restated as
  "(V_arr-shaped, so model-dependent)" earlier in the same comment;
  this commit's change is the second history-framed phrase from the
  same block.
- Reflow the log-level prose paragraph into a bulleted list. The
  three cases (`"off"` / `"warning"` ∪ `"progress"` / `"debug"`) read
  faster as bullets, and the `"off"` qualifier ("skips even the NaN
  fail-fast") fits inline rather than needing the separate sentence
  about contracts and estimation loops.

Also drop "without the deferred-stack fan-in that previously OOMed at
production sizes" from `tests/solution/test_diagnostics.py` — same
"previously" anti-pattern.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ns bump, source cleanups

Brings the #342#339 stack onto #340. Net surface:

- New `## Docstring Style` section in AGENTS.md (3 subsections with
  good/bad examples drawn from #342/#339 cleanups).
- `.ai-instructions` submodule pointer to 609ac4a, picking up the
  same docstring-style guidance on the canonical version.
- `_build_combined_simulation_function` is gone — replaced by direct
  `concatenate_functions` on `per_target_funcs`. Output shape
  changes from flat `<target>__<next_state>` to nested
  `{target: {next_<state>: array}}`. Consumer
  `_update_states_for_subjects` updated accordingly.
- `regime_template.create_regime_params_template`: `H_variables`
  intermediate dropped; single `variables` set with all categories.
- `solve_brute.py` diagnostics-accumulator comment: drop "#334"
  PR ref, drop "~2 MB at production grid sizes" magic number,
  reflow log-level prose into bulleted list.
- `tests/test_chained_state_transitions.py`: rewrite to assert
  concrete chain mathematics (`next_wealth_t = wealth_t - c_t +
  0.1 * next_aime_t`) instead of "no NaN".
- `tests/test_next_state.py::test_get_next_state_function_with_simulate_target`
  updated to assert nested output shape, replacing the prior
  flat-key version that was added on this branch as `c969b1a`.

Conflict resolution: take #339's version of `next_state.py` and
`test_chained_state_transitions.py` (the cascade-target's reshape
supersedes the branch's earlier exec-based combined function).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The cascade-merge from #339 changed `get_next_state_function_for_simulation`
to return a nested mapping `{target: {next_<state>: array}}` (the
flat `<target>__<next_state>` form went away with
`_build_combined_simulation_function`). The simulate-target test on
this branch was set up against the flat shape (added as `c969b1a`
before the merge); update its assertions to the nested form so the
test still validates the actual output structure.
…ssage

Replace the "Active constraints: ..." list (which printed *all* regime
constraints whether binding or not) with a per-constraint boolean
column appended to the infeasible-subjects DataFrame. For each
constraint and each infeasible subject, the entry is True when that
constraint individually admits at least one action and False when it
rejects every action by itself.

The distinction matters for diagnosis. With a single binding
constraint, exactly one column reads False. With a joint-rejection
case (each constraint admits some action; their intersection is
empty), every column reads True and the user knows the issue is the
intersection, not any individual constraint.

Implementation: `_per_constraint_feasibility` builds a per-constraint
feasibility function via `_get_feasibility(constraints={name:func})`
and reuses `_batched_feasibility_check` to run it on the infeasible
subjects only.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Base automatically changed from improve/lazy-solve-diagnostics to main May 6, 2026 08:10
Docstring/comment cleanup (drop counterfactuals and history rehashes per
AGENTS.md "Describe state, not history"):

- discrete.py to_jax: rephrase "would otherwise produce" as positive
  constraint
- initial_conditions.py: replace "Without this..." with the invariant the
  cast enforces
- compile.py module docstring: drop "Mirrors the pattern in solve_brute"
  cross-reference
- compile.py argmax-over-active-periods comment: drop aca-model-specific
  forced-canwork / pre-FRA example
- compile.py _build_next_state_args: drop the now-stale comment about
  Python-int period (all sites use jnp.int32)

Bullet-list dispatch cases (per AGENTS.md "Bulleted lists for enumerated
cases"):

- model.py n_subjects field: 4 dispatch cases as bullets, plus an explicit
  param-shape stability contract for MSM-style use
- model.py _resolve_simulate_internal_regimes: 3 cases as bullets

TDD/test style (per AGENTS.md "Testing"):

- test_simulate_aot.py: split test_n_subjects_none_keeps_lazy_behavior
  into two single-assertion tests
- test_simulate_aot.py: rename test_simulate_caches_recompiled_size_no_
  second_warning -> test_simulate_warns_only_once_per_mismatching_size
  (the implementation only adds to _warned_n_subjects, not the cache)
- test_simulate_aot.py: split first/second-call compile counting into
  test_simulate_first_matching_call_populates_aot_cache and
  test_simulate_second_matching_call_does_not_invoke_compile, dropping
  the weak `n_first > 0` smoke check
- test_simulate_aot.py: docstrings on the validation tests, return-type
  annotations on counting_compile and _build_initial_conditions
- test_int_dtype_invariants.py: per-test docstrings
- compile.py _get_regime_V_shapes: docstring matching the file's other
  private helpers
- model.py: class-body annotations and field docstrings for
  _simulate_compile_cache, _warned_n_subjects, _simulate_compile_lock

x64 + AOT fail-fast (issue 19):

- model_processing._fail_if_x64_with_aot raises
  ModelInitializationError when n_subjects is set under
  jax_enable_x64=True; the AOT path pins integer dtypes to int32, x64
  promotes to int64, so the cached signature would not match runtime
- AOT tests get an autouse fixture that disables x64 for the duration

period dtype consistency (issue 20):

- _build_next_state_args / _build_crtp_args: lower with
  period=jnp.int32(0) to match the argmax path
- transitions.calculate_next_states / calculate_next_regime_membership:
  pass period=jnp.int32(period) at runtime so the runtime call matches
  the AOT abstract signature

Concurrency (issue 18):

- Model gains a threading.Lock guarding check-then-set on
  _simulate_compile_cache and _warned_n_subjects
- __getstate__ / __setstate__ exclude the lock and the per-process AOT
  cache from pickling (compiled programs can't survive process boundaries
  anyway)

Dedup-key invariant (issue 17):

- _collect_unique_simulate_functions: comment documents the dedup
  contract — pylcm's process_regimes ships per-regime callables for
  next_state and crtp, so identity-based dedup is collision-free
- new test_simulate_functions_use_per_regime_callables in
  test_regime_processing.py pins the invariant against future
  regression
Pin every int that crosses into pylcm to int32, with overflow checks at
the boundary helpers. After this, the AOT simulate path works under
both `jax_enable_x64=True` and `=False` — the int-side hazard is gone,
so the construction-time guard is dropped.

Boundaries:

- New `lcm.dtypes.safe_to_int32(value, *, name)`: host-side cast that
  raises `ValueError` with the leaf's qualified name on int32 overflow,
  rather than silently wrapping. Reused at every int boundary helper.
- `lcm.params.processing._cast_int_leaves_to_int32`: applied in
  `broadcast_to_template` over every resolved params leaf. Walks
  `MappingLeaf` / `SequenceLeaf` recursively. Casts typed JAX/numpy
  int arrays to int32; leaves Python int/bool scalars alone so JAX's
  weak-typing rules can still promote them per call site (e.g.
  `discount_factor: 1` keeps working in float-typed functions).
- `lcm.simulation.transitions._update_states_for_subjects`: casts
  `next_state_values` to the storage dtype before `jnp.where`, but
  only when the kinds match (int->int, float->float). Cross-kind
  pairs (int storage + float transition output, possible when a user
  passes int initial conditions for a continuous state) keep JAX's
  promotion semantics; the cross-kind boundary cast is Package B.

Guard removal:

- `_fail_if_x64_with_aot` deleted. The boundary casts now keep
  `internal_params` and the simulate state pool int32 regardless of
  `jax_enable_x64`, so the AOT cache signature is stable under x64.
  `tests/simulation/test_simulate_aot.py` drops its autouse
  `_disable_x64` fixture; AOT tests run under the conftest's default
  x64 setting.

Tests:

- New `tests/test_dtypes.py` (5 tests): `safe_to_int32` round-trip,
  in-range int64 array, overflow with qualified-name error, underflow.
- Extended `tests/test_int_dtype_invariants.py` with six tests
  covering the four new boundaries:
  - `_update_states_for_subjects` keeps storage int32 when next-state
    output is int64
  - `process_params` casts a typed int64 array leaf to int32
  - `process_params` raises `ValueError` naming the param qname on
    int array overflow
  - `process_params` casts int arrays inside a `MappingLeaf` to int32
  - `process_params` passes Python int leaves through unchanged for
    JAX weak-typing
  - `simulate(...)` accepts int64 `regime` initial conditions and
    round-trips to the same DataFrame as int32

Out of scope (Package B):

- Continuous-state float dtype normalisation at boundaries.
- Float scalar params -> single canonical float dtype.
- Cross-kind cast in `_update_states_for_subjects` (int storage +
  float transition output, when user supplies int initial conditions
  for a continuous state — pre-existing behaviour preserved).
hmgaudecker and others added 4 commits May 6, 2026 18:26
The two overflow tests built their fixtures via
`jnp.asarray([..., 2**32], dtype=jnp.int64)`. Under `jax_enable_x64=False`
(the GPU 32-bit-precision job), JAX truncates the requested int64 to
int32 at construction time and raises its own `OverflowError` before
`safe_to_int32` ever sees the value — so the test asserts a `ValueError`
that is never reached.

Use `np.asarray(..., dtype=np.int64)` for these fixtures instead. Numpy
honours the explicit dtype regardless of the JAX precision setting, so
our boundary helper receives a real int64 and produces its own
qualified-name `ValueError`.

The same pattern (use numpy for overflow fixtures) will land in
Package B for the float-overflow test.
…keys

Docstring style:

- `dtypes.py` module docstring: drop "out of scope for the boundary
  helpers" / "broken user transition" framing — describe what the
  module does instead.
- `_cast_int_leaves_to_int32` docstring: drop "would force premature
  dtype commitment" counterfactual and "Package B" project-jargon
  reference. Replace with a positive description (cast / pass-through
  by leaf kind).
- `__getstate__` docstring: lead with what is returned (a copy of
  `__dict__` minus the per-process AOT compile state) before
  explaining mechanism.

Public-API docs:

- `process_params` Raises section now lists the `ValueError` that the
  int boundary cast can surface, plus a paragraph documenting the
  dtype-normalisation step. Module docstring covers the same.

Concurrency docstring:

- `_simulate_compile_lock` field docstring rephrased to make it
  explicit that the consequent `log.warning` is intentionally
  outside the lock.

Test coverage and structure:

- Add `test_process_params_casts_int_array_inside_sequence_leaf_to_int32`
  to mirror the `MappingLeaf` test for `SequenceLeaf` int contents.
- Add `test_unpickled_model_can_simulate_with_aot`: full pickle round-
  trip through `cloudpickle`, then re-run simulate to confirm the AOT
  cache is reset and re-populated post-unpickle.
- Parametrise `test_process_params_casts_int_array_inside_mapping_leaf_to_int32`
  over `low`/`high` (one assertion per test).
- Parametrise `test_safe_to_int32_*` over `python-int` / `int64-array`
  inputs and split into `_returns_int32` (dtype) and
  `_preserves_in_range_values` (value preservation).
- Split `test_simulate_warns_on_n_subjects_mismatch` into four single-
  assertion tests behind a shared fixture (warning count, declared-N
  in message, actual-N in message, no cache entry).
- Replace `test_simulate_functions_use_per_regime_callables` 2-assertion
  body with a parametrised version over `next_state` /
  `compute_regime_transition_probs`, using a shared fixture. New
  docstring describes user-visible behaviour rather than the AOT-dedup
  rationale.
- Round-trip int64-regime test now uses `pd.testing.assert_frame_equal`
  across all output columns, not `.equals()` on one column.

Structural fixes:

- `_collect_unique_simulate_functions` now keys `next_state` and
  `compute_regime_transition_probs` by `(kind, regime_name,
  callable-id)` instead of `(kind, callable-id)`. Two regimes that
  share a callable identity now still get distinct compiled programs
  (carried-forward issue 17 from prior review). Removes the need for
  the comment that pointed at a specific test file path.
- `_cast_int_leaves_to_int32` swaps the `hasattr(value, "dtype")`
  duck-type check for `isinstance(value, (Array, np.ndarray))` —
  closes the PR-#302 review thread on duck-typing arrays.
- `draw_key_from_dict`: pin `regime_ids` to `jnp.int32` so the
  "all integers are int32" invariant holds throughout simulate, not
  just at the AOT-traced boundaries.
`compile_all_simulate_functions` previously held every entry's
concrete lower-args (V-shaped templates, per-regime subject-state /
action zeros, regime-params view) in `unique[key][1]` until the
function returned, and every `Lowered` HLO module in `lowered[key]`
until the slowest parallel compile finished. With per-target
next-state DAGs and an AOT cache that pins every kernel alive, the
overlap of in-flight lower-args + lowered HLO + compiled kernels was
the dominant contributor to the simulate-side compile-phase peak.

Drop the args from `unique[key]` immediately after each lowering, and
`del lowered[k]` as soon as the corresponding `Compiled` lands in
`compiled`. The dedup keys, the parallel pool semantics, and the
swap-in step are unchanged; existing tests (n_subjects mismatch,
unpickled-model AOT round-trip, dtype invariants) remain green.
The local in `_raise_at` holds the regime's params after merging
`resolved_fixed_params` in by hand — the same merge the live solve
loop performs implicitly via partialled closures. `effective_regime_params`
captures that intent; `diag_params` only said "I'm passing this to the
diagnostic call".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Member Author

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autoreview.

if not infeasible_indices:
return None

per_constraint_admits_any = _per_constraint_feasibility(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helps a lot for debugging invalid initial conditions. Sorry it landed here, making the PR bigger...

@hmgaudecker hmgaudecker requested review from mj023 and timmens and removed request for mj023 May 7, 2026 07:07
@hmgaudecker
Copy link
Copy Markdown
Member Author

hmgaudecker commented May 7, 2026

Note:

Benchmark Statistic before after Ratio Alert
peak CPU mem 8.39 GB 10.36 GB 1.23

If this ever were the bottleneck, reducing max_compilation_workers would get this down at a one-off cost.

Plus it seems to be gone in #345.

Drops the "Python scalar pass-through to keep JAX weak-typing
semantics" line from `_cast_int_leaves_to_int32`: a Python `int`
arriving at a DAG function as a Python scalar now becomes
`jnp.int32(value)` so downstream code sees a JAX-typed scalar
rather than a Python int that JAX would promote per call site.
This finishes the "no Python scalars inside JIT'd loops" goal
that motivated the dtype-barrier work.

`bool` is short-circuited (the float-side cast pass on #345 will
handle it; bool is a Python `int` subclass so the bool branch must
come before the int one). Module + function docstrings refreshed
accordingly.

Test `test_process_params_passes_python_int_through_for_jax_weak_typing`
renamed and flipped to assert `dtype == jnp.int32`.

`ScalarInt` keeps the `int | Int32[Scalar, ""]` union: tightening
to JAX-only cascades into 13 call-site mismatches at internal
Python metadata sites (`n_points`, `period`) that legitimately
pass Python `int` outside the JIT'd DAG. Tightening the alias is
a separate audit and follow-up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker added a commit that referenced this pull request May 8, 2026
Folds #340's int Python-scalar cast into the broader #345 work:

- `_cast_leaves_to_canonical_dtype` becomes a strict whitelist with
  explicit raises (no silent fallthrough). Dispatch order:
  `MappingLeaf`/`SequenceLeaf` → recurse, `pd.Series` → defensive
  raise (orchestrator must reshape first), `bool` (before `int`),
  `int`, `float`, `np.ndarray | jax.Array` keyed on `dtype.kind`
  (`b/i/u/f`). Anything else (`str`, complex/object arrays, custom
  objects) raises `InvalidParamsError` with the leaf's qualified name.
- `broadcast_to_template` becomes broadcast-only — the inline cast
  loop moves out into the new top-level
  `cast_params_to_canonical_dtypes(internal_params)` helper.
- `process_params` chains broadcast + cast for callers that have no
  `pd.Series` leaves; pd.Series-bearing callers must orchestrate the
  three steps explicitly.
- `_process_params` (model.py) and `_build_regimes_and_template_with_fixed_params`
  (model_processing.py) now sequence broadcast → reshape → cast →
  validate. The cast walks a uniform tree (no `pd.Series` to special-
  case) which lets the whitelist drop the pass-through branch.
- `bool` casts to `jnp.bool_`, `float` casts to
  `canonical_float_dtype()`, completing the "no Python scalars inside
  JIT'd loops" goal — DAG functions now receive JAX-typed scalars
  end-to-end.
- `ScalarBool` alias added (`bool | Bool[Scalar, ""]`). `ScalarInt` /
  `ScalarFloat` keep their `int |` / `float |` unions for now; the
  tighter forms cascade into call-site mismatches at internal Python
  metadata sites (`n_points`, `period`) and warrant a separate audit.

Tests:

- `test_process_params_passes_python_float_through_for_jax_weak_typing`
  → `..._casts_python_float_to_canonical`, asserts dtype.
- `test_python_float_param_passed_through_for_weak_typing` →
  `..._cast_to_canonical_dtype`, asserts dtype.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When `Model(n_subjects=N)` is set, simulate-side XLA compilation used
to run lazily on the first matching `simulate(...)` call — strictly
after `solve(...)` returned. On production aca-baseline that adds
several minutes to the end-to-end wall clock for nothing: solve is
GPU-bound, simulate compile is CPU-bound XLA work, so they overlap
trivially.

Add `_maybe_start_simulate_compile_async` and call it from `solve(...)`
right after parameters are processed. It spawns a single-worker
`ThreadPoolExecutor` that runs `compile_all_simulate_functions` in the
background and parks the result on `_simulate_compile_future`.
`_resolve_simulate_internal_regimes` awaits the future before
populating the cache, so the lazy fallback path (no `solve` call,
direct `simulate(...)`) still works.

`__getstate__` / `__setstate__` drop the future on the way out and
reset to `None` on the way in — `concurrent.futures.Future` is tied to
its originating thread pool and can't survive a process boundary.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant