Model.n_subjects: AOT-compile simulate, lock integer dtype to int32#340
Open
hmgaudecker wants to merge 103 commits intomainfrom
Open
Model.n_subjects: AOT-compile simulate, lock integer dtype to int32#340hmgaudecker wants to merge 103 commits intomainfrom
hmgaudecker wants to merge 103 commits intomainfrom
Conversation
Extends the existing runtime-points mechanism (previously state-only)
to continuous action grids. With this change, an action declared as
`IrregSpacedGrid(n_points=N)` adds an `{action_name: {"points":
"Float1D"}}` entry to the regime params template, and `state_action_space()`
substitutes the runtime-supplied points into `continuous_actions` at
solve / simulate time.
Motivation: aca-dev's structural retirement model has a `consumption`
action grid whose lower bound is the per-iteration `consumption_floor`
parameter. Without this change the c-grid bounds would have to be
fixed at build time, which forces either an over-wide grid (wasted
density) or model rebuilds per estimation iteration (recompilation).
Mirrors the existing state-grid treatment:
- `regime_template.py`: walks `regime.actions` alongside `regime.states`,
factoring the shared shadowing check into helpers.
- `interfaces.InternalRegime.state_action_space()`: builds both
state and continuous-action replacements in a single sweep over
`self.grids`, then calls `_base_state_action_space.replace(...)`
with whichever side actually had substitutions.
- `pandas_utils._is_runtime_grid_param`: also recognises action grids
so column extraction in `to_dataframe()` keeps working.
Tests (TDD): four new tests in `tests/test_runtime_params.py`,
mirroring the state-grid counterparts — params-template entry,
solve, runtime-vs-fixed equivalence, and a sanity check that
varying runtime points actually changes V.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model now declares `consumption` as `IrregSpacedGrid(n_points=N)` with runtime-supplied points. The bench builder now passes `model=model` to `get_benchmark_params` so consumption gridpoints are injected into params before solving. aca-model rev: adc8a19 → 4123fe9 (feature/runtime-consumption-points) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`IrregSpacedGrid(n_points=N)` declares a continuous grid whose values are supplied at runtime via `params[regime][grid_name]['points']`. Substitution happens inside `InternalRegime.state_action_space(regime_params=...)` at solve / simulate time. Any code path that calls `to_jax()` on the base grid before substitution silently got `jnp.full(N, jnp.nan)` and went on to compute against the placeholder. That is exactly what fired in `validate_initial_conditions` for `task_simulate_aca`: the validator built the action grid by calling `internal_regime.grids[name].to_jax()` (placeholder NaNs), then asked `borrowing_constraint(consumption=NaN, wealth=W)` whether each gridpoint was feasible. NaN comparisons are False, so every action was reported infeasible for every subject in every initial regime. Make the invariant explicit: `IrregSpacedGrid.to_jax()` raises `GridInitializationError` for runtime-supplied grids, with a message pointing the caller at `state_action_space(regime_params=...)` for real values or `.n_points` for shape. Confine the legitimate "placeholder needed for AOT tracing" caller (the base state-action space) to a private helper in `state_action_space.py` that uses NaN explicitly. Reroute `_check_regime_feasibility` through the substituted state-action space. Add regression tests covering both runtime action and runtime state grids round-tripping `simulate(check_initial_conditions=True)`, and unit tests pinning down the new raise + the existing NaN-source mechanics in `map_coordinates` / `get_irreg_coordinate`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
Move late `DiscreteGrid`, `map_coordinates`, and `get_irreg_coordinate` imports to the module top level (PLC0415), drop the unnecessary `val` assignment before return (RET504), and mark the unused `wealth` arg in the local `borrow` constraint as `# noqa: ARG001`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A regime function whose output is then re-indexed by a discrete state inside another consumer (function, constraint, or transition) is a silent footgun: pylcm broadcasts function outputs to per-cell scalars before consumption, so the indexing silently produces NaN at runtime instead of the intended scalar. The aca-baseline benchmark hit this via `bequest(... utility_scale_factor[pref_type])` where `utility_scale_factor` is registered as a regime function — the dead regime's V came back all-NaN with no actionable error. Adds an AST-walking validator in `validate_logical_consistency` that inspects every consumer (functions, constraints, transition) for a `Subscript(Name=X, slice=Name=Y)` pattern where `X` is in `regime.functions` and `Y` is a `DiscreteGrid` state. If any clash is found, raises `RegimeInitializationError` listing each clash and pointing the user at the safe pattern (function takes the state, returns a scalar — see `discount_factor`). Three TDD tests in `tests/test_function_output_state_indexing.py`: - the clash raises (functions case) - the safe pattern (function takes the state, scalar return) builds - the check applies to constraints too Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model `feature/runtime-consumption-points` 4123fe9 → 1342861 (refactors `utility_scale_factor` to take `pref_type` and return a scalar, eliminating the regime-function-output / state-indexed-input clash that produced NaN in the dead regime's V). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ion space `create_regime_state_action_space` (used during forward simulation) was calling `create_state_action_space` directly, which leaves `pass_points_at_runtime=True` IrregSpacedGrid action grids as their NaN placeholder. The placeholder fed straight into `argmax_and_max_Q_over_a` and `_lookup_values_from_indices`, so optimal actions came back NaN, the source regime's `next_state` propagated NaN into every target regime's namespaced state, and `validate_V` raised on the first downstream regime whose utility depended on those states (the dead regime in aca-model: assets/pref_type both NaN). Route through `internal_regime.state_action_space(regime_params=...)` (the same path solve uses) and overlay the per-subject states. Add a TDD regression test in tests/test_runtime_params.py covering the simulate path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…grid values `LogSpacedGrid` previously inherited only the generic continuous-grid checks (start < stop, n_points > 0). With `start <= 0`, `to_jax()` silently returned NaN/-inf, and the bug would only surface deep inside an interpolation kernel. Now refuses at construction. While here, tighten two adjacent silent-failure modes: - `_validate_continuous_grid` rejects non-finite `start`/`stop`. `start >= stop` is False for NaN, so a NaN bound previously slipped through every check. - `_validate_irreg_spaced_grid` rejects non-finite points. The ascending-order test uses `>=`, which is False for NaN, so a NaN point previously passed the order check silently. Both matter for runtime-supplied grids: e.g. `geomspace(consumption_floor, MAX, N)` with a bad `consumption_floor` produces all-NaN points, and we want that caught at the grid layer rather than as a downstream V_arr NaN diagnostic. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… banners
- tests/test_single_feasible_action.py: drop three decorative section
banners (AGENTS.md prohibits `# ---...---` separators); fold the
banner prose into the docstrings of the tests/helpers below.
- tests/test_single_feasible_action.py: type-annotate `_crra_bequest`
and `_alive_utility`'s pref_type / consumption_weight /
coefficient_rra arguments (DiscreteState / FloatND).
- tests/test_runtime_params.py: type-annotate `_make_action_grid_model`
and `_make_action_grid_model_with_stateful_dead`.
- src/lcm/simulation/transitions.py: re-run `_validate_all_states_present`
in the new `create_regime_state_action_space` (the substitution
switch from `create_state_action_space(states=...)` to
`base.replace(states=...)` had silently dropped this check).
- src/lcm/params/regime_template.py: docstring on
`_fail_if_runtime_grid_shadows_function`; fix stale phrasing in
`create_regime_params_template` ("matching the state name" →
"matching the state or action name").
- src/lcm/interfaces.py: comment why the `_ShockGrid` substitution
branch is gated on `in_states` only (state-only by design,
AGENTS.md forbids ShockGrids as actions; gate is the explicit
enforcement of that invariant).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The validator's error message already explains why; the class docstring only needs the contract. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rid path Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Derived categoricals (`regime.derived_categoricals`, function outputs that pylcm treats as categoricals — see https://pylcm.readthedocs.io/en/latest/pandas-interop/#derived-categoricals) suffer the same per-cell broadcast clash as discrete states. Extend `discrete_state_names` in `_validate_function_output_state_indexing` to include them; add a TDD test. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…module) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
pylcm is a general library; references to a particular companion application become stale fast and force readers to know unrelated projects to follow the test rationale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The variable previously named `discrete_state_names` accumulated state
DiscreteGrids, derived categoricals, and now discrete actions — all
three suffer the same per-cell broadcast clash when a consumer does
`func_output[X]`. Renamed the variable, the two helpers
(`_validate_function_output_grid_indexing`,
`_find_function_output_grid_indexing`), the test module
(`test_function_output_grid_indexing.py`), and the error-message
wording ("discrete state" → "discrete grid"). Added a TDD test for
the discrete-action case.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion The previous docstring claimed the indexing 'silently produces NaN', but a disabled-validator probe shows otherwise: - When the producer takes the discrete grid as input, its output is a per-cell scalar; `func_output[grid]` raises `IndexError: Too many indices` at trace time. This is the real footgun the validator should catch. - When the producer does NOT take the discrete grid as input, its output stays array-shaped and `func_output[grid]` is correct code that solves to sensible V values. The previous validator flagged both shapes — including the safe one — as a clash. Tighten: only fire when the producing function also takes the discrete grid as input. Update the description to match observed behaviour (IndexError, not NaN). Add a regression test that exercises the array-valued-producer + state-indexed-consumer shape and asserts it builds without raising. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #334 introduced a deferred-diagnostics accumulator that appends every (regime, period) NaN/Inf flag to a Python list, stacks the lists at end of solve, and `.tolist()`s the stacks to host. On a 16 GB V100 at production aca-baseline grid sizes the stacked reduction graph holds the per-period `isnan(V_arr)` / `isinf(V_arr)` intermediates alive simultaneously; the post-loop `.tolist()` then asks XLA to compile the fan-in and OOMs on a ~7.3 GiB allocation on top of the already-resident solution V arrays. Symptom: backward induction reports every age as "finished in ~14 ms" (dispatch-async times), then `JaxRuntimeError: RESOURCE_EXHAUSTED` at the first `.tolist()`. Fix: replace the per-period list-append with a running scalar OR; add a per-period `block_until_ready()` so each period's reduction kernel finishes (and its intermediate is freed) before the next period dispatches. `block_until_ready` is device-only — no host transfer, no PCIe round-trip — so it doesn't reintroduce the per-period sync that #334 removed; in practice the small reduction has finished by the time `max_Q_over_a` (~14 ms/period) returns. End of solve: one `.item()` per running scalar. On a healthy solve those two bools are False and we return without materialising any per-row state. Failure paths (`running_any_nan` / `running_any_inf` True) walk `diagnostic_rows` and materialise one bool per row to localise the offender — same total host transfers as the prior code, but only on the failure path. Debug-stats path (`log_level="debug"`) still appends min/max/mean per period; a single per-period `block_until_ready` after the appends frees those intermediates too. The end-of-solve `_log_per_period_stats` keeps the existing per-(regime, period) log line. `_StackedReductions`, `_emit_deferred_diagnostics`, and the old `_raise_if_nan` / `_warn_if_inf` (taking pre-materialised flag lists) are replaced by `_emit_post_loop_diagnostics` (orchestrator), `_raise_first_nan_row`, `_warn_inf_rows`, and `_log_per_period_stats`. Tests: new `tests/solution/test_diagnostics.py` covering the four log levels — happy-path warning, NaN-raise with `(regime, age)` in the message, off-level skip, and per-period debug stats. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each `_DiagnosticRow` previously held the active-period `state_action_space`, the rolling `next_regime_to_V_arr`, the regime's flat params, and a `compute_intermediates` closure (which itself captured the state_action_space). At production grid sizes — 50+ periods × ~6 active regimes — the accumulated references pin every period's full-shape V mapping in device memory, OOMing the V100 16 GB mid-loop on `block_until_ready` (the next allocation that has nowhere to go). The streaming NaN/Inf reduction landed earlier addressed only the per-period reduction buffers; the row-level retention is the larger leak. Strip `_DiagnosticRow` to the three Python scalars actually needed for failure-path localisation (`regime_name`, `period`, `age`) and reconstruct the heavy bits from `solution`, `internal_regimes`, and `internal_params` inside `_raise_at`. The reconstruction mirrors the loop's roll-forward semantics: for each regime, take the smallest later period in `solution` where the regime was active, falling back to a zeros template — the same value the rolling `next_regime_to_V_arr` slot held during the live dispatch. Also lock the row's shape via a structural test so future changes that re-introduce device-backed fields fail loudly in CI rather than silently regressing OOM behaviour. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Benchmark comparison (main → HEAD)Comparing
|
4 tasks
4306092 to
10ba5d4
Compare
Two changes targeting the NaN-in-V failure path: 1. Fail-fast at age boundary. Adds a per-period `running_any_nan.item()` host transfer right after the existing `block_until_ready`. On True, the loop breaks out and the existing post-loop emitter raises immediately. Cost: one scalar bool transfer per period — negligible next to `max_Q_over_a`. Without this, backward induction would finish the entire age range (potentially ~2h on production grids) before raising at the first-NaN row, leaving the user staring at an idle-looking solve. Inf stays non-fatal; the post-loop warning still fires for any period that flagged it. 2. Drop the misleading "re-solve with debug logging" suggestion from `validate_V`. The diagnostic [NOTE] is added inline by `_enrich_with_diagnostics` whenever `compute_intermediates` is wired up — i.e. on the default path — so suggesting a re-solve to "produce" diagnostics is wrong: they were already produced. Replace with a pointer to the [NOTE] for the per-axis breakdown plus a mention of `log_path=...` for snapshot persistence (the only thing debug-mode actually adds beyond the inline diagnostic). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
3406f06 to
5e09d46
Compare
V_arr shape is model-dependent (states × grid points × regimes), so the per-period isnan/isinf intermediate buffer size is too. The qualitative point — these allocations stack up across periods if not freed — is what matters; the size figure was a distraction that implied a fixed scale only true on whichever model the comment was written against. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codifies three patterns observed across this PR's review cleanups and #339's, where source docstrings and inline comments rehearsed prior implementations, cited PR numbers, and quoted model-specific figures. None of that survives the 9-month-without-PR-context test. The new "## Docstring Style" section sits next to "## Testing" and adds three subsections, each with good/bad examples drawn from real recent code: - Describe state, not history (no "earlier", "previously", "the old design", "before the fix"). - No PR numbers, no model-specific magic numbers (PRs rot; "~2 MB at production grid sizes" only holds on one model/box). - Bulleted lists for enumerated cases (one bullet per case beats running prose for log levels, regime kinds, dispatch strategies). Cross-references the existing "Test docstrings — describe behavior, not history" subsection rather than duplicating the same rule with test-specific framing. Also bump .ai-instructions submodule pointer to 609ac4a, which landed the same docstring-style guidance on the canonical version. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t list Three follow-ups to the AGENTS.md docstring-style section just inlined on this branch via the cascade-merge: - Drop "the per-period host stalls that #334 removed" from the diagnostics-accumulator comment. PR refs rot. Restate as "doesn't introduce a host stall" with the same forward-looking explanation. - Drop "(~2 MB each at production grid sizes)" — V_arr shape is model-dependent (states × grid points × regimes), so any fixed byte figure misleads on every other model. Restated as "(V_arr-shaped, so model-dependent)" earlier in the same comment; this commit's change is the second history-framed phrase from the same block. - Reflow the log-level prose paragraph into a bulleted list. The three cases (`"off"` / `"warning"` ∪ `"progress"` / `"debug"`) read faster as bullets, and the `"off"` qualifier ("skips even the NaN fail-fast") fits inline rather than needing the separate sentence about contracts and estimation loops. Also drop "without the deferred-stack fan-in that previously OOMed at production sizes" from `tests/solution/test_diagnostics.py` — same "previously" anti-pattern. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ns bump, source cleanups Brings the #342 → #339 stack onto #340. Net surface: - New `## Docstring Style` section in AGENTS.md (3 subsections with good/bad examples drawn from #342/#339 cleanups). - `.ai-instructions` submodule pointer to 609ac4a, picking up the same docstring-style guidance on the canonical version. - `_build_combined_simulation_function` is gone — replaced by direct `concatenate_functions` on `per_target_funcs`. Output shape changes from flat `<target>__<next_state>` to nested `{target: {next_<state>: array}}`. Consumer `_update_states_for_subjects` updated accordingly. - `regime_template.create_regime_params_template`: `H_variables` intermediate dropped; single `variables` set with all categories. - `solve_brute.py` diagnostics-accumulator comment: drop "#334" PR ref, drop "~2 MB at production grid sizes" magic number, reflow log-level prose into bulleted list. - `tests/test_chained_state_transitions.py`: rewrite to assert concrete chain mathematics (`next_wealth_t = wealth_t - c_t + 0.1 * next_aime_t`) instead of "no NaN". - `tests/test_next_state.py::test_get_next_state_function_with_simulate_target` updated to assert nested output shape, replacing the prior flat-key version that was added on this branch as `c969b1a`. Conflict resolution: take #339's version of `next_state.py` and `test_chained_state_transitions.py` (the cascade-target's reshape supersedes the branch's earlier exec-based combined function). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The cascade-merge from #339 changed `get_next_state_function_for_simulation` to return a nested mapping `{target: {next_<state>: array}}` (the flat `<target>__<next_state>` form went away with `_build_combined_simulation_function`). The simulate-target test on this branch was set up against the flat shape (added as `c969b1a` before the merge); update its assertions to the nested form so the test still validates the actual output structure.
…ssage
Replace the "Active constraints: ..." list (which printed *all* regime
constraints whether binding or not) with a per-constraint boolean
column appended to the infeasible-subjects DataFrame. For each
constraint and each infeasible subject, the entry is True when that
constraint individually admits at least one action and False when it
rejects every action by itself.
The distinction matters for diagnosis. With a single binding
constraint, exactly one column reads False. With a joint-rejection
case (each constraint admits some action; their intersection is
empty), every column reads True and the user knows the issue is the
intersection, not any individual constraint.
Implementation: `_per_constraint_feasibility` builds a per-constraint
feasibility function via `_get_feasibility(constraints={name:func})`
and reuses `_batched_feasibility_check` to run it on the infeasible
subjects only.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Docstring/comment cleanup (drop counterfactuals and history rehashes per AGENTS.md "Describe state, not history"): - discrete.py to_jax: rephrase "would otherwise produce" as positive constraint - initial_conditions.py: replace "Without this..." with the invariant the cast enforces - compile.py module docstring: drop "Mirrors the pattern in solve_brute" cross-reference - compile.py argmax-over-active-periods comment: drop aca-model-specific forced-canwork / pre-FRA example - compile.py _build_next_state_args: drop the now-stale comment about Python-int period (all sites use jnp.int32) Bullet-list dispatch cases (per AGENTS.md "Bulleted lists for enumerated cases"): - model.py n_subjects field: 4 dispatch cases as bullets, plus an explicit param-shape stability contract for MSM-style use - model.py _resolve_simulate_internal_regimes: 3 cases as bullets TDD/test style (per AGENTS.md "Testing"): - test_simulate_aot.py: split test_n_subjects_none_keeps_lazy_behavior into two single-assertion tests - test_simulate_aot.py: rename test_simulate_caches_recompiled_size_no_ second_warning -> test_simulate_warns_only_once_per_mismatching_size (the implementation only adds to _warned_n_subjects, not the cache) - test_simulate_aot.py: split first/second-call compile counting into test_simulate_first_matching_call_populates_aot_cache and test_simulate_second_matching_call_does_not_invoke_compile, dropping the weak `n_first > 0` smoke check - test_simulate_aot.py: docstrings on the validation tests, return-type annotations on counting_compile and _build_initial_conditions - test_int_dtype_invariants.py: per-test docstrings - compile.py _get_regime_V_shapes: docstring matching the file's other private helpers - model.py: class-body annotations and field docstrings for _simulate_compile_cache, _warned_n_subjects, _simulate_compile_lock x64 + AOT fail-fast (issue 19): - model_processing._fail_if_x64_with_aot raises ModelInitializationError when n_subjects is set under jax_enable_x64=True; the AOT path pins integer dtypes to int32, x64 promotes to int64, so the cached signature would not match runtime - AOT tests get an autouse fixture that disables x64 for the duration period dtype consistency (issue 20): - _build_next_state_args / _build_crtp_args: lower with period=jnp.int32(0) to match the argmax path - transitions.calculate_next_states / calculate_next_regime_membership: pass period=jnp.int32(period) at runtime so the runtime call matches the AOT abstract signature Concurrency (issue 18): - Model gains a threading.Lock guarding check-then-set on _simulate_compile_cache and _warned_n_subjects - __getstate__ / __setstate__ exclude the lock and the per-process AOT cache from pickling (compiled programs can't survive process boundaries anyway) Dedup-key invariant (issue 17): - _collect_unique_simulate_functions: comment documents the dedup contract — pylcm's process_regimes ships per-regime callables for next_state and crtp, so identity-based dedup is collision-free - new test_simulate_functions_use_per_regime_callables in test_regime_processing.py pins the invariant against future regression
Pin every int that crosses into pylcm to int32, with overflow checks at
the boundary helpers. After this, the AOT simulate path works under
both `jax_enable_x64=True` and `=False` — the int-side hazard is gone,
so the construction-time guard is dropped.
Boundaries:
- New `lcm.dtypes.safe_to_int32(value, *, name)`: host-side cast that
raises `ValueError` with the leaf's qualified name on int32 overflow,
rather than silently wrapping. Reused at every int boundary helper.
- `lcm.params.processing._cast_int_leaves_to_int32`: applied in
`broadcast_to_template` over every resolved params leaf. Walks
`MappingLeaf` / `SequenceLeaf` recursively. Casts typed JAX/numpy
int arrays to int32; leaves Python int/bool scalars alone so JAX's
weak-typing rules can still promote them per call site (e.g.
`discount_factor: 1` keeps working in float-typed functions).
- `lcm.simulation.transitions._update_states_for_subjects`: casts
`next_state_values` to the storage dtype before `jnp.where`, but
only when the kinds match (int->int, float->float). Cross-kind
pairs (int storage + float transition output, possible when a user
passes int initial conditions for a continuous state) keep JAX's
promotion semantics; the cross-kind boundary cast is Package B.
Guard removal:
- `_fail_if_x64_with_aot` deleted. The boundary casts now keep
`internal_params` and the simulate state pool int32 regardless of
`jax_enable_x64`, so the AOT cache signature is stable under x64.
`tests/simulation/test_simulate_aot.py` drops its autouse
`_disable_x64` fixture; AOT tests run under the conftest's default
x64 setting.
Tests:
- New `tests/test_dtypes.py` (5 tests): `safe_to_int32` round-trip,
in-range int64 array, overflow with qualified-name error, underflow.
- Extended `tests/test_int_dtype_invariants.py` with six tests
covering the four new boundaries:
- `_update_states_for_subjects` keeps storage int32 when next-state
output is int64
- `process_params` casts a typed int64 array leaf to int32
- `process_params` raises `ValueError` naming the param qname on
int array overflow
- `process_params` casts int arrays inside a `MappingLeaf` to int32
- `process_params` passes Python int leaves through unchanged for
JAX weak-typing
- `simulate(...)` accepts int64 `regime` initial conditions and
round-trips to the same DataFrame as int32
Out of scope (Package B):
- Continuous-state float dtype normalisation at boundaries.
- Float scalar params -> single canonical float dtype.
- Cross-kind cast in `_update_states_for_subjects` (int storage +
float transition output, when user supplies int initial conditions
for a continuous state — pre-existing behaviour preserved).
The two overflow tests built their fixtures via `jnp.asarray([..., 2**32], dtype=jnp.int64)`. Under `jax_enable_x64=False` (the GPU 32-bit-precision job), JAX truncates the requested int64 to int32 at construction time and raises its own `OverflowError` before `safe_to_int32` ever sees the value — so the test asserts a `ValueError` that is never reached. Use `np.asarray(..., dtype=np.int64)` for these fixtures instead. Numpy honours the explicit dtype regardless of the JAX precision setting, so our boundary helper receives a real int64 and produces its own qualified-name `ValueError`. The same pattern (use numpy for overflow fixtures) will land in Package B for the float-overflow test.
…keys Docstring style: - `dtypes.py` module docstring: drop "out of scope for the boundary helpers" / "broken user transition" framing — describe what the module does instead. - `_cast_int_leaves_to_int32` docstring: drop "would force premature dtype commitment" counterfactual and "Package B" project-jargon reference. Replace with a positive description (cast / pass-through by leaf kind). - `__getstate__` docstring: lead with what is returned (a copy of `__dict__` minus the per-process AOT compile state) before explaining mechanism. Public-API docs: - `process_params` Raises section now lists the `ValueError` that the int boundary cast can surface, plus a paragraph documenting the dtype-normalisation step. Module docstring covers the same. Concurrency docstring: - `_simulate_compile_lock` field docstring rephrased to make it explicit that the consequent `log.warning` is intentionally outside the lock. Test coverage and structure: - Add `test_process_params_casts_int_array_inside_sequence_leaf_to_int32` to mirror the `MappingLeaf` test for `SequenceLeaf` int contents. - Add `test_unpickled_model_can_simulate_with_aot`: full pickle round- trip through `cloudpickle`, then re-run simulate to confirm the AOT cache is reset and re-populated post-unpickle. - Parametrise `test_process_params_casts_int_array_inside_mapping_leaf_to_int32` over `low`/`high` (one assertion per test). - Parametrise `test_safe_to_int32_*` over `python-int` / `int64-array` inputs and split into `_returns_int32` (dtype) and `_preserves_in_range_values` (value preservation). - Split `test_simulate_warns_on_n_subjects_mismatch` into four single- assertion tests behind a shared fixture (warning count, declared-N in message, actual-N in message, no cache entry). - Replace `test_simulate_functions_use_per_regime_callables` 2-assertion body with a parametrised version over `next_state` / `compute_regime_transition_probs`, using a shared fixture. New docstring describes user-visible behaviour rather than the AOT-dedup rationale. - Round-trip int64-regime test now uses `pd.testing.assert_frame_equal` across all output columns, not `.equals()` on one column. Structural fixes: - `_collect_unique_simulate_functions` now keys `next_state` and `compute_regime_transition_probs` by `(kind, regime_name, callable-id)` instead of `(kind, callable-id)`. Two regimes that share a callable identity now still get distinct compiled programs (carried-forward issue 17 from prior review). Removes the need for the comment that pointed at a specific test file path. - `_cast_int_leaves_to_int32` swaps the `hasattr(value, "dtype")` duck-type check for `isinstance(value, (Array, np.ndarray))` — closes the PR-#302 review thread on duck-typing arrays. - `draw_key_from_dict`: pin `regime_ids` to `jnp.int32` so the "all integers are int32" invariant holds throughout simulate, not just at the AOT-traced boundaries.
`compile_all_simulate_functions` previously held every entry's concrete lower-args (V-shaped templates, per-regime subject-state / action zeros, regime-params view) in `unique[key][1]` until the function returned, and every `Lowered` HLO module in `lowered[key]` until the slowest parallel compile finished. With per-target next-state DAGs and an AOT cache that pins every kernel alive, the overlap of in-flight lower-args + lowered HLO + compiled kernels was the dominant contributor to the simulate-side compile-phase peak. Drop the args from `unique[key]` immediately after each lowering, and `del lowered[k]` as soon as the corresponding `Compiled` lands in `compiled`. The dedup keys, the parallel pool semantics, and the swap-in step are unchanged; existing tests (n_subjects mismatch, unpickled-model AOT round-trip, dtype invariants) remain green.
The local in `_raise_at` holds the regime's params after merging `resolved_fixed_params` in by hand — the same merge the live solve loop performs implicitly via partialled closures. `effective_regime_params` captures that intent; `diag_params` only said "I'm passing this to the diagnostic call". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker
commented
May 7, 2026
| if not infeasible_indices: | ||
| return None | ||
|
|
||
| per_constraint_admits_any = _per_constraint_feasibility( |
Member
Author
There was a problem hiding this comment.
This helps a lot for debugging invalid initial conditions. Sorry it landed here, making the PR bigger...
Member
Author
|
Note:
If this ever were the bottleneck, reducing Plus it seems to be gone in #345. |
Drops the "Python scalar pass-through to keep JAX weak-typing semantics" line from `_cast_int_leaves_to_int32`: a Python `int` arriving at a DAG function as a Python scalar now becomes `jnp.int32(value)` so downstream code sees a JAX-typed scalar rather than a Python int that JAX would promote per call site. This finishes the "no Python scalars inside JIT'd loops" goal that motivated the dtype-barrier work. `bool` is short-circuited (the float-side cast pass on #345 will handle it; bool is a Python `int` subclass so the bool branch must come before the int one). Module + function docstrings refreshed accordingly. Test `test_process_params_passes_python_int_through_for_jax_weak_typing` renamed and flipped to assert `dtype == jnp.int32`. `ScalarInt` keeps the `int | Int32[Scalar, ""]` union: tightening to JAX-only cascades into 13 call-site mismatches at internal Python metadata sites (`n_points`, `period`) that legitimately pass Python `int` outside the JIT'd DAG. Tightening the alias is a separate audit and follow-up. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker
added a commit
that referenced
this pull request
May 8, 2026
Folds #340's int Python-scalar cast into the broader #345 work: - `_cast_leaves_to_canonical_dtype` becomes a strict whitelist with explicit raises (no silent fallthrough). Dispatch order: `MappingLeaf`/`SequenceLeaf` → recurse, `pd.Series` → defensive raise (orchestrator must reshape first), `bool` (before `int`), `int`, `float`, `np.ndarray | jax.Array` keyed on `dtype.kind` (`b/i/u/f`). Anything else (`str`, complex/object arrays, custom objects) raises `InvalidParamsError` with the leaf's qualified name. - `broadcast_to_template` becomes broadcast-only — the inline cast loop moves out into the new top-level `cast_params_to_canonical_dtypes(internal_params)` helper. - `process_params` chains broadcast + cast for callers that have no `pd.Series` leaves; pd.Series-bearing callers must orchestrate the three steps explicitly. - `_process_params` (model.py) and `_build_regimes_and_template_with_fixed_params` (model_processing.py) now sequence broadcast → reshape → cast → validate. The cast walks a uniform tree (no `pd.Series` to special- case) which lets the whitelist drop the pass-through branch. - `bool` casts to `jnp.bool_`, `float` casts to `canonical_float_dtype()`, completing the "no Python scalars inside JIT'd loops" goal — DAG functions now receive JAX-typed scalars end-to-end. - `ScalarBool` alias added (`bool | Bool[Scalar, ""]`). `ScalarInt` / `ScalarFloat` keep their `int |` / `float |` unions for now; the tighter forms cascade into call-site mismatches at internal Python metadata sites (`n_points`, `period`) and warrant a separate audit. Tests: - `test_process_params_passes_python_float_through_for_jax_weak_typing` → `..._casts_python_float_to_canonical`, asserts dtype. - `test_python_float_param_passed_through_for_weak_typing` → `..._cast_to_canonical_dtype`, asserts dtype. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When `Model(n_subjects=N)` is set, simulate-side XLA compilation used to run lazily on the first matching `simulate(...)` call — strictly after `solve(...)` returned. On production aca-baseline that adds several minutes to the end-to-end wall clock for nothing: solve is GPU-bound, simulate compile is CPU-bound XLA work, so they overlap trivially. Add `_maybe_start_simulate_compile_async` and call it from `solve(...)` right after parameters are processed. It spawns a single-worker `ThreadPoolExecutor` that runs `compile_all_simulate_functions` in the background and parks the result on `_simulate_compile_future`. `_resolve_simulate_internal_regimes` awaits the future before populating the cache, so the lazy fallback path (no `solve` call, direct `simulate(...)`) still works. `__getstate__` / `__setstate__` drop the future on the way out and reset to `None` on the way in — `concurrent.futures.Future` is tied to its originating thread pool and can't survive a process boundary. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Goal
Do the same for simulate as we did for solve before: AOT-compile. Main reason is to fail early and get cleaner benchmarks (sim-jit was included in runtime previously).
This showed subtle differences re what int types were used during compilation and runs. Explicit casts inside the hot loop turned out to be necessary, but very costly. For speed / mem efficiency / peace of mind, convert everything to
jnp.int32at the boundary (no use cases will ever require ints > 2_147_483_647).Summary
Two coupled changes that together let
simulate(...)ship a single AOT-compiled program per unique callable and batch shape:1.
Model(n_subjects: int | None = None)— AOT-compile simulaten_subjectsis set, the first matchingsimulate(...)call AOT-compiles every unique simulate function (argmax_and_max_Q_over_aper period,next_stateper regime,compute_regime_transition_probsper regime) for batch shapen_subjectsin parallel viaThreadPoolExecutor, mirroring solve's existing AOT path insolve_brute._compile_all_functions.period_to_regime_to_V_arrat the entry ofsimulate(...)so every period dispatches with the same pytree (active-regime padding with zeros). Without this the last period's emptynext_regime_to_V_arrbreaks both the AOT pytree signature and JAX's own JIT cache reuse.n_subjects=None(the default) preserves the previous lazy behaviour.2. Lock integer dtype to
int32end-to-end (formerly #341, squash-merged as6c610d1)The lazy JIT cache silently retraces per dtype, so
int32(no x64) vsint64(x64) variants of the same regime compiled into different specialisations. AOT compile viajax.jit(...).lower(**args).compile()ships a single signature and broke at runtime withint32[N] vs int64[N]mismatches. Fix:Int1D/IntND/DiscreteState/DiscreteAction/ScalarIntinsrc/lcm/typing.pyfromInttoInt32— ~113 internal usages inherit the dtype constraint without further edits.DiscreteGrid.to_jax()toint32regardless ofjax_enable_x64.build_initial_statescasts discrete states to grid dtype (one-shot fix for the period-0 vs period-1+ dispatch dtype split).jax_enable_x64:regime_building/argmax.py:46, 68—argmaxand scalar fallback (realint64→int32narrowing under x64).simulation/initial_conditions.py:365, 409andutils/error_handling.py:376—whereindex extractions in error/validation paths.pandas_utils.py:155— regime-id ingestion cast toint32to matchDiscreteGrid.to_jax().simulation/simulate.py:102—subject_regime_idssentinel buffer pinned toint32regardless of input regime dtype.Note:
jnp.searchsortedalready returnsint32even under x64, andunravel_indexoutputs are immediately consumed by index ops that accept any int dtype — so no inline casts are needed inside hot interpolation/lookup kernels (an earlier draft added them, which broke XLA fusion and roughly doubled GPU peak memory on aca-baseline).Test plan
pixi run -e tests-cpu pytest tests/simulation/test_simulate_aot.py— 7 new tests cover validation, status-quo regression, AOT-compile-once, mismatch warn, warn-once-per-size.pixi run -e tests-cpu pytest tests/ -n 7— 895 passed, 5 skipped (full pylcm suite green; +3 dtype-invariant tests intests/test_int_dtype_invariants.py).pixi run -e type-checking ty— clean.prek run --all-files— clean.simulate OK, n=20with all int initial-condition entries reportingint32.Stacked on #339.
🤖 Generated with Claude Code