Skip to content

Lock integer dtype to int32 end-to-end#341

Merged
hmgaudecker merged 2 commits intofeat/simulate-aot-n-subjectsfrom
chore/int32-everywhere
May 3, 2026
Merged

Lock integer dtype to int32 end-to-end#341
hmgaudecker merged 2 commits intofeat/simulate-aot-n-subjectsfrom
chore/int32-everywhere

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

Summary

Stacks on #340. Finishes the int dtype lock-in started by the pinhole fixes
on feat/simulate-aot-n-subjects (DiscreteGrid.to_jax → int32,
build_initial_states casts to grid dtype). Now every internal integer
JAX array is int32 regardless of jax_enable_x64, and the type aliases
in src/lcm/typing.py advertise that contract so ty flags any future
regression at edit time.

Why this matters: the lazy JIT cache silently retraces per dtype, so
int32 (no x64) vs int64 (x64) variants of the same regime compiled
into different specializations. AOT compile via
jax.jit(...).lower(**args).compile() ships a single signature and
broke at runtime with int32[N] vs int64[N] mismatches.

Changes

  • src/lcm/typing.py — replace Int with Int32 for Int1D, IntND,
    DiscreteState, DiscreteAction, ScalarInt. ~113 internal usages
    inherit the dtype constraint without further edits.
  • Boundary casts at every site whose integer dtype depended on
    jax_enable_x64:
    • grids/coordinates.py:191searchsorted
    • grids/piecewise.py:90, 161searchsorted
    • simulation/simulate.py:386searchsorted (starting periods)
    • simulation/simulate.py:357unravel_index outputs
    • regime_building/argmax.py:46, 68 — scalar fallback + argmax
    • simulation/initial_conditions.py:365, 409where index extractions
    • utils/error_handling.py:376where index extraction
  • pandas_utils.py:155 — regime-id ingestion cast to int32 to match
    DiscreteGrid.to_jax().
  • simulation/simulate.py:102subject_regime_ids sentinel buffer
    pinned to int32 regardless of input regime dtype.

Test plan

  • pixi run -e tests-cpu pytest tests/ -n 7 — 895 passed (3 new
    dtype-invariant tests added)
  • pixi run -e type-checking ty — clean
  • prek run --all-files — clean
  • aca-dev simulate-AOT smoke (/tmp/smoke8.py) — simulate OK, n=20
    with all int initial-condition entries reporting int32

🤖 Generated with Claude Code

Tighten Int1D/IntND/DiscreteState/DiscreteAction/ScalarInt to Int32 in
typing.py, and cast searchsorted/argmax/unravel_index/where outputs to
int32 at every site where their width depended on jax_enable_x64. This
prevents the JIT cache from silently splitting into per-period
int32/int64 variants and breaks the AOT-compiled simulate program that
ships a single signature. Adds a regression test asserting discrete
grids, build_initial_states discrete entries, and MISSING_CAT_CODE
match int32.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 3, 2026

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 3, 2026

Benchmark comparison (main → HEAD)

Comparing 8f2a4cfc (main) → 0be8a60c (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 41.879 s 29.010 s 0.69
peak GPU mem 539 MB 579 MB 1.07
compilation time 385.10 s 399.76 s 1.04
peak CPU mem 8.08 GB 9.96 GB 1.23
Mahler-Yum execution time 5.261 s 4.655 s 0.88
peak GPU mem 522 MB 522 MB 1.00
compilation time 17.01 s 16.73 s 0.98
peak CPU mem 1.69 GB 1.67 GB 0.99
Precautionary Savings - Solve execution time 52.7 ms 49.9 ms 0.95
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.61 s 2.61 s 1.00
peak CPU mem 1.12 GB 1.12 GB 1.00
Precautionary Savings - Simulate execution time 146.7 ms 113.8 ms 0.78
peak GPU mem 340 MB 340 MB 1.00
compilation time 6.08 s 6.23 s 1.03
peak CPU mem 1.28 GB 1.29 GB 1.00
Precautionary Savings - Solve & Simulate execution time 166.1 ms 139.8 ms 0.84
peak GPU mem 577 MB 577 MB 1.00
compilation time 8.30 s 7.92 s 0.95
peak CPU mem 1.28 GB 1.28 GB 1.01
Precautionary Savings - Solve & Simulate (irreg) execution time 299.6 ms 285.1 ms 0.95
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 8.72 s 8.80 s 1.01
peak CPU mem 1.33 GB 1.33 GB 1.00

`jnp.searchsorted` already returns int32 even with `jax_enable_x64`, so
the four `.astype(jnp.int32)` casts in `grids/coordinates.py`,
`grids/piecewise.py` (×2), and `simulation/simulate.py:_compute_starting_periods`
were no-ops at the dtype level — but they sat between an integer-producing
op and its index-consumer inside vmap'd interpolation kernels, breaking
XLA's fusion and forcing the intermediate to materialise as a top-level
GPU buffer per (period, regime, state). Likewise, the `unravel_index`
output in `_lookup_values_from_indices` is consumed immediately by
`grid[index]`, which accepts int64 fine — the cast served no purpose.

Keeps the argmax cast on the solve path (real int64→int32 narrowing),
the boundary casts at error/validation paths, and the AOT-relevant
casts in `pandas_utils` and the `subject_regime_ids` sentinel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker merged commit 6c610d1 into feat/simulate-aot-n-subjects May 3, 2026
10 checks passed
@hmgaudecker hmgaudecker deleted the chore/int32-everywhere branch May 3, 2026 14:15
hmgaudecker added a commit that referenced this pull request May 4, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker added a commit that referenced this pull request May 4, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant