From 39dc40747f53d736af10eae0e7de933ddff9fd71 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 10:48:17 +0200 Subject: [PATCH 1/2] Lock integer dtype to int32 end-to-end Tighten Int1D/IntND/DiscreteState/DiscreteAction/ScalarInt to Int32 in typing.py, and cast searchsorted/argmax/unravel_index/where outputs to int32 at every site where their width depended on jax_enable_x64. This prevents the JIT cache from silently splitting into per-period int32/int64 variants and breaks the AOT-compiled simulate program that ships a single signature. Adds a regression test asserting discrete grids, build_initial_states discrete entries, and MISSING_CAT_CODE match int32. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/coordinates.py | 2 +- src/lcm/grids/piecewise.py | 8 +++-- src/lcm/pandas_utils.py | 3 +- src/lcm/regime_building/argmax.py | 4 +-- src/lcm/simulation/initial_conditions.py | 4 +-- src/lcm/simulation/simulate.py | 8 +++-- src/lcm/typing.py | 12 +++---- src/lcm/utils/error_handling.py | 2 +- tests/test_int_dtype_invariants.py | 41 ++++++++++++++++++++++++ 9 files changed, 66 insertions(+), 18 deletions(-) create mode 100644 tests/test_int_dtype_invariants.py diff --git a/src/lcm/grids/coordinates.py b/src/lcm/grids/coordinates.py index 0edce4bd..583b9b4d 100644 --- a/src/lcm/grids/coordinates.py +++ b/src/lcm/grids/coordinates.py @@ -188,7 +188,7 @@ def get_irreg_coordinate( n_points = len(points) # Find the index of the first point greater than value - idx_upper = jnp.searchsorted(points, value, side="right") + idx_upper = jnp.searchsorted(points, value, side="right").astype(jnp.int32) # Clamp to valid range for interpolation idx_upper = jnp.clip(idx_upper, 1, n_points - 1) diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index e3a252f2..aae8a0f3 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -87,7 +87,9 @@ def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" - piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") + piece_idx = jnp.searchsorted(self._breakpoints, value, side="right").astype( + jnp.int32 + ) local_coord = grid_coordinates.get_linspace_coordinate( value=value, start=self._piece_starts[piece_idx], @@ -158,7 +160,9 @@ def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" - piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") + piece_idx = jnp.searchsorted(self._breakpoints, value, side="right").astype( + jnp.int32 + ) local_coord = grid_coordinates.get_logspace_coordinate( value=value, start=self._piece_starts[piece_idx], diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 2d696a45..cf1b4cb0 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -153,7 +153,8 @@ def initial_conditions_from_dataframe( # noqa: C901 for col, arr in result_arrays.items() } initial_conditions["regime"] = jnp.array( - df["regime"].map(dict(regime_names_to_ids)).to_numpy() + df["regime"].map(dict(regime_names_to_ids)).to_numpy(), + dtype=jnp.int32, ) return initial_conditions diff --git a/src/lcm/regime_building/argmax.py b/src/lcm/regime_building/argmax.py index 0e48cf2b..a4271e2f 100644 --- a/src/lcm/regime_building/argmax.py +++ b/src/lcm/regime_building/argmax.py @@ -43,7 +43,7 @@ def argmax_and_max( # When there are no dimensions to reduce over, return: # - index 0 (trivial argmax since there's only one element) # - the array itself (already the maximum) - return jnp.array(0), a + return jnp.array(0, dtype=jnp.int32), a # Move axis over which to compute the argmax to the back and flatten last dims # ================================================================================== @@ -65,7 +65,7 @@ def argmax_and_max( max_value_mask = a == _max if where is not None: max_value_mask = jnp.logical_and(max_value_mask, where) - _argmax = jnp.argmax(max_value_mask, axis=-1) + _argmax = jnp.argmax(max_value_mask, axis=-1).astype(jnp.int32) return _argmax, _max.reshape(_argmax.shape) diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 22a8dd2e..127059e1 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -362,7 +362,7 @@ def _collect_structural_errors( active_mask = active_mask & (~in_regime | period_active) if not jnp.all(active_mask): - invalid_indices = jnp.where(~active_mask)[0] + invalid_indices = jnp.where(~active_mask)[0].astype(jnp.int32) invalid_combos = { (ids_to_regime_names[int(regime_id_arr[i])], float(age_values[i])) for i in invalid_indices @@ -406,7 +406,7 @@ def _collect_feasibility_errors( errors: list[str] = [] for regime_name, internal_regime in internal_regimes.items(): regime_id = regime_names_to_ids[regime_name] - idx_arr = jnp.where(regime_id_arr == regime_id)[0] + idx_arr = jnp.where(regime_id_arr == regime_id)[0].astype(jnp.int32) subject_indices = idx_arr.tolist() if idx_arr.size > 0 else [] if not subject_indices: continue diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d54d2674..4ed08140 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -99,7 +99,9 @@ def simulate( starting_periods = _compute_starting_periods( initial_ages=initial_states["age"], ages=ages ) - subject_regime_ids = jnp.full_like(initial_conditions["regime"], MISSING_CAT_CODE) + subject_regime_ids = jnp.full_like( + initial_conditions["regime"], MISSING_CAT_CODE, dtype=jnp.int32 + ) # Forward simulation simulation_results: dict[RegimeName, dict[int, PeriodRegimeSimulationData]] = { @@ -353,7 +355,7 @@ def _lookup_values_from_indices( nd_indices = vmapped_unravel_index(flat_indices, grids_shapes) return MappingProxyType( { - name: grid[index] + name: grid[index.astype(jnp.int32)] for (name, grid), index in zip(grids.items(), nd_indices, strict=True) } ) @@ -383,7 +385,7 @@ def _compute_starting_periods( """ age_values = jnp.asarray(ages.values) - starting_periods = jnp.searchsorted(age_values, initial_ages) + starting_periods = jnp.searchsorted(age_values, initial_ages).astype(jnp.int32) # Clamp indices to valid range before accessing age_values. searchsorted can # return len(age_values) for ages beyond the grid maximum. diff --git a/src/lcm/typing.py b/src/lcm/typing.py index c73b4c33..62492770 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -4,27 +4,27 @@ import pandas as pd from jax import Array -from jaxtyping import Bool, Float, Int, Scalar +from jaxtyping import Bool, Float, Int32, Scalar from lcm.params import MappingLeaf from lcm.params.sequence_leaf import SequenceLeaf type ContinuousState = Float[Array, "..."] type ContinuousAction = Float[Array, "..."] -type DiscreteState = Int[Array, "..."] -type DiscreteAction = Int[Array, "..."] +type DiscreteState = Int32[Array, "..."] +type DiscreteAction = Int32[Array, "..."] type FloatND = Float[Array, "..."] -type IntND = Int[Array, "..."] +type IntND = Int32[Array, "..."] type BoolND = Bool[Array, "..."] type Float1D = Float[Array, "_"] # noqa: F821 -type Int1D = Int[Array, "_"] # noqa: F821 +type Int1D = Int32[Array, "_"] # noqa: F821 type Bool1D = Bool[Array, "_"] # noqa: F821 # Many JAX functions are designed to work with scalar numerical values. This also # includes zero dimensional jax arrays. -type ScalarInt = int | Int[Scalar, ""] +type ScalarInt = int | Int32[Scalar, ""] type ScalarFloat = float | Float[Scalar, ""] type Period = int | Int1D diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index c7f16008..7f790208 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -373,7 +373,7 @@ def _format_sum_violation( {name: jnp.atleast_1d(arr) for name, arr in state_action_values.items()} ) failing_mask = ~jnp.isclose(sum_all, 1.0) - failing_indices = jnp.where(failing_mask)[0] + failing_indices = jnp.where(failing_mask)[0].astype(jnp.int32) failing_sums = sum_all[failing_mask] n_failing = int(failing_indices.shape[0]) n_show = min(n_failing, 5) diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py new file mode 100644 index 00000000..3147de3f --- /dev/null +++ b/tests/test_int_dtype_invariants.py @@ -0,0 +1,41 @@ +"""Integer dtypes are pinned to int32 across pylcm regardless of x64 mode.""" + +import jax.numpy as jnp + +from lcm.simulation.initial_conditions import ( + MISSING_CAT_CODE, + build_initial_states, +) +from tests.test_models.deterministic.regression import get_model + + +def test_discrete_grid_to_jax_is_int32() -> None: + model = get_model(n_periods=3) + for regime in model.regimes.values(): + for grid in {**regime.states, **regime.actions}.values(): + jax_arr = grid.to_jax() + if jax_arr.dtype.kind == "i": + assert jax_arr.dtype == jnp.int32, ( + f"Discrete grid yielded {jax_arr.dtype}, expected int32." + ) + + +def test_build_initial_states_discrete_dtype_is_int32() -> None: + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.array([20.0, 50.0]), + "age": jnp.array([18.0, 18.0]), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + for key, arr in flat.items(): + if arr.dtype.kind == "i": + assert arr.dtype == jnp.int32, ( + f"Initial state {key} has dtype {arr.dtype}, expected int32." + ) + + +def test_missing_cat_code_is_int32_minimum() -> None: + assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE From 0be8a60c5f21190168155641e312ab8b47dd05d5 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 15:05:49 +0200 Subject: [PATCH 2/2] Drop redundant searchsorted/unravel_index int32 casts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `jnp.searchsorted` already returns int32 even with `jax_enable_x64`, so the four `.astype(jnp.int32)` casts in `grids/coordinates.py`, `grids/piecewise.py` (×2), and `simulation/simulate.py:_compute_starting_periods` were no-ops at the dtype level — but they sat between an integer-producing op and its index-consumer inside vmap'd interpolation kernels, breaking XLA's fusion and forcing the intermediate to materialise as a top-level GPU buffer per (period, regime, state). Likewise, the `unravel_index` output in `_lookup_values_from_indices` is consumed immediately by `grid[index]`, which accepts int64 fine — the cast served no purpose. Keeps the argmax cast on the solve path (real int64→int32 narrowing), the boundary casts at error/validation paths, and the AOT-relevant casts in `pandas_utils` and the `subject_regime_ids` sentinel. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/coordinates.py | 2 +- src/lcm/grids/piecewise.py | 8 ++------ src/lcm/simulation/simulate.py | 4 ++-- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/lcm/grids/coordinates.py b/src/lcm/grids/coordinates.py index 583b9b4d..0edce4bd 100644 --- a/src/lcm/grids/coordinates.py +++ b/src/lcm/grids/coordinates.py @@ -188,7 +188,7 @@ def get_irreg_coordinate( n_points = len(points) # Find the index of the first point greater than value - idx_upper = jnp.searchsorted(points, value, side="right").astype(jnp.int32) + idx_upper = jnp.searchsorted(points, value, side="right") # Clamp to valid range for interpolation idx_upper = jnp.clip(idx_upper, 1, n_points - 1) diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index aae8a0f3..e3a252f2 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -87,9 +87,7 @@ def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" - piece_idx = jnp.searchsorted(self._breakpoints, value, side="right").astype( - jnp.int32 - ) + piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_linspace_coordinate( value=value, start=self._piece_starts[piece_idx], @@ -160,9 +158,7 @@ def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" - piece_idx = jnp.searchsorted(self._breakpoints, value, side="right").astype( - jnp.int32 - ) + piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_logspace_coordinate( value=value, start=self._piece_starts[piece_idx], diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 4ed08140..d1ab42ab 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -355,7 +355,7 @@ def _lookup_values_from_indices( nd_indices = vmapped_unravel_index(flat_indices, grids_shapes) return MappingProxyType( { - name: grid[index.astype(jnp.int32)] + name: grid[index] for (name, grid), index in zip(grids.items(), nd_indices, strict=True) } ) @@ -385,7 +385,7 @@ def _compute_starting_periods( """ age_values = jnp.asarray(ages.values) - starting_periods = jnp.searchsorted(age_values, initial_ages).astype(jnp.int32) + starting_periods = jnp.searchsorted(age_values, initial_ages) # Clamp indices to valid range before accessing age_values. searchsorted can # return len(age_values) for ages beyond the grid maximum.