diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 2d696a45..cf1b4cb0 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -153,7 +153,8 @@ def initial_conditions_from_dataframe( # noqa: C901 for col, arr in result_arrays.items() } initial_conditions["regime"] = jnp.array( - df["regime"].map(dict(regime_names_to_ids)).to_numpy() + df["regime"].map(dict(regime_names_to_ids)).to_numpy(), + dtype=jnp.int32, ) return initial_conditions diff --git a/src/lcm/regime_building/argmax.py b/src/lcm/regime_building/argmax.py index 0e48cf2b..a4271e2f 100644 --- a/src/lcm/regime_building/argmax.py +++ b/src/lcm/regime_building/argmax.py @@ -43,7 +43,7 @@ def argmax_and_max( # When there are no dimensions to reduce over, return: # - index 0 (trivial argmax since there's only one element) # - the array itself (already the maximum) - return jnp.array(0), a + return jnp.array(0, dtype=jnp.int32), a # Move axis over which to compute the argmax to the back and flatten last dims # ================================================================================== @@ -65,7 +65,7 @@ def argmax_and_max( max_value_mask = a == _max if where is not None: max_value_mask = jnp.logical_and(max_value_mask, where) - _argmax = jnp.argmax(max_value_mask, axis=-1) + _argmax = jnp.argmax(max_value_mask, axis=-1).astype(jnp.int32) return _argmax, _max.reshape(_argmax.shape) diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 22a8dd2e..127059e1 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -362,7 +362,7 @@ def _collect_structural_errors( active_mask = active_mask & (~in_regime | period_active) if not jnp.all(active_mask): - invalid_indices = jnp.where(~active_mask)[0] + invalid_indices = jnp.where(~active_mask)[0].astype(jnp.int32) invalid_combos = { (ids_to_regime_names[int(regime_id_arr[i])], float(age_values[i])) for i in invalid_indices @@ -406,7 +406,7 @@ def _collect_feasibility_errors( errors: list[str] = [] for regime_name, internal_regime in internal_regimes.items(): regime_id = regime_names_to_ids[regime_name] - idx_arr = jnp.where(regime_id_arr == regime_id)[0] + idx_arr = jnp.where(regime_id_arr == regime_id)[0].astype(jnp.int32) subject_indices = idx_arr.tolist() if idx_arr.size > 0 else [] if not subject_indices: continue diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d54d2674..d1ab42ab 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -99,7 +99,9 @@ def simulate( starting_periods = _compute_starting_periods( initial_ages=initial_states["age"], ages=ages ) - subject_regime_ids = jnp.full_like(initial_conditions["regime"], MISSING_CAT_CODE) + subject_regime_ids = jnp.full_like( + initial_conditions["regime"], MISSING_CAT_CODE, dtype=jnp.int32 + ) # Forward simulation simulation_results: dict[RegimeName, dict[int, PeriodRegimeSimulationData]] = { diff --git a/src/lcm/typing.py b/src/lcm/typing.py index c73b4c33..62492770 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -4,27 +4,27 @@ import pandas as pd from jax import Array -from jaxtyping import Bool, Float, Int, Scalar +from jaxtyping import Bool, Float, Int32, Scalar from lcm.params import MappingLeaf from lcm.params.sequence_leaf import SequenceLeaf type ContinuousState = Float[Array, "..."] type ContinuousAction = Float[Array, "..."] -type DiscreteState = Int[Array, "..."] -type DiscreteAction = Int[Array, "..."] +type DiscreteState = Int32[Array, "..."] +type DiscreteAction = Int32[Array, "..."] type FloatND = Float[Array, "..."] -type IntND = Int[Array, "..."] +type IntND = Int32[Array, "..."] type BoolND = Bool[Array, "..."] type Float1D = Float[Array, "_"] # noqa: F821 -type Int1D = Int[Array, "_"] # noqa: F821 +type Int1D = Int32[Array, "_"] # noqa: F821 type Bool1D = Bool[Array, "_"] # noqa: F821 # Many JAX functions are designed to work with scalar numerical values. This also # includes zero dimensional jax arrays. -type ScalarInt = int | Int[Scalar, ""] +type ScalarInt = int | Int32[Scalar, ""] type ScalarFloat = float | Float[Scalar, ""] type Period = int | Int1D diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index c7f16008..7f790208 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -373,7 +373,7 @@ def _format_sum_violation( {name: jnp.atleast_1d(arr) for name, arr in state_action_values.items()} ) failing_mask = ~jnp.isclose(sum_all, 1.0) - failing_indices = jnp.where(failing_mask)[0] + failing_indices = jnp.where(failing_mask)[0].astype(jnp.int32) failing_sums = sum_all[failing_mask] n_failing = int(failing_indices.shape[0]) n_show = min(n_failing, 5) diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py new file mode 100644 index 00000000..3147de3f --- /dev/null +++ b/tests/test_int_dtype_invariants.py @@ -0,0 +1,41 @@ +"""Integer dtypes are pinned to int32 across pylcm regardless of x64 mode.""" + +import jax.numpy as jnp + +from lcm.simulation.initial_conditions import ( + MISSING_CAT_CODE, + build_initial_states, +) +from tests.test_models.deterministic.regression import get_model + + +def test_discrete_grid_to_jax_is_int32() -> None: + model = get_model(n_periods=3) + for regime in model.regimes.values(): + for grid in {**regime.states, **regime.actions}.values(): + jax_arr = grid.to_jax() + if jax_arr.dtype.kind == "i": + assert jax_arr.dtype == jnp.int32, ( + f"Discrete grid yielded {jax_arr.dtype}, expected int32." + ) + + +def test_build_initial_states_discrete_dtype_is_int32() -> None: + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.array([20.0, 50.0]), + "age": jnp.array([18.0, 18.0]), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + for key, arr in flat.items(): + if arr.dtype.kind == "i": + assert arr.dtype == jnp.int32, ( + f"Initial state {key} has dtype {arr.dtype}, expected int32." + ) + + +def test_missing_cat_code_is_int32_minimum() -> None: + assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE