Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lcm/pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def initial_conditions_from_dataframe( # noqa: C901
for col, arr in result_arrays.items()
}
initial_conditions["regime"] = jnp.array(
df["regime"].map(dict(regime_names_to_ids)).to_numpy()
df["regime"].map(dict(regime_names_to_ids)).to_numpy(),
dtype=jnp.int32,
)

return initial_conditions
Expand Down
4 changes: 2 additions & 2 deletions src/lcm/regime_building/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def argmax_and_max(
# When there are no dimensions to reduce over, return:
# - index 0 (trivial argmax since there's only one element)
# - the array itself (already the maximum)
return jnp.array(0), a
return jnp.array(0, dtype=jnp.int32), a

# Move axis over which to compute the argmax to the back and flatten last dims
# ==================================================================================
Expand All @@ -65,7 +65,7 @@ def argmax_and_max(
max_value_mask = a == _max
if where is not None:
max_value_mask = jnp.logical_and(max_value_mask, where)
_argmax = jnp.argmax(max_value_mask, axis=-1)
_argmax = jnp.argmax(max_value_mask, axis=-1).astype(jnp.int32)

return _argmax, _max.reshape(_argmax.shape)

Expand Down
4 changes: 2 additions & 2 deletions src/lcm/simulation/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _collect_structural_errors(
active_mask = active_mask & (~in_regime | period_active)

if not jnp.all(active_mask):
invalid_indices = jnp.where(~active_mask)[0]
invalid_indices = jnp.where(~active_mask)[0].astype(jnp.int32)
invalid_combos = {
(ids_to_regime_names[int(regime_id_arr[i])], float(age_values[i]))
for i in invalid_indices
Expand Down Expand Up @@ -406,7 +406,7 @@ def _collect_feasibility_errors(
errors: list[str] = []
for regime_name, internal_regime in internal_regimes.items():
regime_id = regime_names_to_ids[regime_name]
idx_arr = jnp.where(regime_id_arr == regime_id)[0]
idx_arr = jnp.where(regime_id_arr == regime_id)[0].astype(jnp.int32)
subject_indices = idx_arr.tolist() if idx_arr.size > 0 else []
if not subject_indices:
continue
Expand Down
4 changes: 3 additions & 1 deletion src/lcm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def simulate(
starting_periods = _compute_starting_periods(
initial_ages=initial_states["age"], ages=ages
)
subject_regime_ids = jnp.full_like(initial_conditions["regime"], MISSING_CAT_CODE)
subject_regime_ids = jnp.full_like(
initial_conditions["regime"], MISSING_CAT_CODE, dtype=jnp.int32
)

# Forward simulation
simulation_results: dict[RegimeName, dict[int, PeriodRegimeSimulationData]] = {
Expand Down
12 changes: 6 additions & 6 deletions src/lcm/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@

import pandas as pd
from jax import Array
from jaxtyping import Bool, Float, Int, Scalar
from jaxtyping import Bool, Float, Int32, Scalar

from lcm.params import MappingLeaf
from lcm.params.sequence_leaf import SequenceLeaf

type ContinuousState = Float[Array, "..."]
type ContinuousAction = Float[Array, "..."]
type DiscreteState = Int[Array, "..."]
type DiscreteAction = Int[Array, "..."]
type DiscreteState = Int32[Array, "..."]
type DiscreteAction = Int32[Array, "..."]

type FloatND = Float[Array, "..."]
type IntND = Int[Array, "..."]
type IntND = Int32[Array, "..."]
type BoolND = Bool[Array, "..."]

type Float1D = Float[Array, "_"] # noqa: F821
type Int1D = Int[Array, "_"] # noqa: F821
type Int1D = Int32[Array, "_"] # noqa: F821
type Bool1D = Bool[Array, "_"] # noqa: F821

# Many JAX functions are designed to work with scalar numerical values. This also
# includes zero dimensional jax arrays.
type ScalarInt = int | Int[Scalar, ""]
type ScalarInt = int | Int32[Scalar, ""]
type ScalarFloat = float | Float[Scalar, ""]

type Period = int | Int1D
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/utils/error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _format_sum_violation(
{name: jnp.atleast_1d(arr) for name, arr in state_action_values.items()}
)
failing_mask = ~jnp.isclose(sum_all, 1.0)
failing_indices = jnp.where(failing_mask)[0]
failing_indices = jnp.where(failing_mask)[0].astype(jnp.int32)
failing_sums = sum_all[failing_mask]
n_failing = int(failing_indices.shape[0])
n_show = min(n_failing, 5)
Expand Down
41 changes: 41 additions & 0 deletions tests/test_int_dtype_invariants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Integer dtypes are pinned to int32 across pylcm regardless of x64 mode."""

import jax.numpy as jnp

from lcm.simulation.initial_conditions import (
MISSING_CAT_CODE,
build_initial_states,
)
from tests.test_models.deterministic.regression import get_model


def test_discrete_grid_to_jax_is_int32() -> None:
model = get_model(n_periods=3)
for regime in model.regimes.values():
for grid in {**regime.states, **regime.actions}.values():
jax_arr = grid.to_jax()
if jax_arr.dtype.kind == "i":
assert jax_arr.dtype == jnp.int32, (
f"Discrete grid yielded {jax_arr.dtype}, expected int32."
)


def test_build_initial_states_discrete_dtype_is_int32() -> None:
model = get_model(n_periods=3)
initial_states = {
"wealth": jnp.array([20.0, 50.0]),
"age": jnp.array([18.0, 18.0]),
}
flat = build_initial_states(
initial_states=initial_states,
internal_regimes=model.internal_regimes,
)
for key, arr in flat.items():
if arr.dtype.kind == "i":
assert arr.dtype == jnp.int32, (
f"Initial state {key} has dtype {arr.dtype}, expected int32."
)


def test_missing_cat_code_is_int32_minimum() -> None:
assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE