From e8813136e7ce04c7f998244421c5759fae36ce08 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 16:44:11 +0200 Subject: [PATCH 01/22] Package B: float dtype barriers at the API boundary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pin every float that crosses into pylcm to a single canonical dtype derived from `jax.config.jax_enable_x64`. Adds the `canonical_float_dtype()` and `safe_to_float_dtype(value, *, name)` helpers next to Package A's int counterparts, and applies them at the same boundaries. Helpers (`src/lcm/dtypes.py`): - `canonical_float_dtype()`: `jnp.float64` if x64, else `jnp.float32`. Read at call time so toggling JAX config between tests is honoured. - `safe_to_float_dtype(value, *, name)`: host-side cast with overflow check on float64 -> float32 down-casts. Up-casts and same-width casts skip the range check; precision loss within range is *not* an error (it's an inherent consequence of the user's x64 choice). Boundaries: - `_cast_int_leaves_to_int32` becomes `_cast_leaves_to_canonical_dtype`: one pass that handles both int (via `safe_to_int32`) and float (via `safe_to_float_dtype`). Python `int` / `float` / `bool` scalars pass through to keep JAX weak-typing semantics. `pd.Series` leaves pass through too — `convert_series_in_params` reshapes them later based on their multi-index. - `build_initial_states` casts continuous user arrays to `canonical_float_dtype()` and pins the missing-state NaN fallback to the same dtype. After this, both discrete (int32) and continuous (canonical float) state pools have stable dtypes across all simulate periods. - `_update_states_for_subjects` now unconditionally casts `next_state_values` to the storage dtype. Package A's cross-kind guard is no longer needed: with the continuous-state cast in place, storage dtype is always the canonical one for that kind. Tests: - New `tests/test_float_dtype_invariants.py` (10 tests): - `canonical_float_dtype()` follows `jax_enable_x64` - `safe_to_float_dtype` round-trip / down-cast / up-cast / overflow - `build_initial_states` continuous-state casts (float64 input, int input, missing-state NaN fallback) - `process_params` casts typed float arrays, leaves Python floats weak-typed, raises on float-array overflow with qualified name, casts inside `MappingLeaf` - `LinSpacedGrid` / `LogSpacedGrid` / `IrregSpacedGrid` `to_jax` materialise at canonical dtype - `model.solve(...)` V-arrays at canonical dtype - Multi-period simulate: every state's dtype is stable across periods (no silent promotion mid-run) - `test_validate_param_types`: `numpy_array_param_rejected` -> `numpy_array_param_accepted_and_normalised`. With the boundary cast in place, numpy arrays are auto-converted; the historical rejection-by-isinstance is no longer needed. - Extended `tests/test_dtypes.py` with 7 float-helper tests. After Package B, the simulate-AOT path traces against a single abstract signature (int32 + canonical float) regardless of how users supply their inputs (Python scalars, mixed-precision JAX arrays, numpy arrays). --- src/lcm/dtypes.py | 48 +++++ src/lcm/params/processing.py | 41 ++-- src/lcm/simulation/initial_conditions.py | 10 +- src/lcm/simulation/transitions.py | 16 +- tests/test_dtypes.py | 84 ++++++++- tests/test_float_dtype_invariants.py | 226 +++++++++++++++++++++++ tests/test_validate_param_types.py | 10 +- 7 files changed, 399 insertions(+), 36 deletions(-) create mode 100644 tests/test_float_dtype_invariants.py diff --git a/src/lcm/dtypes.py b/src/lcm/dtypes.py index 75dd591a..c556a1e1 100644 --- a/src/lcm/dtypes.py +++ b/src/lcm/dtypes.py @@ -8,12 +8,24 @@ which is out of scope for the boundary helpers. """ +import jax import jax.numpy as jnp import numpy as np from jax import Array _INT32_MIN = int(np.iinfo(np.int32).min) _INT32_MAX = int(np.iinfo(np.int32).max) +_FLOAT32_MAX = float(np.finfo(np.float32).max) + + +def canonical_float_dtype() -> jnp.dtype: + """Return pylcm's canonical float dtype, derived from `jax_enable_x64`. + + Returns `jnp.float64` if `jax.config.jax_enable_x64` is True, + otherwise `jnp.float32`. The value is read at call time, not at + import, so toggling the JAX config (e.g. between tests) is honoured. + """ + return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32 def safe_to_int32(value: object, *, name: str) -> Array: @@ -44,3 +56,39 @@ def safe_to_int32(value: object, *, name: str) -> Array: ) raise ValueError(msg) return jnp.asarray(np_value, dtype=jnp.int32) + + +def safe_to_float_dtype(value: object, *, name: str) -> Array: + """Cast a scalar, sequence, or array to the canonical float dtype. + + When the cast is *down* (float64 -> float32 under `jax_enable_x64=False`), + check that no element exceeds `float32` magnitude — raising + `OverflowError` if so rather than letting JAX silently saturate to + `±inf`. Up-casts and same-width casts skip the range check; precision + loss within range is *not* an error (it is an inherent consequence of + `jax_enable_x64=False`). + + Args: + value: A Python float, numpy/JAX scalar, or array-like. + name: Qualified name of the leaf — surfaced in the error message. + + Returns: + A JAX array at `canonical_float_dtype()` (0-d if `value` was a + scalar). + + Raises: + OverflowError: If down-casting to `float32` would saturate any + element to `±inf`. The message names the leaf via `name`. + + """ + target_dtype = canonical_float_dtype() + np_value = np.asarray(value) + if target_dtype == jnp.float32 and np_value.size > 0: + max_mag = float(np.max(np.abs(np_value))) + if max_mag > _FLOAT32_MAX: + msg = ( + f"{name}: float32 overflow — max |value| {max_mag:g} " + f"exceeds float32 max {_FLOAT32_MAX:g}." + ) + raise OverflowError(msg) + return jnp.asarray(np_value, dtype=target_dtype) diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index 03060407..f0c9c3ce 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -5,9 +5,10 @@ from typing import Any, cast import numpy as np +import pandas as pd from dags.tree import QNAME_DELIMITER, qname_from_tree_path, tree_path_from_qname -from lcm.dtypes import safe_to_int32 +from lcm.dtypes import safe_to_float_dtype, safe_to_int32 from lcm.exceptions import InvalidNameError, InvalidParamsError from lcm.interfaces import InternalRegime from lcm.params.mapping_leaf import MappingLeaf @@ -116,7 +117,7 @@ def broadcast_to_template( for regime, leaves in result.items(): for param_qname, value in leaves.items(): - leaves[param_qname] = _cast_int_leaves_to_int32( + leaves[param_qname] = _cast_leaves_to_canonical_dtype( value, name=f"{regime}{QNAME_DELIMITER}{param_qname}" ) @@ -126,36 +127,42 @@ def broadcast_to_template( ) -def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 - """Normalise typed integer arrays in a params value to `jnp.int32`. +def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: ANN401 + """Normalise typed numeric arrays in a params value to canonical pylcm dtypes. - Only typed JAX or numpy integer arrays are cast — Python `int` / `bool` - leaves stay unmodified. JAX's weak-typing rules promote raw Python ints - correctly to whichever dtype the surrounding operation needs (e.g. - `discount_factor: 1` works in a float-typed function), so casting them - eagerly to `int32` would force premature dtype commitment. Typed - arrays (`jnp.array(..., dtype=jnp.int64)`) are strongly typed by JAX - and would otherwise leak their dtype into the AOT signature. + Casts typed JAX or numpy integer arrays to `jnp.int32`, and typed float + arrays to `canonical_float_dtype()`. Python `int` / `float` / `bool` + leaves stay unmodified — JAX's weak-typing rules promote them to + whichever dtype the surrounding operation needs (e.g. `discount_factor: 1` + works in a float-typed function), so eager casting would force + premature dtype commitment. - Walks `MappingLeaf` and `SequenceLeaf` recursively. Float and - non-numeric leaves pass through — float normalisation is Package B. + Walks `MappingLeaf` and `SequenceLeaf` recursively. Non-numeric + typed leaves pass through. """ if isinstance(value, MappingLeaf): return MappingLeaf( { - k: _cast_int_leaves_to_int32(v, name=f"{name}.{k}") + k: _cast_leaves_to_canonical_dtype(v, name=f"{name}.{k}") for k, v in value.data.items() } ) if isinstance(value, SequenceLeaf): return SequenceLeaf( [ - _cast_int_leaves_to_int32(v, name=f"{name}[{i}]") + _cast_leaves_to_canonical_dtype(v, name=f"{name}[{i}]") for i, v in enumerate(value.data) ] ) - if hasattr(value, "dtype") and np.issubdtype(value.dtype, np.integer): - return safe_to_int32(value, name=name) + # `pd.Series` leaves are reshaped by `convert_series_in_params` based on + # their multi-index; flattening here via `np.asarray` would lose that. + if isinstance(value, pd.Series): + return value + if hasattr(value, "dtype"): + if np.issubdtype(value.dtype, np.integer): + return safe_to_int32(value, name=name) + if np.issubdtype(value.dtype, np.floating): + return safe_to_float_dtype(value, name=name) return value diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 69ca8d9d..3db3063e 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -16,6 +16,7 @@ from jax import numpy as jnp from lcm.ages import PSEUDO_STATE_NAMES, AgeGrid +from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype from lcm.exceptions import ( InvalidInitialConditionsError, format_messages, @@ -77,9 +78,14 @@ def build_initial_states( n_subjects, MISSING_CAT_CODE, dtype=target_dtype ) elif state_name in initial_states: - flat[key] = initial_states[state_name] + # Cast user-supplied continuous states to the canonical float + # dtype so the simulate state pool has one signature across + # periods regardless of the user-supplied dtype. + flat[key] = safe_to_float_dtype( + initial_states[state_name], name=f"initial_states.{state_name}" + ) else: - flat[key] = jnp.full(n_subjects, jnp.nan) + flat[key] = jnp.full(n_subjects, jnp.nan, dtype=canonical_float_dtype()) return MappingProxyType(flat) diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index db1fec82..95415db2 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -287,18 +287,14 @@ def _update_states_for_subjects( for next_state_name, next_state_values in target_next_states.items(): state_name = f"{target}__{next_state_name.removeprefix('next_')}" target_dtype = all_states[state_name].dtype - # Preserve storage dtype only when the transition output is the - # same numeric kind. Across kinds (e.g. int storage + float - # transition output) leave JAX's promotion in place; the - # cross-kind boundary cast belongs to Package B. - new_values = ( - next_state_values.astype(target_dtype) - if next_state_values.dtype.kind == target_dtype.kind - else next_state_values - ) + # Pin transition outputs to the storage dtype before `jnp.where`. + # Initial-condition boundary casts ensure storage already reflects + # the canonical dtype (int32 for discrete, `canonical_float_dtype()` + # for continuous), so this cast is value-preserving for any + # well-typed user transition. updated_states[state_name] = jnp.where( subject_indices, - new_values, + next_state_values.astype(target_dtype), all_states[state_name], ) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 77b0c69f..4908049d 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -1,10 +1,33 @@ """Tests for `lcm.dtypes` boundary-cast helpers.""" +from collections.abc import Iterator + import jax.numpy as jnp import numpy as np import pytest +from jax import config as jax_config + +from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype, safe_to_int32 + + +@pytest.fixture(name="x64_disabled") +def _fixture_x64_disabled() -> Iterator[None]: + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=False) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) + -from lcm.dtypes import safe_to_int32 +@pytest.fixture(name="x64_enabled") +def _fixture_x64_enabled() -> Iterator[None]: + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=True) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) def test_safe_to_int32_casts_python_int_in_range() -> None: @@ -39,3 +62,62 @@ def test_safe_to_int32_raises_on_underflow() -> None: """A Python int below int32 min raises `ValueError` naming the leaf.""" with pytest.raises(ValueError, match="offset"): safe_to_int32(-(2**40), name="offset") + + +def test_canonical_float_dtype_is_float32_under_no_x64( + x64_disabled: None, # noqa: ARG001 +) -> None: + """`canonical_float_dtype()` is `float32` when `jax_enable_x64=False`.""" + assert canonical_float_dtype() == jnp.float32 + + +def test_canonical_float_dtype_is_float64_under_x64( + x64_enabled: None, # noqa: ARG001 +) -> None: + """`canonical_float_dtype()` is `float64` when `jax_enable_x64=True`.""" + assert canonical_float_dtype() == jnp.float64 + + +def test_safe_to_float_dtype_casts_python_float_to_canonical( + x64_disabled: None, # noqa: ARG001 +) -> None: + """A Python float lands at `float32` under no-x64.""" + out = safe_to_float_dtype(0.5, name="x") + assert out.dtype == jnp.float32 + assert float(out) == 0.5 + + +def test_safe_to_float_dtype_casts_float64_array_to_float32( + x64_disabled: None, # noqa: ARG001 +) -> None: + """A `float64` array within float32 range is downcast to `float32`.""" + arr = jnp.asarray([0.1, 0.2, 0.3], dtype=jnp.float64) + out = safe_to_float_dtype(arr, name="x") + assert out.dtype == jnp.float32 + + +def test_safe_to_float_dtype_passes_array_through_under_x64( + x64_enabled: None, # noqa: ARG001 +) -> None: + """Under x64, a `float64` array is preserved (no down-cast required).""" + arr = jnp.asarray([0.1, 0.2, 0.3], dtype=jnp.float64) + out = safe_to_float_dtype(arr, name="x") + assert out.dtype == jnp.float64 + + +def test_safe_to_float_dtype_raises_on_overflow_when_downcasting( + x64_disabled: None, # noqa: ARG001 +) -> None: + """A `float64` value above float32 max raises `OverflowError`, naming the leaf.""" + big = 1e40 + with pytest.raises(OverflowError, match="big_param"): + safe_to_float_dtype(big, name="big_param") + + +def test_safe_to_float_dtype_no_overflow_check_when_upcasting( + x64_enabled: None, # noqa: ARG001 +) -> None: + """Casting `float32` -> `float64` (up) skips the overflow check.""" + arr = jnp.asarray([0.1, 0.2], dtype=jnp.float32) + out = safe_to_float_dtype(arr, name="x") + assert out.dtype == jnp.float64 diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py new file mode 100644 index 00000000..e1edbc69 --- /dev/null +++ b/tests/test_float_dtype_invariants.py @@ -0,0 +1,226 @@ +"""Float dtypes follow `canonical_float_dtype()` across pylcm boundaries.""" + +from collections.abc import Iterator +from types import MappingProxyType + +import jax.numpy as jnp +import pytest +from jax import config as jax_config + +from lcm.dtypes import canonical_float_dtype +from lcm.grids import IrregSpacedGrid, LinSpacedGrid, LogSpacedGrid +from lcm.params import MappingLeaf +from lcm.params.processing import process_params +from lcm.simulation.initial_conditions import build_initial_states +from tests.test_models.deterministic.regression import ( + RegimeId, + get_model, + get_params, +) + + +@pytest.fixture(name="x64_disabled") +def _fixture_x64_disabled() -> Iterator[None]: + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=False) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) + + +def test_build_initial_states_continuous_state_cast_to_canonical_dtype( + x64_disabled: None, # noqa: ARG001 +) -> None: + """Continuous initial states land at `canonical_float_dtype()` for any input.""" + model = get_model(n_periods=3) + # User passes float64 arrays under x64=False — should be cast to float32. + initial_states = { + "wealth": jnp.asarray([20.0, 50.0], dtype=jnp.float64), + "age": jnp.asarray([18.0, 18.0], dtype=jnp.float64), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + target = canonical_float_dtype() + for key, arr in flat.items(): + if arr.dtype.kind == "f": + assert arr.dtype == target, ( + f"Initial state {key} has dtype {arr.dtype}, expected {target}." + ) + + +def test_build_initial_states_int_input_for_continuous_state_cast_to_canonical( + x64_disabled: None, # noqa: ARG001 +) -> None: + """Int initial-condition arrays for continuous states land at canonical float.""" + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.asarray([20, 50], dtype=jnp.int32), + "age": jnp.asarray([18, 18], dtype=jnp.int32), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + target = canonical_float_dtype() + # All non-discrete state entries should be canonical-float now. + for key, arr in flat.items(): + if "wealth" in key or "age" in key: + assert arr.dtype == target, ( + f"Continuous state {key} has dtype {arr.dtype}, expected {target}." + ) + + +def test_build_initial_states_missing_continuous_fallback_dtype_is_canonical( + x64_disabled: None, # noqa: ARG001 +) -> None: + """Missing continuous states fall back to `nan` at the canonical float dtype.""" + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.asarray([20.0, 50.0]), + "age": jnp.asarray([18.0, 18.0]), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + # Find a fallback-NaN entry and check its dtype. + target = canonical_float_dtype() + nan_entries = [arr for arr in flat.values() if arr.dtype.kind == "f"] + assert nan_entries # sanity + for arr in nan_entries: + assert arr.dtype == target + + +def test_process_params_casts_float64_array_to_canonical_under_no_x64( + x64_disabled: None, # noqa: ARG001 +) -> None: + """A `float64` array param is downcast to `float32` under `jax_enable_x64=False`.""" + template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) + user_params = { + "regime_a": {"schedule": jnp.asarray([0.1, 0.2, 0.3], dtype=jnp.float64)} + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + schedule = out["regime_a"]["schedule"] + assert schedule.dtype == jnp.float32 # ty: ignore[unresolved-attribute] + + +def test_process_params_passes_python_float_through_for_jax_weak_typing( + x64_disabled: None, # noqa: ARG001 +) -> None: + """Python `float` params stay weak-typed so JAX promotes them per call site.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"discount_factor": "float"})} + ) + user_params = {"regime_a": {"discount_factor": 0.95}} + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + # Python float stays Python float; JAX weak-typing handles promotion at JIT. + assert out["regime_a"]["discount_factor"] == 0.95 + assert isinstance(out["regime_a"]["discount_factor"], float) + + +def test_process_params_float_array_overflow_raises_with_qualified_name( + x64_disabled: None, # noqa: ARG001 +) -> None: + """An out-of-float32 float64 *array* raises naming the qualified leaf.""" + template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) + user_params = { + "regime_a": {"schedule": jnp.asarray([0.0, 1e40], dtype=jnp.float64)} + } + + with pytest.raises(OverflowError, match="schedule"): + process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + +def test_simulate_state_pool_dtype_stable_across_periods( + x64_disabled: None, # noqa: ARG001 +) -> None: + """A multi-period simulate keeps every state's dtype stable across periods.""" + n_periods = 4 + model = get_model(n_periods=n_periods) + params = get_params(n_periods=n_periods) + initial = { + "wealth": jnp.asarray([20.0, 50.0, 80.0]), + "age": jnp.asarray([18.0, 18.0, 18.0]), + "regime": jnp.asarray([RegimeId.working_life] * 3), + } + + result = model.simulate( + params=params, period_to_regime_to_V_arr=None, initial_conditions=initial + ) + + # Build the per-period state-dtype matrix and assert stability. + seen: dict[str, set] = {} + for period_data in result.raw_results.values(): + for snap in period_data.values(): + for state_name, arr in snap.states.items(): + seen.setdefault(state_name, set()).add(arr.dtype) + for state_name, dtypes in seen.items(): + assert len(dtypes) == 1, f"State {state_name} drifted across periods: {dtypes}" + + +def test_solve_v_arrays_at_canonical_float_dtype( + x64_disabled: None, # noqa: ARG001 +) -> None: + """Every V-array returned by `model.solve()` is at `canonical_float_dtype()`.""" + model = get_model(n_periods=3) + period_to_regime_to_V_arr = model.solve(params=get_params(n_periods=3)) + target = canonical_float_dtype() + for period_v in period_to_regime_to_V_arr.values(): + for regime_name, v_arr in period_v.items(): + assert v_arr.dtype == target, ( + f"V[{regime_name}] dtype is {v_arr.dtype}, expected {target}." + ) + + +def test_continuous_grid_to_jax_dtype_is_canonical( + x64_disabled: None, # noqa: ARG001 +) -> None: + """Continuous grid `to_jax()` returns canonical-float arrays.""" + target = canonical_float_dtype() + assert LinSpacedGrid(start=0, stop=1, n_points=5).to_jax().dtype == target + assert LogSpacedGrid(start=1, stop=10, n_points=5).to_jax().dtype == target + assert IrregSpacedGrid(points=(0.0, 0.5, 1.0)).to_jax().dtype == target + + +def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( + x64_disabled: None, # noqa: ARG001 +) -> None: + """`MappingLeaf` float arrays land at `canonical_float_dtype()`.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "MappingLeaf"})} + ) + user_params = { + "regime_a": { + "sched": MappingLeaf( + { + "low": jnp.asarray([0.1, 0.2], dtype=jnp.float64), + "high": jnp.asarray([0.5, 0.7], dtype=jnp.float64), + } + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + leaf = out["regime_a"]["sched"] + assert leaf.data["low"].dtype == jnp.float32 # ty: ignore[unresolved-attribute] + assert leaf.data["high"].dtype == jnp.float32 # ty: ignore[unresolved-attribute] diff --git a/tests/test_validate_param_types.py b/tests/test_validate_param_types.py index 939950ce..841b02b9 100644 --- a/tests/test_validate_param_types.py +++ b/tests/test_validate_param_types.py @@ -2,10 +2,8 @@ import jax.numpy as jnp import numpy as np -import pytest from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical -from lcm.exceptions import InvalidParamsError @categorical(ordered=True) @@ -49,11 +47,11 @@ def _make_model() -> Model: ) -def test_numpy_array_param_rejected() -> None: - """Passing a numpy array as a param should raise InvalidParamsError.""" +def test_numpy_array_param_accepted_and_normalised() -> None: + """Numpy arrays are auto-converted to JAX at the params boundary.""" model = _make_model() - with pytest.raises(InvalidParamsError, match=r"numpy\.ndarray"): - model.solve(params={"bonus": np.array(1.0), "discount_factor": 0.95}) # ty: ignore[invalid-argument-type] + # Should solve cleanly; the boundary cast normalises numpy -> JAX. + model.solve(params={"bonus": np.array(1.0), "discount_factor": 0.95}) # ty: ignore[invalid-argument-type] def test_jax_array_param_accepted() -> None: From ef180d037697d5b0417ba4e07dedca310eb9b47f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 18:30:08 +0200 Subject: [PATCH 02/22] Fix Package B 32-bit precision test: build float overflow fixture with numpy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same fix as Package A's int side. Under `jax_enable_x64=False`, `jnp.asarray(..., dtype=jnp.float64)` of `1e40` saturates to `±inf` at construction time before `safe_to_float_dtype` ever sees it. Use `np.asarray(..., dtype=np.float64)` so the value reaches the boundary helper as a real float64 and the helper produces its own qualified-name `OverflowError`. --- tests/test_float_dtype_invariants.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py index e1edbc69..53b81994 100644 --- a/tests/test_float_dtype_invariants.py +++ b/tests/test_float_dtype_invariants.py @@ -4,6 +4,7 @@ from types import MappingProxyType import jax.numpy as jnp +import numpy as np import pytest from jax import config as jax_config @@ -136,13 +137,13 @@ def test_process_params_float_array_overflow_raises_with_qualified_name( ) -> None: """An out-of-float32 float64 *array* raises naming the qualified leaf.""" template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) - user_params = { - "regime_a": {"schedule": jnp.asarray([0.0, 1e40], dtype=jnp.float64)} - } + # Numpy here: under `jax_enable_x64=False`, `jnp.asarray(..., dtype=float64)` + # of an out-of-float32 value saturates to ±inf at construction time. + user_params = {"regime_a": {"schedule": np.asarray([0.0, 1e40], dtype=np.float64)}} with pytest.raises(OverflowError, match="schedule"): process_params( - params=user_params, + params=user_params, # ty: ignore[invalid-argument-type] params_template=template, # ty: ignore[invalid-argument-type] ) From 09f3d034f393742e0e54f5eadcc79dd4ab3bd939 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 19:42:48 +0200 Subject: [PATCH 03/22] Address PR #345 review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Source: - `_update_states_for_subjects`: drop `next_state_values.astype(target_dtype)`. With every input boundary pinned to the canonical dtype, a pure-JAX user transition consuming canonical inputs returns canonical outputs and the cast is a no-op. The only case it covered — a user transition that explicitly produces a non-canonical dtype — is now surfaced loudly via AOT cache mismatch instead of silently coerced. - `safe_to_float_dtype` docstring: bullet list for the cast-direction enumeration (down-cast vs up/same-width). - `process_params` module + function docstring: extend to mention float cast and `OverflowError` alongside the int cast and `ValueError`. - `_cast_leaves_to_canonical_dtype`: rephrase `pd.Series` justification to describe the *property* (multi-index structure) rather than the internal helper that handles it. - `convert_series_in_params` and `initial_conditions_from_dataframe`: route every `pd.Series` -> JAX-array conversion through `canonical_float_dtype()` so the boundary contract holds for pandas-backed params and pandas-backed initial conditions. - `dtypes.py` module docstring: drop the contradictory note about downstream `.astype` casts (downstream no longer casts). Tests: - Move `x64_disabled` / `x64_enabled` fixtures to `tests/conftest.py` (were duplicated across two test files). - `test_safe_to_float_dtype_casts_float64_array_to_float32`: switch to `np.asarray(..., dtype=np.float64)`. With `jnp.asarray`, JAX silently truncated to `float32` at construction time under no-x64, so the helper's down-cast path was never exercised. - Same fix applied to every `tests/test_float_dtype_invariants.py` test that built a float64 input under the `x64_disabled` fixture. - Split / parametrise multi-assertion tests in `test_float_dtype_invariants.py`: continuous-grid `to_jax` over `LinSpacedGrid` / `LogSpacedGrid` / `IrregSpacedGrid`; `MappingLeaf` float keys parametrised over `["low", "high"]`. - Add `test_process_params_casts_float_array_inside_sequence_leaf_to_canonical` to mirror the `MappingLeaf` test (parametrised over `[0, 1]`). - Add `test_build_initial_states_missing_continuous_fallback_values_are_nan` asserting the fallback is actually NaN (not just at canonical dtype). - `test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64`: assert against the literal `jnp.float32` instead of `canonical_float_dtype()` so a future grid implementation that hardcodes `float64` would surface here (current form passed trivially because both sides are driven by the same x64 flag). - `test_simulate_state_pool_dtype_stable_across_periods` and `test_solve_v_arrays_at_canonical_float_dtype`: collect violations into a single dict and assert non-emptiness, so failures still name the offending state / V-array but the test has one assertion. - `tests/test_validate_param_types.py`: rewrite the three numpy / JAX / Python-scalar tests to assert the *normalised* leaf type and dtype via `_process_params`, not just "no exception raised". Update module docstring to describe the current behaviour. - `tests/test_int_dtype_invariants.py::test_update_states_for_subjects_*`: rewrite as a positive same-dtype round-trip — the previous form pinned a guard that the cast removal eliminated. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/dtypes.py | 22 ++-- src/lcm/pandas_utils.py | 11 +- src/lcm/params/processing.py | 37 ++++-- src/lcm/simulation/transitions.py | 8 +- tests/conftest.py | 26 +++- tests/test_dtypes.py | 34 ++--- tests/test_float_dtype_invariants.py | 185 ++++++++++++++++----------- tests/test_int_dtype_invariants.py | 18 ++- tests/test_validate_param_types.py | 40 ++++-- 9 files changed, 233 insertions(+), 148 deletions(-) diff --git a/src/lcm/dtypes.py b/src/lcm/dtypes.py index 4edc7bff..7345df1f 100644 --- a/src/lcm/dtypes.py +++ b/src/lcm/dtypes.py @@ -3,11 +3,9 @@ Used at every API boundary that accepts user data (params, initial conditions, regime-id arrays) — always called from Python, never inside JIT. Each helper validates that the value fits the target dtype and -raises a clearly-named error if not. - -Casts further down the simulate stack (e.g. transition outputs landing -in the state pool) use plain `.astype` and rely on the boundary cast -above them having already pinned the canonical dtype. +raises a clearly-named error if not. Once an input has crossed the +boundary it carries the canonical dtype unchanged through the simulate +stack; downstream code does not re-cast. """ import jax @@ -63,12 +61,14 @@ def safe_to_int32(value: object, *, name: str) -> Array: def safe_to_float_dtype(value: object, *, name: str) -> Array: """Cast a scalar, sequence, or array to the canonical float dtype. - When the cast is *down* (float64 -> float32 under `jax_enable_x64=False`), - check that no element exceeds `float32` magnitude — raising - `OverflowError` if so rather than letting JAX silently saturate to - `±inf`. Up-casts and same-width casts skip the range check; precision - loss within range is *not* an error (it is an inherent consequence of - `jax_enable_x64=False`). + Range check fires only on a down-cast: + + - Down-cast (float64 → float32 under `jax_enable_x64=False`): raise + `OverflowError` if any element exceeds float32 magnitude rather + than letting JAX silently saturate to ``±inf``. + - Up-cast or same-width cast: skip the range check. Precision loss + within range is not an error — it is an inherent consequence of + `jax_enable_x64=False`. Args: value: A Python float, numpy/JAX scalar, or array-like. diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index cf1b4cb0..adc04eac 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -12,6 +12,7 @@ from jax import Array from lcm.ages import PSEUDO_STATE_NAMES, AgeGrid +from lcm.dtypes import canonical_float_dtype from lcm.grids import DiscreteGrid, IrregSpacedGrid from lcm.params import MappingLeaf from lcm.params.sequence_leaf import SequenceLeaf @@ -149,7 +150,7 @@ def initial_conditions_from_dataframe( # noqa: C901 initial_conditions: dict[str, Array] = { col: jnp.array(arr, dtype=jnp.int32) if col in discrete_state_names - else jnp.array(arr) + else jnp.array(arr, dtype=canonical_float_dtype()) for col, arr in result_arrays.items() } initial_conditions["regime"] = jnp.array( @@ -371,12 +372,12 @@ def array_from_series( """ if func is None: - return jnp.array(sr.to_numpy(), dtype=float) + return jnp.array(sr.to_numpy(), dtype=canonical_float_dtype()) indexing_params = _get_func_indexing_params(func=func, array_param_name=param_name) if not indexing_params: - return jnp.array(sr.to_numpy(), dtype=float) + return jnp.array(sr.to_numpy(), dtype=canonical_float_dtype()) grids = _resolve_categoricals( regimes=regimes, @@ -706,7 +707,7 @@ def _scatter_series( shape = [m.size for m in level_mappings] if len(series) == 0: - return jnp.full(shape, fill_value) + return jnp.full(shape, fill_value, dtype=canonical_float_dtype()) index_arrays = [ _map_level( @@ -717,7 +718,7 @@ def _scatter_series( result = np.full(shape, fill_value) result[tuple(index_arrays)] = series.to_numpy() - return jnp.array(result) + return jnp.array(result, dtype=canonical_float_dtype()) def _map_level(*, mapping: _LevelMapping, level_values: pd.Index) -> np.ndarray: diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index 3020283f..eecd7229 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -1,10 +1,17 @@ """Process user-provided params into internal params. `process_params` resolves user-supplied parameters against the model's -template, then runs a boundary-cast pass that normalises typed integer -leaves to `jnp.int32` (and integer arrays inside `MappingLeaf` / -`SequenceLeaf`). Out-of-range values surface as `ValueError` with the -offending leaf's qualified name. +template, then runs a boundary-cast pass that normalises typed numeric +leaves to canonical pylcm dtypes: + +- Typed integer leaves (and integer arrays inside `MappingLeaf` / + `SequenceLeaf`) cast to `jnp.int32`. Out-of-range values surface as + `ValueError`. +- Typed float leaves (and float arrays inside `MappingLeaf` / + `SequenceLeaf`) cast to `canonical_float_dtype()`. Down-cast overflow + surfaces as `OverflowError`. + +Both errors name the offending leaf via its qualified path. """ from collections.abc import Mapping @@ -47,11 +54,17 @@ def process_params( - Regime level: `{"regime_0": {"arg_0": 0.0}}` — propagates within regime_0 - Function level: `{"regime_0": {"func": {"arg_0": 0.0}}}` — direct specification - The output always matches the params_template skeleton. Typed integer + The output always matches the params_template skeleton. Typed numeric arrays in the user input — including those inside `MappingLeaf` / - `SequenceLeaf` containers — are cast to `jnp.int32` so the AOT signature - is stable across calls; Python scalars pass through to keep JAX weak- - typing semantics. + `SequenceLeaf` containers — are cast to canonical pylcm dtypes so the + AOT signature is stable across calls: + + - Integer arrays cast to `jnp.int32`. + - Float arrays cast to `canonical_float_dtype()`. + + Python scalars (`int` / `float` / `bool`) pass through to keep JAX + weak-typing semantics, and `pd.Series` leaves pass through to their + dedicated multi-index reshaper. Args: params: User-provided parameters dictionary. @@ -65,6 +78,8 @@ def process_params( InvalidNameError: If the same parameter is specified at multiple levels. ValueError: If a typed integer leaf carries a value outside the int32 range; the message names the offending parameter qname. + OverflowError: If a typed float leaf would saturate to `±inf` on + down-cast to `float32`; the message names the offending qname. """ return broadcast_to_template(params=params, template=params_template, required=True) @@ -157,9 +172,9 @@ def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: A - Python `int` / `float` / `bool` scalars — JAX's weak-typing rules let them promote to whatever dtype the surrounding operation needs (e.g. `discount_factor: 1` in a float-typed function). - - `pd.Series` leaves — reshaped by `convert_series_in_params` based - on their multi-index, so flattening here via `np.asarray` would - lose that structure. + - `pd.Series` leaves — they carry multi-index structure that + `np.asarray` would flatten away, so the dtype cast is handled by + the dedicated multi-index reshaper instead. - Non-numeric typed leaves. """ if isinstance(value, MappingLeaf): diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index ffd77415..b7ad15f6 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -287,15 +287,9 @@ def _update_states_for_subjects( for target, target_next_states in computed_next_states.items(): for next_state_name, next_state_values in target_next_states.items(): state_name = f"{target}__{next_state_name.removeprefix('next_')}" - target_dtype = all_states[state_name].dtype - # Pin transition outputs to the storage dtype before `jnp.where`. - # Initial-condition boundary casts ensure storage already reflects - # the canonical dtype (int32 for discrete, `canonical_float_dtype()` - # for continuous), so this cast is value-preserving for any - # well-typed user transition. updated_states[state_name] = jnp.where( subject_indices, - next_state_values.astype(target_dtype), + next_state_values, all_states[state_name], ) diff --git a/tests/conftest.py b/tests/conftest.py index 5d9e627a..b889345f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ +from collections.abc import Iterator from dataclasses import make_dataclass import pytest +from jax import config as jax_config # Module-level precision settings (updated by pytest_configure based on --precision) X64_ENABLED: bool = True @@ -28,11 +30,31 @@ def pytest_configure(config): X64_ENABLED = config.getoption("--precision") == "64" DECIMAL_PRECISION = 12 if X64_ENABLED else 5 - from jax import config as jax_config # noqa: PLC0415 - jax_config.update("jax_enable_x64", val=X64_ENABLED) @pytest.fixture(scope="session") def binary_category_class(): return make_dataclass("BinaryCategoryClass", [("cat0", int, 0), ("cat1", int, 1)]) + + +@pytest.fixture(name="x64_disabled") +def _fixture_x64_disabled() -> Iterator[None]: + """Run the test with `jax_enable_x64=False`, restoring afterwards.""" + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=False) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) + + +@pytest.fixture(name="x64_enabled") +def _fixture_x64_enabled() -> Iterator[None]: + """Run the test with `jax_enable_x64=True`, restoring afterwards.""" + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=True) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 49dec722..e42e4ff0 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -1,35 +1,12 @@ """Tests for `lcm.dtypes` boundary-cast helpers.""" -from collections.abc import Iterator - import jax.numpy as jnp import numpy as np import pytest -from jax import config as jax_config from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype, safe_to_int32 -@pytest.fixture(name="x64_disabled") -def _fixture_x64_disabled() -> Iterator[None]: - previous = jax_config.read("jax_enable_x64") - jax_config.update("jax_enable_x64", val=False) - try: - yield - finally: - jax_config.update("jax_enable_x64", val=previous) - - -@pytest.fixture(name="x64_enabled") -def _fixture_x64_enabled() -> Iterator[None]: - previous = jax_config.read("jax_enable_x64") - jax_config.update("jax_enable_x64", val=True) - try: - yield - finally: - jax_config.update("jax_enable_x64", val=previous) - - @pytest.mark.parametrize( "value", [7, np.asarray([0, 1, -3], dtype=np.int64)], @@ -105,8 +82,15 @@ def test_safe_to_float_dtype_casts_python_float_to_canonical( def test_safe_to_float_dtype_casts_float64_array_to_float32( x64_disabled: None, # noqa: ARG001 ) -> None: - """A `float64` array within float32 range is downcast to `float32`.""" - arr = jnp.asarray([0.1, 0.2, 0.3], dtype=jnp.float64) + """A `float64` array within float32 range is downcast to `float32`. + + Build the input with `np.asarray` rather than `jnp.asarray` — under + `jax_enable_x64=False`, JAX silently truncates a `float64` request + to `float32` at construction time, so a JAX-built input would never + reach the helper as `float64` and the down-cast path would not be + exercised. + """ + arr = np.asarray([0.1, 0.2, 0.3], dtype=np.float64) out = safe_to_float_dtype(arr, name="x") assert out.dtype == jnp.float32 diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py index 53b81994..d9899cc7 100644 --- a/tests/test_float_dtype_invariants.py +++ b/tests/test_float_dtype_invariants.py @@ -1,17 +1,16 @@ """Float dtypes follow `canonical_float_dtype()` across pylcm boundaries.""" -from collections.abc import Iterator from types import MappingProxyType import jax.numpy as jnp import numpy as np import pytest -from jax import config as jax_config from lcm.dtypes import canonical_float_dtype from lcm.grids import IrregSpacedGrid, LinSpacedGrid, LogSpacedGrid from lcm.params import MappingLeaf from lcm.params.processing import process_params +from lcm.params.sequence_leaf import SequenceLeaf from lcm.simulation.initial_conditions import build_initial_states from tests.test_models.deterministic.regression import ( RegimeId, @@ -20,42 +19,26 @@ ) -@pytest.fixture(name="x64_disabled") -def _fixture_x64_disabled() -> Iterator[None]: - previous = jax_config.read("jax_enable_x64") - jax_config.update("jax_enable_x64", val=False) - try: - yield - finally: - jax_config.update("jax_enable_x64", val=previous) - - -def test_build_initial_states_continuous_state_cast_to_canonical_dtype( +def test_build_initial_states_casts_user_float64_to_canonical( x64_disabled: None, # noqa: ARG001 ) -> None: - """Continuous initial states land at `canonical_float_dtype()` for any input.""" + """A float64 continuous initial state lands at `canonical_float_dtype()`.""" model = get_model(n_periods=3) - # User passes float64 arrays under x64=False — should be cast to float32. initial_states = { - "wealth": jnp.asarray([20.0, 50.0], dtype=jnp.float64), - "age": jnp.asarray([18.0, 18.0], dtype=jnp.float64), + "wealth": np.asarray([20.0, 50.0], dtype=np.float64), + "age": np.asarray([18.0, 18.0], dtype=np.float64), } flat = build_initial_states( - initial_states=initial_states, + initial_states=initial_states, # ty: ignore[invalid-argument-type] internal_regimes=model.internal_regimes, ) - target = canonical_float_dtype() - for key, arr in flat.items(): - if arr.dtype.kind == "f": - assert arr.dtype == target, ( - f"Initial state {key} has dtype {arr.dtype}, expected {target}." - ) + assert flat["working_life__wealth"].dtype == canonical_float_dtype() -def test_build_initial_states_int_input_for_continuous_state_cast_to_canonical( +def test_build_initial_states_casts_user_int_to_canonical( x64_disabled: None, # noqa: ARG001 ) -> None: - """Int initial-condition arrays for continuous states land at canonical float.""" + """A continuous initial state given as int32 lands at `canonical_float_dtype()`.""" model = get_model(n_periods=3) initial_states = { "wealth": jnp.asarray([20, 50], dtype=jnp.int32), @@ -65,47 +48,54 @@ def test_build_initial_states_int_input_for_continuous_state_cast_to_canonical( initial_states=initial_states, internal_regimes=model.internal_regimes, ) - target = canonical_float_dtype() - # All non-discrete state entries should be canonical-float now. - for key, arr in flat.items(): - if "wealth" in key or "age" in key: - assert arr.dtype == target, ( - f"Continuous state {key} has dtype {arr.dtype}, expected {target}." - ) + assert flat["working_life__wealth"].dtype == canonical_float_dtype() def test_build_initial_states_missing_continuous_fallback_dtype_is_canonical( x64_disabled: None, # noqa: ARG001 ) -> None: - """Missing continuous states fall back to `nan` at the canonical float dtype.""" + """A missing continuous state falls back to a canonical-dtype array.""" model = get_model(n_periods=3) - initial_states = { - "wealth": jnp.asarray([20.0, 50.0]), - "age": jnp.asarray([18.0, 18.0]), - } + # Supply a placeholder state to set n_subjects without touching `wealth`. flat = build_initial_states( - initial_states=initial_states, + initial_states={"placeholder": jnp.asarray([0.0, 0.0])}, internal_regimes=model.internal_regimes, ) - # Find a fallback-NaN entry and check its dtype. - target = canonical_float_dtype() - nan_entries = [arr for arr in flat.values() if arr.dtype.kind == "f"] - assert nan_entries # sanity - for arr in nan_entries: - assert arr.dtype == target + assert flat["working_life__wealth"].dtype == canonical_float_dtype() + + +def test_build_initial_states_missing_continuous_fallback_values_are_nan( + x64_disabled: None, # noqa: ARG001 +) -> None: + """A missing continuous state falls back to an all-NaN array. + + Pinning only the dtype would let a regression that fills the fallback + with zeros (or anything else representable) pass; assert the values. + """ + model = get_model(n_periods=3) + flat = build_initial_states( + initial_states={"placeholder": jnp.asarray([0.0, 0.0])}, + internal_regimes=model.internal_regimes, + ) + assert bool(jnp.all(jnp.isnan(flat["working_life__wealth"]))) def test_process_params_casts_float64_array_to_canonical_under_no_x64( x64_disabled: None, # noqa: ARG001 ) -> None: - """A `float64` array param is downcast to `float32` under `jax_enable_x64=False`.""" + """A `float64` array param is downcast to `float32` under `jax_enable_x64=False`. + + Build with `np.asarray` rather than `jnp.asarray` — the JAX builder + silently truncates to `float32` under no-x64 at construction time, so a + JAX-built input would never reach the helper as `float64`. + """ template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) user_params = { - "regime_a": {"schedule": jnp.asarray([0.1, 0.2, 0.3], dtype=jnp.float64)} + "regime_a": {"schedule": np.asarray([0.1, 0.2, 0.3], dtype=np.float64)} } out = process_params( - params=user_params, + params=user_params, # ty: ignore[invalid-argument-type] params_template=template, # ty: ignore[invalid-argument-type] ) @@ -127,7 +117,6 @@ def test_process_params_passes_python_float_through_for_jax_weak_typing( params_template=template, # ty: ignore[invalid-argument-type] ) - # Python float stays Python float; JAX weak-typing handles promotion at JIT. assert out["regime_a"]["discount_factor"] == 0.95 assert isinstance(out["regime_a"]["discount_factor"], float) @@ -135,10 +124,8 @@ def test_process_params_passes_python_float_through_for_jax_weak_typing( def test_process_params_float_array_overflow_raises_with_qualified_name( x64_disabled: None, # noqa: ARG001 ) -> None: - """An out-of-float32 float64 *array* raises naming the qualified leaf.""" + """An out-of-float32 float64 array raises naming the qualified leaf.""" template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) - # Numpy here: under `jax_enable_x64=False`, `jnp.asarray(..., dtype=float64)` - # of an out-of-float32 value saturates to ±inf at construction time. user_params = {"regime_a": {"schedule": np.asarray([0.0, 1e40], dtype=np.float64)}} with pytest.raises(OverflowError, match="schedule"): @@ -151,7 +138,12 @@ def test_process_params_float_array_overflow_raises_with_qualified_name( def test_simulate_state_pool_dtype_stable_across_periods( x64_disabled: None, # noqa: ARG001 ) -> None: - """A multi-period simulate keeps every state's dtype stable across periods.""" + """A multi-period simulate keeps every state's dtype stable across periods. + + The intended invariant is per-state stability; failing on any single + state still gives an actionable signal because the assertion message + names the offending state and its observed dtypes. + """ n_periods = 4 model = get_model(n_periods=n_periods) params = get_params(n_periods=n_periods) @@ -165,14 +157,13 @@ def test_simulate_state_pool_dtype_stable_across_periods( params=params, period_to_regime_to_V_arr=None, initial_conditions=initial ) - # Build the per-period state-dtype matrix and assert stability. seen: dict[str, set] = {} for period_data in result.raw_results.values(): for snap in period_data.values(): for state_name, arr in snap.states.items(): seen.setdefault(state_name, set()).add(arr.dtype) - for state_name, dtypes in seen.items(): - assert len(dtypes) == 1, f"State {state_name} drifted across periods: {dtypes}" + drifted = {name: dtypes for name, dtypes in seen.items() if len(dtypes) != 1} + assert not drifted, f"States drifted across periods: {drifted}" def test_solve_v_arrays_at_canonical_float_dtype( @@ -182,24 +173,42 @@ def test_solve_v_arrays_at_canonical_float_dtype( model = get_model(n_periods=3) period_to_regime_to_V_arr = model.solve(params=get_params(n_periods=3)) target = canonical_float_dtype() - for period_v in period_to_regime_to_V_arr.values(): - for regime_name, v_arr in period_v.items(): - assert v_arr.dtype == target, ( - f"V[{regime_name}] dtype is {v_arr.dtype}, expected {target}." - ) + wrong = { + (period, regime_name): v_arr.dtype + for period, period_v in period_to_regime_to_V_arr.items() + for regime_name, v_arr in period_v.items() + if v_arr.dtype != target + } + assert not wrong, f"V-arrays not at {target}: {wrong}" -def test_continuous_grid_to_jax_dtype_is_canonical( +@pytest.mark.parametrize( + "grid", + [ + LinSpacedGrid(start=0, stop=1, n_points=5), + LogSpacedGrid(start=1, stop=10, n_points=5), + IrregSpacedGrid(points=(0.0, 0.5, 1.0)), + ], + ids=["linspaced", "logspaced", "irregspaced"], +) +def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( + grid: LinSpacedGrid | LogSpacedGrid | IrregSpacedGrid, x64_disabled: None, # noqa: ARG001 ) -> None: - """Continuous grid `to_jax()` returns canonical-float arrays.""" - target = canonical_float_dtype() - assert LinSpacedGrid(start=0, stop=1, n_points=5).to_jax().dtype == target - assert LogSpacedGrid(start=1, stop=10, n_points=5).to_jax().dtype == target - assert IrregSpacedGrid(points=(0.0, 0.5, 1.0)).to_jax().dtype == target + """Continuous grid `to_jax()` materialises at `float32` under no-x64. + Asserts the concrete target dtype rather than `canonical_float_dtype()` + so the test fails if a future grid implementation hardcodes `float64` + (which JAX would silently truncate to `float32` under no-x64; the + helper-side comparison would mask that, the literal-side comparison + surfaces it). + """ + assert grid.to_jax().dtype == jnp.float32 + +@pytest.mark.parametrize("key", ["low", "high"]) def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( + key: str, x64_disabled: None, # noqa: ARG001 ) -> None: """`MappingLeaf` float arrays land at `canonical_float_dtype()`.""" @@ -210,8 +219,8 @@ def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( "regime_a": { "sched": MappingLeaf( { - "low": jnp.asarray([0.1, 0.2], dtype=jnp.float64), - "high": jnp.asarray([0.5, 0.7], dtype=jnp.float64), + "low": np.asarray([0.1, 0.2], dtype=np.float64), + "high": np.asarray([0.5, 0.7], dtype=np.float64), } ) } @@ -222,6 +231,38 @@ def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( params_template=template, # ty: ignore[invalid-argument-type] ) - leaf = out["regime_a"]["sched"] - assert leaf.data["low"].dtype == jnp.float32 # ty: ignore[unresolved-attribute] - assert leaf.data["high"].dtype == jnp.float32 # ty: ignore[unresolved-attribute] + assert ( + out["regime_a"]["sched"].data[key].dtype # ty: ignore[unresolved-attribute] + == jnp.float32 + ) + + +@pytest.mark.parametrize("index", [0, 1]) +def test_process_params_casts_float_array_inside_sequence_leaf_to_canonical( + index: int, + x64_disabled: None, # noqa: ARG001 +) -> None: + """`SequenceLeaf` float arrays land at `canonical_float_dtype()`.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "SequenceLeaf"})} + ) + user_params = { + "regime_a": { + "sched": SequenceLeaf( + [ + np.asarray([0.1, 0.2], dtype=np.float64), + np.asarray([0.5, 0.7], dtype=np.float64), + ] + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + assert ( + out["regime_a"]["sched"].data[index].dtype # ty: ignore[unresolved-attribute] + == jnp.float32 + ) diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index d174d515..5072cdde 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -61,13 +61,23 @@ def test_missing_cat_code_is_int32_minimum() -> None: assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE -def test_update_states_for_subjects_preserves_storage_dtype() -> None: - """A transition that returns int64 cannot promote the storage pool to int64.""" +def test_update_states_for_subjects_keeps_same_dtype_round_trip() -> None: + """Canonical-dtype transition outputs round-trip through the state pool. + + With every input boundary pinned to the canonical dtype, a well-typed user + transition returns canonical-dtype outputs and `_update_states_for_subjects` + writes them through `jnp.where` without dtype change. This test pins the + contract for the int side; mixed-dtype inputs are out of scope — the + function does not defend against transitions that violate the canonical- + dtype invariant. + """ all_states = MappingProxyType( {"work__health": jnp.asarray([0, 1, 0, 1], dtype=jnp.int32)} ) - int64_next = jnp.asarray([1, 1, 1, 1], dtype=jnp.int64) - computed = MappingProxyType({"work": MappingProxyType({"next_health": int64_next})}) + next_values = jnp.asarray([1, 1, 1, 1], dtype=jnp.int32) + computed = MappingProxyType( + {"work": MappingProxyType({"next_health": next_values})} + ) subjects = jnp.asarray([True, False, True, False]) updated = _update_states_for_subjects( diff --git a/tests/test_validate_param_types.py b/tests/test_validate_param_types.py index 841b02b9..a37eee52 100644 --- a/tests/test_validate_param_types.py +++ b/tests/test_validate_param_types.py @@ -1,9 +1,17 @@ -"""Test that numpy arrays in params are rejected after processing.""" +"""Tests for params accepted at the boundary by `_validate_param_types`. + +After `process_params` casts typed numeric arrays to canonical pylcm +dtypes, every supported user input form (numpy arrays, JAX arrays, +Python scalars) reaches the validator as a JAX array or Python scalar +and is accepted. +""" import jax.numpy as jnp import numpy as np +from jax import Array from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.dtypes import canonical_float_dtype @categorical(ordered=True) @@ -47,20 +55,30 @@ def _make_model() -> Model: ) -def test_numpy_array_param_accepted_and_normalised() -> None: - """Numpy arrays are auto-converted to JAX at the params boundary.""" +def test_numpy_array_param_normalised_to_canonical_jax_array() -> None: + """A numpy array param is cast to a JAX array at `canonical_float_dtype()`.""" model = _make_model() - # Should solve cleanly; the boundary cast normalises numpy -> JAX. - model.solve(params={"bonus": np.array(1.0), "discount_factor": 0.95}) # ty: ignore[invalid-argument-type] + internal = model._process_params( + params={"bonus": np.asarray(1.0, dtype=np.float64), "discount_factor": 0.95} # ty: ignore[invalid-argument-type] + ) + bonus = internal["working"]["utility__bonus"] + assert isinstance(bonus, Array) + assert bonus.dtype == canonical_float_dtype() -def test_jax_array_param_accepted() -> None: - """JAX arrays should be accepted.""" +def test_jax_array_param_kept_at_canonical_dtype() -> None: + """A typed JAX array param is kept (or cast) at `canonical_float_dtype()`.""" model = _make_model() - model.solve(params={"bonus": jnp.array(1.0), "discount_factor": 0.95}) + internal = model._process_params( + params={"bonus": jnp.asarray(1.0), "discount_factor": 0.95} + ) + bonus = internal["working"]["utility__bonus"] + assert bonus.dtype == canonical_float_dtype() # ty: ignore[unresolved-attribute] -def test_python_scalar_param_accepted() -> None: - """Python scalars should be accepted.""" +def test_python_float_param_passed_through_for_weak_typing() -> None: + """A Python `float` param survives processing as a Python `float`.""" model = _make_model() - model.solve(params={"bonus": 1.0, "discount_factor": 0.95}) + internal = model._process_params(params={"bonus": 1.0, "discount_factor": 0.95}) + assert internal["working"]["utility__bonus"] == 1.0 + assert isinstance(internal["working"]["utility__bonus"], float) From b8dc4901c7a506c5c81e382559a7199e592ff034 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 22:19:30 +0200 Subject: [PATCH 04/22] bench_aca_baseline: pass pref_type_grid to create_benchmark_model aca-model dropped the `pref_type_grid` default on `create_benchmark_model`. Forward `DiscreteGrid(BenchmarkPrefType)` explicitly to keep the benchmark on its 2-type pref-type axis. --- benchmarks/bench_aca_baseline.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index a9364879..49705308 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -48,13 +48,19 @@ def _build() -> tuple[object, object, object]: """Build the aca-baseline model, params, and initial conditions.""" + from lcm import DiscreteGrid + + from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, get_benchmark_params, ) - model = create_benchmark_model(n_subjects=_N_SUBJECTS) + model = create_benchmark_model( + n_subjects=_N_SUBJECTS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 From cbe65bed56f76cf6142ecaf20c754341240322d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 20:20:09 +0000 Subject: [PATCH 05/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- benchmarks/bench_aca_baseline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 49705308..ca56b962 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -48,8 +48,6 @@ def _build() -> tuple[object, object, object]: """Build the aca-baseline model, params, and initial conditions.""" - from lcm import DiscreteGrid - from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, @@ -57,6 +55,8 @@ def _build() -> tuple[object, object, object]: get_benchmark_params, ) + from lcm import DiscreteGrid + model = create_benchmark_model( n_subjects=_N_SUBJECTS, pref_type_grid=DiscreteGrid(BenchmarkPrefType), From a6b0ac1b56faf840d17e276818050d59d4b9863a Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 09:50:38 +0200 Subject: [PATCH 06/22] bench_aca_baseline: hoist aca_model + lcm imports to module top Per AGENTS.md: no in-function imports. --- benchmarks/bench_aca_baseline.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index ca56b962..129923fd 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -40,23 +40,21 @@ import time import cloudpickle +from aca_model.agent.preferences import BenchmarkPrefType +from aca_model.benchmark import ( + create_benchmark_model, + get_benchmark_initial_conditions, + get_benchmark_params, +) from benchmarks import _gpu_mem +from lcm import DiscreteGrid _N_SUBJECTS = 1000 def _build() -> tuple[object, object, object]: """Build the aca-baseline model, params, and initial conditions.""" - from aca_model.agent.preferences import BenchmarkPrefType - from aca_model.benchmark import ( - create_benchmark_model, - get_benchmark_initial_conditions, - get_benchmark_params, - ) - - from lcm import DiscreteGrid - model = create_benchmark_model( n_subjects=_N_SUBJECTS, pref_type_grid=DiscreteGrid(BenchmarkPrefType), From 9d26643508dac00984e6f99e4d7eff5c97cbc078 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 10:38:40 +0200 Subject: [PATCH 07/22] _validate_param_types: drop dead branches post-whitelist After `cast_params_to_canonical_dtypes`, every leaf is either a JAX `Array` or a `MappingLeaf` / `SequenceLeaf`. The validator's Python scalar branch and duck-typed-array branch can never fire, so collapse the dispatch to the three live cases. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model_processing.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index 832f5e4b..4f149736 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -423,11 +423,11 @@ def _filter_kwargs_for_func( def _validate_param_types(internal_params: InternalParams) -> None: - """Raise if any param leaf is not a Python scalar or JAX array. + """Raise if any param leaf is not a JAX `Array` or container leaf. - After processing, every leaf value (including inside MappingLeaf / - SequenceLeaf containers) must be a Python scalar (float, int, bool) or a - JAX array. Notably, numpy arrays and pandas Series are not accepted. + Defense-in-depth check after `cast_params_to_canonical_dtypes`: by the + time this runs, every leaf must be a JAX `Array`, or a `MappingLeaf` / + `SequenceLeaf` whose contents recursively satisfy the same rule. """ for regime_name, regime_params in internal_params.items(): for key, value in regime_params.items(): @@ -435,7 +435,7 @@ def _validate_param_types(internal_params: InternalParams) -> None: def _check_leaf(value: object, path: str) -> None: - """Check a single leaf value, recursing into MappingLeaf/SequenceLeaf.""" + """Check a single leaf, recursing into `MappingLeaf` / `SequenceLeaf`.""" if isinstance(value, MappingLeaf): for k, v in value.data.items(): _check_leaf(v, f"{path}.{k}") @@ -444,17 +444,8 @@ def _check_leaf(value: object, path: str) -> None: for i, v in enumerate(value.data): _check_leaf(v, f"{path}[{i}]") return - if isinstance(value, (float, int, bool)): + if isinstance(value, Array): return - if hasattr(value, "dtype") and hasattr(value, "shape"): - if isinstance(value, Array): - return - type_name = type(value).__module__ + "." + type(value).__name__ - msg = ( - f"Parameter '{path}' is a {type_name} (shape {value.shape}). " - f"Use jnp.array() or pass a pd.Series with a named index." - ) - raise InvalidParamsError(msg) type_name = type(value).__module__ + "." + type(value).__name__ - msg = f"Parameter '{path}' has unexpected type {type_name}." + msg = f"Parameter {path!r} is a {type_name}, expected a JAX Array." raise InvalidParamsError(msg) From 530d50c3c17bd7fb8e2d0cea581ed1523f2eba39 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 11:14:25 +0200 Subject: [PATCH 08/22] Tighten Scalar* aliases to JAX-only; convert grid endpoints at construction ScalarFloat, ScalarInt, and ScalarBool now stand for JAX scalars only, so downstream annotations (e.g. aca-model DAG functions) carry the "post-cast invariant" guarantee accurately. Changes that follow from the tightening: - UniformContinuousGrid (LinSpacedGrid, LogSpacedGrid) and IrregSpacedGrid use a manual __init__ to accept Python literals at the user-facing API and store start/stop/points as JAX scalars at canonical_float_dtype(). Grid dtype is now sticky to construction time x64 mode. - Coordinate helpers (linspace, logspace, get_*_coordinate, Grid.get_coordinate) widen each numeric slot to `float | ScalarFloat` / `int | ScalarInt` so they remain callable from setup-time Python code as well as the JIT'd DAG. - simulate.py replaces `enumerate(ages.values)` with index-based iteration so `age` carries a proper JAX-scalar type; transitions.py follows. - Display/diagnostic age parameters in error_handling.py and logging.py widen to `int | float | ScalarInt | ScalarFloat` so Python literals from `_DiagnosticRow` keep working. Test changes: parametrised dtype-invariant test now constructs grids inside the test body so the x64_disabled fixture is in effect; the returning-int test in test_regime_state_mismatch flips to `-> int`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 169 +++++++++++++++++++-------- src/lcm/grids/coordinates.py | 90 +++++++------- src/lcm/grids/piecewise.py | 8 +- src/lcm/params/processing.py | 2 +- src/lcm/shocks/_base.py | 4 +- src/lcm/simulation/simulate.py | 7 +- src/lcm/simulation/transitions.py | 6 +- src/lcm/typing.py | 6 +- src/lcm/utils/error_handling.py | 8 +- src/lcm/utils/logging.py | 4 +- tests/test_float_dtype_invariants.py | 37 +++++- tests/test_regime_state_mismatch.py | 4 +- 12 files changed, 221 insertions(+), 124 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 9a147c44..40a58f59 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -7,6 +7,7 @@ import jax.numpy as jnp from jax import Array +from lcm.dtypes import canonical_float_dtype from lcm.exceptions import GridInitializationError, format_messages from lcm.grids import coordinates as grid_coordinates from lcm.grids.base import Grid @@ -29,32 +30,47 @@ class ContinuousGrid(Grid): """Size of the batches that are looped over during the solution.""" @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... @abstractmethod - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True, kw_only=True, init=False) class UniformContinuousGrid(ContinuousGrid, ABC): - """Grid with start/stop/n_points for linearly or logarithmically spaced values.""" + """Grid with start/stop/n_points for linearly or logarithmically spaced values. + + `start` and `stop` are stored as JAX scalars at `canonical_float_dtype()`, + converted from the Python literals supplied at construction. `n_points` + stays a Python `int` so it can size JAX arrays statically. + """ - start: int | float - """The start value of the grid.""" + start: ScalarFloat + """The start value of the grid (JAX scalar at `canonical_float_dtype()`).""" - stop: int | float - """The stop value of the grid.""" + stop: ScalarFloat + """The stop value of the grid (JAX scalar at `canonical_float_dtype()`).""" n_points: int """The number of points in the grid.""" - def __post_init__(self) -> None: - _validate_continuous_grid( - start=self.start, - stop=self.stop, - n_points=self.n_points, + def __init__( + self, + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int, + batch_size: int = 0, + ) -> None: + _init_uniform_grid( + self, + start=start, + stop=stop, + n_points=n_points, + batch_size=batch_size, + requires_positive_start=False, ) @abstractmethod @@ -62,11 +78,11 @@ def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... @abstractmethod - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" def replace(self, **kwargs: float) -> UniformContinuousGrid: @@ -103,10 +119,10 @@ def to_jax(self) -> Float1D: ) @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_linspace_coordinate( value=value, start=self.start, stop=self.stop, n_points=self.n_points @@ -124,11 +140,20 @@ class LogSpacedGrid(UniformContinuousGrid): """ - def __post_init__(self) -> None: - _validate_continuous_grid( - start=self.start, - stop=self.stop, - n_points=self.n_points, + def __init__( + self, + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int, + batch_size: int = 0, + ) -> None: + _init_uniform_grid( + self, + start=start, + stop=stop, + n_points=n_points, + batch_size=batch_size, requires_positive_start=True, ) @@ -139,25 +164,55 @@ def to_jax(self) -> Float1D: ) @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_logspace_coordinate( value=value, start=self.start, stop=self.stop, n_points=self.n_points ) -@dataclass(frozen=True, kw_only=True) +def _init_uniform_grid( + grid: UniformContinuousGrid, + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int, + batch_size: int, + requires_positive_start: bool, +) -> None: + """Validate the user input and store fields on `grid`. + + Validation runs on the original Python values; once they pass, `start` + and `stop` are converted to JAX scalars at `canonical_float_dtype()` + so downstream code reads typed scalars. + """ + _validate_continuous_grid( + start=start, + stop=stop, + n_points=n_points, + requires_positive_start=requires_positive_start, + ) + dtype = canonical_float_dtype() + object.__setattr__(grid, "start", jnp.asarray(start, dtype=dtype)) + object.__setattr__(grid, "stop", jnp.asarray(stop, dtype=dtype)) + object.__setattr__(grid, "n_points", n_points) + object.__setattr__(grid, "batch_size", batch_size) + + +@dataclass(frozen=True, kw_only=True, init=False) class IrregSpacedGrid(ContinuousGrid): """A grid of continuous values at irregular (user-specified) points. This grid type is useful for representing non-uniformly spaced points such as Gauss-Hermite quadrature nodes. - When `points` is omitted and only `n_points` is given, the `points` must be - supplied at runtime via the params. + `points` is stored as a JAX array at `canonical_float_dtype()`, converted + from the Python sequence supplied at construction. When `points` is + omitted and only `n_points` is given, the points must be supplied at + runtime via the params. Example: -------- @@ -166,31 +221,44 @@ class IrregSpacedGrid(ContinuousGrid): """ - points: Sequence[float] | Float1D | None = None + points: Float1D | None """The grid points in ascending order, or `None` for runtime-supplied points.""" - n_points: int | None = None + n_points: int """Number of points. Derived from `len(points)` when points are given.""" - def __post_init__(self) -> None: - if self.points is not None: - _validate_irreg_spaced_grid(self.points) - # Derive n_points from points if not explicitly set - if self.n_points is None: - object.__setattr__(self, "n_points", len(self.points)) - elif self.n_points != len(self.points): + def __init__( + self, + *, + points: Sequence[float] | Float1D | None = None, + n_points: int | None = None, + batch_size: int = 0, + ) -> None: + if points is not None: + _validate_irreg_spaced_grid(points) + derived_n = len(points) + if n_points is None: + n_points = derived_n + elif n_points != derived_n: raise GridInitializationError( - f"n_points ({self.n_points}) does not match " - f"len(points) ({len(self.points)})" + f"n_points ({n_points}) does not match len(points) ({derived_n})" ) - elif self.n_points is None: + stored_points: Float1D | None = jnp.asarray( + points, dtype=canonical_float_dtype() + ) + elif n_points is None: raise GridInitializationError( "Either points or n_points must be specified for IrregSpacedGrid." ) - elif self.n_points < 2: # noqa: PLR2004 + elif n_points < 2: # noqa: PLR2004 raise GridInitializationError( - f"n_points must be at least 2, got {self.n_points}" + f"n_points must be at least 2, got {n_points}" ) + else: + stored_points = None + object.__setattr__(self, "points", stored_points) + object.__setattr__(self, "n_points", n_points) + object.__setattr__(self, "batch_size", batch_size) @property def pass_points_at_runtime(self) -> bool: @@ -215,13 +283,13 @@ def to_jax(self) -> Float1D: f"read from `.states[name]` or `.continuous_actions[name]`. " f"Use `.n_points` if only the shape is needed." ) - return jnp.asarray(self.points) + return self.points @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" if self.points is None: raise GridInitializationError( @@ -229,18 +297,21 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: "initialization or use IrregSpacedGrid(n_points=...) and " "supply points at runtime via params." ) - return grid_coordinates.get_irreg_coordinate(value=value, points=self.to_jax()) + return grid_coordinates.get_irreg_coordinate(value=value, points=self.points) def _validate_continuous_grid( *, - start: float, - stop: float, + start: float | ScalarFloat, + stop: float | ScalarFloat, n_points: int, requires_positive_start: bool = False, ) -> None: """Validate the continuous grid parameters. + Accepts Python ints/floats from user construction and JAX scalars from + `dataclasses.replace` round-trips on already-constructed grids. + Args: start: The start value of the grid. stop: The stop value of the grid. @@ -254,11 +325,11 @@ def _validate_continuous_grid( """ error_messages = [] - valid_start_type = isinstance(start, int | float) + valid_start_type = isinstance(start, int | float | Array) if not valid_start_type: error_messages.append("start must be a scalar int or float value") - valid_stop_type = isinstance(stop, int | float) + valid_stop_type = isinstance(stop, int | float | Array) if not valid_stop_type: error_messages.append("stop must be a scalar int or float value") diff --git a/src/lcm/grids/coordinates.py b/src/lcm/grids/coordinates.py index 0edce4bd..98717e52 100644 --- a/src/lcm/grids/coordinates.py +++ b/src/lcm/grids/coordinates.py @@ -1,26 +1,10 @@ """Functions to generate and work with different kinds of grids. -Grid generation functions must have the following signature: - - Signature (start: ScalarFloat, stop: ScalarFloat, n_points: int) -> jax.Array - -They take start and end points and create a grid of points between them. - - -Interpolation info functions must have the following signature: - - Signature ( - value: ScalarFloat, - start: ScalarFloat, - stop: ScalarFloat, - n_points: int - ) -> ScalarInt - -They take the information required to generate a grid, and return an index corresponding -to the value, which is a point in the space but not necessarily a grid point. - -Some of the arguments will not be used by all functions but the aligned interface makes -it easy to call functions interchangeably. +Grid generation and interpolation helpers accept Python `float` literals +alongside JAX scalars at every numeric slot, so they're usable both at +setup time (Python literals from user code) and inside the JIT'd DAG +(JAX scalars). `n_points` accepts a Python `int` or a JAX integer scalar, +the latter for piecewise grids that select a piece via `searchsorted`. """ @@ -32,7 +16,12 @@ from lcm.typing import Float1D, ScalarFloat, ScalarInt -def linspace(*, start: ScalarFloat, stop: ScalarFloat, n_points: int) -> Float1D: +def linspace( + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int, +) -> Float1D: """Wrapper around jnp.linspace. Returns a linearly spaced grid between start and stop with n_points, including both @@ -45,32 +34,37 @@ def linspace(*, start: ScalarFloat, stop: ScalarFloat, n_points: int) -> Float1D @overload def get_linspace_coordinate( *, - value: ScalarFloat, - start: ScalarFloat, - stop: ScalarFloat, - n_points: ScalarInt, + value: float | ScalarFloat, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, ) -> ScalarFloat: ... @overload def get_linspace_coordinate( *, value: Array, - start: ScalarFloat, - stop: ScalarFloat, - n_points: ScalarInt, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, ) -> Array: ... def get_linspace_coordinate( *, - value: ScalarFloat | Array, - start: ScalarFloat, - stop: ScalarFloat, - n_points: ScalarInt, -) -> ScalarFloat | Array: + value: float | ScalarFloat | Array, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, +) -> float | ScalarFloat | Array: """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" step_length = (stop - start) / (n_points - 1) return (value - start) / step_length -def logspace(*, start: ScalarFloat, stop: ScalarFloat, n_points: int) -> Float1D: +def logspace( + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int, +) -> Float1D: """Wrapper around jnp.logspace. Returns a logarithmically spaced grid between start and stop with n_points, @@ -94,25 +88,25 @@ def logspace(*, start: ScalarFloat, stop: ScalarFloat, n_points: int) -> Float1D @overload def get_logspace_coordinate( *, - value: ScalarFloat, - start: ScalarFloat, - stop: ScalarFloat, - n_points: ScalarInt, + value: float | ScalarFloat, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, ) -> ScalarFloat: ... @overload def get_logspace_coordinate( *, value: Array, - start: ScalarFloat, - stop: ScalarFloat, - n_points: ScalarInt, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, ) -> Array: ... def get_logspace_coordinate( *, - value: ScalarFloat | Array, - start: ScalarFloat, - stop: ScalarFloat, - n_points: ScalarInt, + value: float | ScalarFloat | Array, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, ) -> ScalarFloat | Array: """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" # Transform start, stop, and value to linear scale @@ -154,7 +148,7 @@ def get_logspace_coordinate( @overload def get_irreg_coordinate( *, - value: ScalarFloat, + value: float | ScalarFloat, points: Float1D, ) -> ScalarFloat: ... @overload @@ -165,7 +159,7 @@ def get_irreg_coordinate( ) -> Array: ... def get_irreg_coordinate( *, - value: ScalarFloat | Array, + value: float | ScalarFloat | Array, points: Float1D, ) -> ScalarFloat | Array: """Return the generalized coordinate of a value in an irregularly spaced grid. diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index e3a252f2..aec8e7a4 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -82,10 +82,10 @@ def to_jax(self) -> Float1D: return jnp.concatenate(piece_arrays) @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_linspace_coordinate( @@ -153,10 +153,10 @@ def to_jax(self) -> Float1D: return jnp.concatenate(piece_arrays) @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_logspace_coordinate( diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index 9b6f1d59..c4cb2c93 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -198,7 +198,7 @@ def cast_params_to_canonical_dtypes(internal_params: InternalParams) -> Internal ) -def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: ANN401 +def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: ANN401, C901, PLR0911 """Cast a single params leaf to its canonical pylcm dtype. Strict whitelist — every code path either casts or raises. diff --git a/src/lcm/shocks/_base.py b/src/lcm/shocks/_base.py index 55d4ac79..2ce5fcee 100644 --- a/src/lcm/shocks/_base.py +++ b/src/lcm/shocks/_base.py @@ -107,10 +107,10 @@ def to_jax(self) -> Float1D: return self.get_gridpoints() @overload - def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" if not self.is_fully_specified: raise GridInitializationError( diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d1ab42ab..51547465 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -31,6 +31,8 @@ IntND, RegimeName, RegimeNamesToIds, + ScalarFloat, + ScalarInt, ) from lcm.utils.error_handling import validate_V from lcm.utils.logging import ( @@ -110,7 +112,8 @@ def simulate( # Build reverse lookup for regime transition logging ids_to_names: dict[int, RegimeName] = {v: k for k, v in regime_names_to_ids.items()} - for period, age in enumerate(ages.values): + for period in range(ages.n_periods): + age = ages.values[period] # noqa: PD011 period_start = time.monotonic() # Activate subjects whose starting period matches the current period @@ -199,7 +202,7 @@ def _simulate_regime_in_period( regime_name: RegimeName, internal_regime: InternalRegime, period: int, - age: float, + age: ScalarInt | ScalarFloat, states: MappingProxyType[str, Array], subject_regime_ids: Int1D, new_subject_regime_ids: Int1D, diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index b7ad15f6..74f54896 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -25,6 +25,8 @@ Int1D, RegimeName, RegimeNamesToIds, + ScalarFloat, + ScalarInt, ) from lcm.utils.namespace import flatten_regime_namespace @@ -70,7 +72,7 @@ def calculate_next_states( internal_regime: InternalRegime, optimal_actions: MappingProxyType[ActionName, Array], period: int, - age: float, + age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, states: MappingProxyType[str, Array], state_action_space: StateActionSpace, @@ -148,7 +150,7 @@ def calculate_next_regime_membership( state_action_space: StateActionSpace, optimal_actions: MappingProxyType[ActionName, Array], period: int, - age: float, + age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, regime_names_to_ids: MappingProxyType[RegimeName, int], new_subject_regime_ids: Int1D, diff --git a/src/lcm/typing.py b/src/lcm/typing.py index 5eea083c..f509bbec 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -24,9 +24,9 @@ # Many JAX functions are designed to work with scalar numerical values. This also # includes zero dimensional jax arrays. -type ScalarInt = int | Int32[Scalar, ""] -type ScalarFloat = float | Float[Scalar, ""] -type ScalarBool = bool | Bool[Scalar, ""] +type ScalarInt = Int32[Scalar, ""] +type ScalarFloat = Float[Scalar, ""] +type ScalarBool = Bool[Scalar, ""] type Period = int | Int1D type Age = int | float diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index faeed9fe..5d89a054 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -39,7 +39,7 @@ def validate_V( *, V_arr: Array, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, regime_name: RegimeName | None = None, partial_solution: object = None, compute_intermediates: Callable | None = None, @@ -287,8 +287,8 @@ def validate_regime_transition_probs( regime_transition_probs: MappingProxyType[str, Array], active_regimes_next_period: tuple[RegimeName, ...], regime_name: RegimeName, - age: ScalarInt | ScalarFloat, - next_age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, + next_age: float | ScalarInt | ScalarFloat, state_action_values: MappingProxyType[str, Array] | None = None, ) -> None: """Validate regime transition probabilities. @@ -543,7 +543,7 @@ def _validate_no_reachable_incomplete_targets( regime_transition_probs: MappingProxyType[str, Array], active_regimes_next_period: tuple[RegimeName, ...], regime_name: RegimeName, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, ) -> None: """Check that targets with incomplete stochastic transitions are unreachable. diff --git a/src/lcm/utils/logging.py b/src/lcm/utils/logging.py index a339af67..c924efe8 100644 --- a/src/lcm/utils/logging.py +++ b/src/lcm/utils/logging.py @@ -60,7 +60,7 @@ def log_nan_in_V( *, logger: logging.Logger, regime_name: str, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, V_arr: FloatND, ) -> None: """Log a warning if V_arr contains NaN or Inf values. @@ -79,7 +79,7 @@ def log_nan_in_V( def log_period_header( *, logger: logging.Logger, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, n_active_regimes: int, ) -> None: """Log the start of a period. diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py index 144c3033..8d793821 100644 --- a/tests/test_float_dtype_invariants.py +++ b/tests/test_float_dtype_invariants.py @@ -1,5 +1,6 @@ """Float dtypes follow `canonical_float_dtype()` across pylcm boundaries.""" +from collections.abc import Callable from types import MappingProxyType import jax.numpy as jnp @@ -184,16 +185,16 @@ def test_solve_v_arrays_at_canonical_float_dtype( @pytest.mark.parametrize( - "grid", + "make_grid", [ - LinSpacedGrid(start=0, stop=1, n_points=5), - LogSpacedGrid(start=1, stop=10, n_points=5), - IrregSpacedGrid(points=(0.0, 0.5, 1.0)), + lambda: LinSpacedGrid(start=0, stop=1, n_points=5), + lambda: LogSpacedGrid(start=1, stop=10, n_points=5), + lambda: IrregSpacedGrid(points=(0.0, 0.5, 1.0)), ], ids=["linspaced", "logspaced", "irregspaced"], ) def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( - grid: LinSpacedGrid | LogSpacedGrid | IrregSpacedGrid, + make_grid: Callable[[], LinSpacedGrid | LogSpacedGrid | IrregSpacedGrid], x64_disabled: None, # noqa: ARG001 ) -> None: """Continuous grid `to_jax()` materialises at `float32` under no-x64. @@ -203,10 +204,36 @@ def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( (which JAX would silently truncate to `float32` under no-x64; the helper-side comparison would mask that, the literal-side comparison surfaces it). + + Grids are constructed inside the test body so the `x64_disabled` + fixture is in effect; grid dtype is now sticky to construction-time + `jax_enable_x64`. """ + grid = make_grid() assert grid.to_jax().dtype == jnp.float32 +@pytest.mark.parametrize("attr", ["start", "stop"]) +def test_uniform_grid_stores_endpoints_as_canonical_jax_scalar( + attr: str, + x64_disabled: None, # noqa: ARG001 +) -> None: + """`LinSpacedGrid` stores `start`/`stop` as JAX scalars at canonical dtype.""" + grid = LinSpacedGrid(start=0.0, stop=100.0, n_points=10) + value = getattr(grid, attr) + assert isinstance(value, jnp.ndarray) + assert value.dtype == canonical_float_dtype() + + +def test_irreg_grid_stores_points_as_canonical_jax_array( + x64_disabled: None, # noqa: ARG001 +) -> None: + """`IrregSpacedGrid` stores `points` as a JAX array at canonical dtype.""" + grid = IrregSpacedGrid(points=(0.0, 0.5, 1.0)) + assert isinstance(grid.points, jnp.ndarray) + assert grid.points.dtype == canonical_float_dtype() + + @pytest.mark.parametrize("key", ["low", "high"]) def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( key: str, diff --git a/tests/test_regime_state_mismatch.py b/tests/test_regime_state_mismatch.py index 6332c1ae..11abb8e9 100644 --- a/tests/test_regime_state_mismatch.py +++ b/tests/test_regime_state_mismatch.py @@ -461,7 +461,7 @@ class _RegimeId: b: int dead: int - def next_regime() -> ScalarInt: + def next_regime() -> int: return _RegimeId.dead a = Regime( @@ -505,7 +505,7 @@ class _RegimeId: b: int dead: int - def next_regime() -> ScalarInt: + def next_regime() -> int: return _RegimeId.dead a = Regime( From f4515ecfee370051c7ae3e4af7286f8284911599 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 11:22:29 +0200 Subject: [PATCH 09/22] Keep coordinate helpers strict; convert at Grid.get_coordinate boundary `linspace`, `logspace`, `get_*_coordinate` are pylcm-internal: every production caller (Grid methods, piecewise dispatchers) hands them JAX scalars. Drop the `float | ScalarFloat` widening on `start` / `stop` / `value` so the helpers pin the post-cast contract. Conversion of user input now happens once at the public-API boundary, inside `Grid.get_coordinate`, via a small `_to_jax_scalar` helper. The helper-direct tests in test_grid_helpers.py wrap their literals with `jnp.asarray` to match. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 21 ++++++++++-- src/lcm/grids/coordinates.py | 56 +++++++++++++++--------------- src/lcm/grids/piecewise.py | 4 ++- src/lcm/shocks/_base.py | 5 ++- tests/test_grid_helpers.py | 66 +++++++++++++++++++++--------------- 5 files changed, 91 insertions(+), 61 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 40a58f59..9881af4e 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -38,6 +38,13 @@ def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Ar """Return the generalized coordinate of a value in the grid.""" +def _to_jax_scalar(value: float | ScalarFloat | Array) -> ScalarFloat | Array: + """Lift a Python `float` to a canonical-dtype JAX scalar; pass arrays through.""" + if isinstance(value, (int, float)): + return jnp.asarray(value, dtype=canonical_float_dtype()) + return value + + @dataclass(frozen=True, kw_only=True, init=False) class UniformContinuousGrid(ContinuousGrid, ABC): """Grid with start/stop/n_points for linearly or logarithmically spaced values. @@ -125,7 +132,10 @@ def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_linspace_coordinate( - value=value, start=self.start, stop=self.stop, n_points=self.n_points + value=_to_jax_scalar(value), + start=self.start, + stop=self.stop, + n_points=self.n_points, ) @@ -170,7 +180,10 @@ def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_logspace_coordinate( - value=value, start=self.start, stop=self.stop, n_points=self.n_points + value=_to_jax_scalar(value), + start=self.start, + stop=self.stop, + n_points=self.n_points, ) @@ -297,7 +310,9 @@ def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Ar "initialization or use IrregSpacedGrid(n_points=...) and " "supply points at runtime via params." ) - return grid_coordinates.get_irreg_coordinate(value=value, points=self.points) + return grid_coordinates.get_irreg_coordinate( + value=_to_jax_scalar(value), points=self.points + ) def _validate_continuous_grid( diff --git a/src/lcm/grids/coordinates.py b/src/lcm/grids/coordinates.py index 98717e52..2c7802d1 100644 --- a/src/lcm/grids/coordinates.py +++ b/src/lcm/grids/coordinates.py @@ -1,11 +1,9 @@ """Functions to generate and work with different kinds of grids. -Grid generation and interpolation helpers accept Python `float` literals -alongside JAX scalars at every numeric slot, so they're usable both at -setup time (Python literals from user code) and inside the JIT'd DAG -(JAX scalars). `n_points` accepts a Python `int` or a JAX integer scalar, -the latter for piecewise grids that select a piece via `searchsorted`. - +These helpers operate on JAX scalars (`ScalarFloat` for endpoints/values, +`ScalarInt` or Python `int` for `n_points`) — every production caller is +either a `Grid` method that has already converted user input to a JAX +scalar, or a piecewise dispatch that selects pieces via `searchsorted`. """ from typing import overload @@ -18,8 +16,8 @@ def linspace( *, - start: float | ScalarFloat, - stop: float | ScalarFloat, + start: ScalarFloat, + stop: ScalarFloat, n_points: int, ) -> Float1D: """Wrapper around jnp.linspace. @@ -34,26 +32,26 @@ def linspace( @overload def get_linspace_coordinate( *, - value: float | ScalarFloat, - start: float | ScalarFloat, - stop: float | ScalarFloat, + value: ScalarFloat, + start: ScalarFloat, + stop: ScalarFloat, n_points: int | ScalarInt, ) -> ScalarFloat: ... @overload def get_linspace_coordinate( *, value: Array, - start: float | ScalarFloat, - stop: float | ScalarFloat, + start: ScalarFloat, + stop: ScalarFloat, n_points: int | ScalarInt, ) -> Array: ... def get_linspace_coordinate( *, - value: float | ScalarFloat | Array, - start: float | ScalarFloat, - stop: float | ScalarFloat, + value: ScalarFloat | Array, + start: ScalarFloat, + stop: ScalarFloat, n_points: int | ScalarInt, -) -> float | ScalarFloat | Array: +) -> ScalarFloat | Array: """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" step_length = (stop - start) / (n_points - 1) return (value - start) / step_length @@ -61,8 +59,8 @@ def get_linspace_coordinate( def logspace( *, - start: float | ScalarFloat, - stop: float | ScalarFloat, + start: ScalarFloat, + stop: ScalarFloat, n_points: int, ) -> Float1D: """Wrapper around jnp.logspace. @@ -88,24 +86,24 @@ def logspace( @overload def get_logspace_coordinate( *, - value: float | ScalarFloat, - start: float | ScalarFloat, - stop: float | ScalarFloat, + value: ScalarFloat, + start: ScalarFloat, + stop: ScalarFloat, n_points: int | ScalarInt, ) -> ScalarFloat: ... @overload def get_logspace_coordinate( *, value: Array, - start: float | ScalarFloat, - stop: float | ScalarFloat, + start: ScalarFloat, + stop: ScalarFloat, n_points: int | ScalarInt, ) -> Array: ... def get_logspace_coordinate( *, - value: float | ScalarFloat | Array, - start: float | ScalarFloat, - stop: float | ScalarFloat, + value: ScalarFloat | Array, + start: ScalarFloat, + stop: ScalarFloat, n_points: int | ScalarInt, ) -> ScalarFloat | Array: """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" @@ -148,7 +146,7 @@ def get_logspace_coordinate( @overload def get_irreg_coordinate( *, - value: float | ScalarFloat, + value: ScalarFloat, points: Float1D, ) -> ScalarFloat: ... @overload @@ -159,7 +157,7 @@ def get_irreg_coordinate( ) -> Array: ... def get_irreg_coordinate( *, - value: float | ScalarFloat | Array, + value: ScalarFloat | Array, points: Float1D, ) -> ScalarFloat | Array: """Return the generalized coordinate of a value in an irregularly spaced grid. diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index aec8e7a4..934c10a1 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -8,7 +8,7 @@ from lcm.exceptions import GridInitializationError, format_messages from lcm.grids import coordinates as grid_coordinates -from lcm.grids.continuous import ContinuousGrid +from lcm.grids.continuous import ContinuousGrid, _to_jax_scalar from lcm.typing import ( Float1D, Int1D, @@ -87,6 +87,7 @@ def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" + value = _to_jax_scalar(value) piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_linspace_coordinate( value=value, @@ -158,6 +159,7 @@ def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" + value = _to_jax_scalar(value) piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_logspace_coordinate( value=value, diff --git a/src/lcm/shocks/_base.py b/src/lcm/shocks/_base.py index 2ce5fcee..f8b8dcb4 100644 --- a/src/lcm/shocks/_base.py +++ b/src/lcm/shocks/_base.py @@ -11,6 +11,7 @@ from lcm.exceptions import GridInitializationError from lcm.grids import ContinuousGrid from lcm.grids import coordinates as grid_coordinates +from lcm.grids.continuous import _to_jax_scalar from lcm.typing import Float1D, FloatND, ScalarFloat @@ -116,7 +117,9 @@ def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Ar raise GridInitializationError( "Cannot compute coordinate for a ShockGrid without all shock params." ) - return grid_coordinates.get_irreg_coordinate(value=value, points=self.to_jax()) + return grid_coordinates.get_irreg_coordinate( + value=_to_jax_scalar(value), points=self.to_jax() + ) def _validate_gauss_hermite_grid( diff --git a/tests/test_grid_helpers.py b/tests/test_grid_helpers.py index 2b698d48..12eff21e 100644 --- a/tests/test_grid_helpers.py +++ b/tests/test_grid_helpers.py @@ -15,18 +15,20 @@ def test_linspace(): - calculated = linspace(start=1, stop=2, n_points=6) + calculated = linspace(start=jnp.asarray(1.0), stop=jnp.asarray(2.0), n_points=6) expected = np.array([1, 1.2, 1.4, 1.6, 1.8, 2]) aaae(calculated, expected, decimal=DECIMAL_PRECISION) def test_linspace_mapped_value(): """For reference of the grid values, see expected grid in `test_linspace`.""" + start = jnp.asarray(1.0) + stop = jnp.asarray(2.0) # Get position corresponding to a value in the grid calculated = get_linspace_coordinate( - value=1.2, - start=1, - stop=2, + value=jnp.asarray(1.2), + start=start, + stop=stop, n_points=6, ) assert np.allclose(calculated, 1.0) @@ -36,25 +38,25 @@ def test_linspace_mapped_value(): # Here, the value is 1.3, that is in the middle of 1.2 and 1.4, which have the # positions 1 and 2, respectively. Therefore, we want the position to be 1.5. calculated = get_linspace_coordinate( - value=1.3, - start=1, - stop=2, + value=jnp.asarray(1.3), + start=start, + stop=stop, n_points=6, ) assert np.allclose(calculated, 1.5) # Get position corresponding to a value that is outside the grid calculated = get_linspace_coordinate( - value=0.6, - start=1, - stop=2, + value=jnp.asarray(0.6), + start=start, + stop=stop, n_points=6, ) assert np.allclose(calculated, -2.0) def test_logspace(): - calculated = logspace(start=1, stop=100, n_points=7) + calculated = logspace(start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=7) expected = np.array( [ 1.0, @@ -72,9 +74,9 @@ def test_logspace(): def test_logspace_mapped_value(): """For reference of the grid values, see expected grid in `test_logspace`.""" calculated = get_logspace_coordinate( - value=(2.15443469 + 4.64158883) / 2, - start=1, - stop=100, + value=jnp.asarray((2.15443469 + 4.64158883) / 2), + start=jnp.asarray(1.0), + stop=jnp.asarray(100.0), n_points=7, ) assert np.allclose(calculated, 1.5) @@ -84,8 +86,8 @@ def test_logspace_mapped_value(): def test_map_coordinates_linear(): """Illustrative test on how the output of get_linspace_coordinate can be used.""" grid_info = { - "start": 0, - "stop": 1, + "start": jnp.asarray(0.0), + "stop": jnp.asarray(1.0), "n_points": 3, } @@ -96,7 +98,7 @@ def test_map_coordinates_linear(): # We choose a coordinate that is exactly in the middle between the first and second # entry of the grid. coordinate = get_linspace_coordinate( - value=0.25, + value=jnp.asarray(0.25), **grid_info, ) @@ -109,8 +111,8 @@ def test_map_coordinates_linear(): def test_map_coordinates_logarithmic(): """Illustrative test on how the output of get_logspace_coordinate can be used.""" grid_info = { - "start": 1, - "stop": 2, + "start": jnp.asarray(1.0), + "stop": jnp.asarray(2.0), "n_points": 3, } @@ -121,7 +123,7 @@ def test_map_coordinates_logarithmic(): # We choose a coordinate that is exactly in the middle between the first and second # entry of the grid. coordinate = get_logspace_coordinate( - value=(1.0 + 1.414213562373095) / 2, + value=jnp.asarray((1.0 + 1.414213562373095) / 2), **grid_info, ) @@ -134,8 +136,8 @@ def test_map_coordinates_logarithmic(): def test_map_coordinates_linear_outside_grid(): """Illustrative test on what happens to values outside the grid.""" grid_info = { - "start": 1, - "stop": 2, + "start": jnp.asarray(1.0), + "stop": jnp.asarray(2.0), "n_points": 2, } @@ -146,8 +148,8 @@ def test_map_coordinates_linear_outside_grid(): # Get coordinates corresponding to values outside the grid [1, 2] coordinates = jnp.array( [ - get_linspace_coordinate(value=grid_val, **grid_info) - for grid_val in [-1, 0, 3] + get_linspace_coordinate(value=jnp.asarray(grid_val), **grid_info) # ty: ignore[no-matching-overload] + for grid_val in [-1.0, 0.0, 3.0] ] ) @@ -158,16 +160,26 @@ def test_map_coordinates_linear_outside_grid(): def test_get_linspace_coordinate_with_array(): values = jnp.array([1.0, 1.2, 1.5]) - coords = get_linspace_coordinate(value=values, start=1, stop=2, n_points=6) + coords = get_linspace_coordinate( + value=values, + start=jnp.asarray(1.0), + stop=jnp.asarray(2.0), + n_points=6, + ) expected = jnp.array([0.0, 1.0, 2.5]) aaae(coords, expected, decimal=DECIMAL_PRECISION) def test_get_logspace_coordinate_with_array(): - grid = logspace(start=1, stop=100, n_points=7) + grid = logspace(start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=7) mid = (float(grid[1]) + float(grid[2])) / 2 values = jnp.array([mid]) - coords = get_logspace_coordinate(value=values, start=1, stop=100, n_points=7) + coords = get_logspace_coordinate( + value=values, + start=jnp.asarray(1.0), + stop=jnp.asarray(100.0), + n_points=7, + ) aaae(coords, jnp.array([1.5]), decimal=DECIMAL_PRECISION) From 9fc2f493369160458a00af80e19987be8f23bd07 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 11:33:06 +0200 Subject: [PATCH 10/22] bench_aca_baseline: build on CPU to keep parent process CUDA-free `Model.__init__` lifts `fixed_params` Python scalars to JAX arrays via the boundary dtype cast, which initialises CUDA in the parent process when running under cuda12. ASV forks the benchmark worker from that parent; the inherited CUDA context is unusable in the child and surfaces as `CUDA_ERROR_NOT_INITIALIZED` on the first device op. Wrap `_build()` in `jax.default_device(cpu)` so all setup-time array creations stay on CPU. The worker process initialises CUDA freshly when `simulate(...)` runs in `setup`/method bodies; JAX moves the deserialised arrays to GPU on demand. Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmarks/bench_aca_baseline.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 129923fd..68185bd3 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -40,6 +40,7 @@ import time import cloudpickle +import jax from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, @@ -54,15 +55,24 @@ def _build() -> tuple[object, object, object]: - """Build the aca-baseline model, params, and initial conditions.""" - model = create_benchmark_model( - n_subjects=_N_SUBJECTS, - pref_type_grid=DiscreteGrid(BenchmarkPrefType), - ) - _, model_params = get_benchmark_params(model=model) - initial_conditions = get_benchmark_initial_conditions( - model=model, n_subjects=_N_SUBJECTS, seed=0 - ) + """Build the aca-baseline model, params, and initial conditions. + + Wrapped in `jax.default_device(cpu)` so the boundary dtype casts in + `Model.__init__` (which lift `fixed_params` Python scalars to JAX + arrays via `jnp.asarray`) don't initialise CUDA in the parent + process. ASV forks the benchmark worker from the parent; an + inherited CUDA context is unusable in the child and surfaces as + `CUDA_ERROR_NOT_INITIALIZED` on the first device op. + """ + with jax.default_device(jax.devices("cpu")[0]): + model = create_benchmark_model( + n_subjects=_N_SUBJECTS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, model_params = get_benchmark_params(model=model) + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=_N_SUBJECTS, seed=0 + ) return model, model_params, initial_conditions From 88a85aebace1a2f7be9965bd7e655b95f8749010 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 11:44:49 +0200 Subject: [PATCH 11/22] save_simulate_snapshot: strip AOT-compiled regimes before pickling result MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When `Model(n_subjects=N)` triggers an AOT compile, every `InternalRegime.simulate_functions` field carries a `jax.stages.Compiled` that holds an unpicklable `LoadedExecutable`. The snapshot already side-loads the V-array via HDF5; widen the strip pass to overwrite `SimulationResult._internal_regimes` with `model.internal_regimes` (the lazy regimes — same metadata, JIT'd `PjitFunction`s pickle cleanly, which is why `model.pkl` survives the same round-trip). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/persistence.py | 18 +++++++++++++----- tests/test_persistence.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/lcm/persistence.py b/src/lcm/persistence.py index 6f746173..0409e53d 100644 --- a/src/lcm/persistence.py +++ b/src/lcm/persistence.py @@ -219,7 +219,7 @@ def save_simulate_snapshot( _save_pkl(snap_dir / "model.pkl", model) _save_pkl(snap_dir / "params.pkl", params) _save_pkl(snap_dir / "initial_conditions.pkl", initial_conditions) - _save_pkl(snap_dir / "result.pkl", _strip_V_arr_from_result(result)) + _save_pkl(snap_dir / "result.pkl", _strip_V_arr_from_result(result, model=model)) _save_h5(snap_dir / "arrays.h5", period_to_regime_to_V_arr) _write_metadata( snap_dir, @@ -290,14 +290,22 @@ def _find_project_root() -> Path | None: return None -def _strip_V_arr_from_result(result: SimulationResult) -> SimulationResult: - """Create a copy of result with value arrays replaced by an empty mapping. - - Avoid storing period_to_regime_to_V_arr both in the pickle and in the HDF5 file. +def _strip_V_arr_from_result( + result: SimulationResult, *, model: Model +) -> SimulationResult: + """Create a copy of result with value arrays and compiled callables stripped. + `period_to_regime_to_V_arr` is dropped to avoid storing it both in the + pickle and in the HDF5 file. `_internal_regimes` is overwritten with the + model's lazy-path `internal_regimes`: when `Model(n_subjects=N)` is set + the result carries the AOT-compiled regimes, whose + `jax.stages.Compiled` callables hold a `LoadedExecutable` that cannot + be pickled. The lazy regimes carry the same metadata and cloud-pickle + cleanly (model.pkl uses the same set). """ stripped = copy.copy(result) object.__setattr__(stripped, "_period_to_regime_to_V_arr", MappingProxyType({})) + object.__setattr__(stripped, "_internal_regimes", model.internal_regimes) return stripped diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 68cb88d0..8a96cbd0 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -29,7 +29,7 @@ def _retired_utility(wealth: ContinuousState) -> FloatND: return jnp.log(wealth) -def _build_tiny_model(): +def _build_tiny_model(*, enable_jit: bool = False, n_subjects: int | None = None): def utility(consumption: ContinuousAction, wealth: ContinuousState) -> FloatND: return jnp.log(consumption + wealth) @@ -60,7 +60,8 @@ def next_regime(period: int) -> ScalarInt: regimes={"working": working, "retired": retired}, ages=ages, regime_id_class=_RegimeId, - enable_jit=False, + enable_jit=enable_jit, + n_subjects=n_subjects, ) params = {"discount_factor": 0.95} return model, params @@ -171,6 +172,30 @@ def test_simulate_with_solve_debug_persists_snapshot(tmp_path, model_and_params) assert snapshot.result is not None +def test_simulate_debug_persists_snapshot_with_aot_compiled_regimes(tmp_path): + """Debug snapshot saves successfully when `n_subjects` triggers AOT compile. + + AOT compilation produces `jax.stages.Compiled` callables on each + `InternalRegime.simulate_functions`; their backing `LoadedExecutable` + cannot be pickled. The snapshot path must strip those before pickling + `result.pkl`. + """ + model, params = _build_tiny_model(enable_jit=True, n_subjects=2) + model.simulate( + params=params, + initial_conditions=_initial_conditions(), + period_to_regime_to_V_arr=None, + log_level="debug", + log_path=tmp_path, + ) + + dirs = sorted(tmp_path.glob("simulate_snapshot_*/")) + assert len(dirs) == 1 + snapshot = load_snapshot(dirs[0]) + assert isinstance(snapshot, SimulateSnapshot) + assert snapshot.result is not None + + def test_solve_no_persistence_when_not_debug(tmp_path, model_and_params): model, params = model_and_params model.solve(params=params, log_level="progress", log_path=tmp_path) From f2d18faf074ad5bca4e99aba42c9ac8c3988f31c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 12:57:25 +0200 Subject: [PATCH 12/22] bench_aca_baseline: defer aca_model + lcm imports back into _build ASV's forkserver runs `preimport` to discover benchmarks across every `bench_*.py` module before forking workers. Importing JAX at module top loads the multithreaded XLA backend into the forkserver; every subsequent `os.fork()` (for any benchmark, not just this one) inherits a corrupted CUDA context and the first device op in the worker aborts with `CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the forkserver and confine it to the worker process. Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmarks/bench_aca_baseline.py | 48 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 68185bd3..8b15efab 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -40,16 +40,8 @@ import time import cloudpickle -import jax -from aca_model.agent.preferences import BenchmarkPrefType -from aca_model.benchmark import ( - create_benchmark_model, - get_benchmark_initial_conditions, - get_benchmark_params, -) from benchmarks import _gpu_mem -from lcm import DiscreteGrid _N_SUBJECTS = 1000 @@ -57,22 +49,32 @@ def _build() -> tuple[object, object, object]: """Build the aca-baseline model, params, and initial conditions. - Wrapped in `jax.default_device(cpu)` so the boundary dtype casts in - `Model.__init__` (which lift `fixed_params` Python scalars to JAX - arrays via `jnp.asarray`) don't initialise CUDA in the parent - process. ASV forks the benchmark worker from the parent; an - inherited CUDA context is unusable in the child and surfaces as - `CUDA_ERROR_NOT_INITIALIZED` on the first device op. + aca_model and lcm imports are deferred to the function body — ASV's + forkserver runs `preimport` to discover benchmarks across every + `bench_*.py` module before forking workers. Importing JAX at module + top loads the multithreaded XLA backend into the forkserver; every + subsequent `os.fork()` inherits a corrupted CUDA context and the + first device op in the worker aborts with + `CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the + forkserver and confine it to the worker process. """ - with jax.default_device(jax.devices("cpu")[0]): - model = create_benchmark_model( - n_subjects=_N_SUBJECTS, - pref_type_grid=DiscreteGrid(BenchmarkPrefType), - ) - _, model_params = get_benchmark_params(model=model) - initial_conditions = get_benchmark_initial_conditions( - model=model, n_subjects=_N_SUBJECTS, seed=0 - ) + from aca_model.agent.preferences import BenchmarkPrefType + from aca_model.benchmark import ( + create_benchmark_model, + get_benchmark_initial_conditions, + get_benchmark_params, + ) + + from lcm import DiscreteGrid + + model = create_benchmark_model( + n_subjects=_N_SUBJECTS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, model_params = get_benchmark_params(model=model) + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=_N_SUBJECTS, seed=0 + ) return model, model_params, initial_conditions From fff1537db8c5093f728e4bcaa00d7ac308798ed5 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 13:51:14 +0200 Subject: [PATCH 13/22] Tighten internal types: ScalarInt n_points, JAX-only Period/Age, kw-only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Continues the dtype-barrier work by promoting internal scalar metadata to JAX-typed forms wherever it lives strictly inside pylcm: - `UniformContinuousGrid.n_points` and `Piece.n_points` are stored as `jnp.int32` JAX scalars, converted from the Python literals at construction. `_init_uniform_grid` casts `start` / `stop` / `n_points` at the boundary before validation; the validator can then assume strict `ScalarFloat` / `ScalarInt` arguments and only check value invariants. Coordinate helpers (`linspace`, `logspace`, `get_*_coordinate`) tighten `n_points` to `ScalarInt` so the conversion happens once at the boundary instead of at every call. - `Grid.get_coordinate` reverts to `ScalarFloat | Array` (no Python float). The single production caller in `regime_building/V.py` always passes a JAX array; tests that called the helpers with Python literals wrap them with `jnp.asarray` / `jnp.int32`. - `Period` aliases `ScalarInt` and `Age` aliases `ScalarInt | ScalarFloat` for the JIT-internal scalar contexts. `AgeGrid.period_to_age` and `age_to_period` use plain `int | float` directly since they are user-facing API methods returning Python values. - `_simulate_regime_in_period` and the `transitions.py` helpers now take `period: ScalarInt`. The simulation loop derives `period = jnp.int32 (period_idx)` once per iteration and passes it through; dict-key lookups (`argmax_and_max_Q_over_a[period_idx]`, `period_to_regime_to_V_arr.get(period_idx + 1)`) keep using the Python int. - `FlatRegimeParams` tightens to `MappingProxyType[str, Array]` — post-whitelist every leaf is a JAX array, the prior `bool | float | Array` union was stale. - `safe_to_int32` renamed to `safe_to_int_dtype` to mirror `safe_to_float_dtype`. - `_strip_V_arr_from_result` made fully kw-only. - `pyproject.toml` ignores `ARG001` for `tests/test_float_dtype_invariants.py` so per-test `# noqa: ARG001` comments drop out and signatures collapse to a single line. - `Piece` becomes `init=False` with a manual `__init__` that lifts `n_points` to `jnp.int32`, mirroring `UniformContinuousGrid`. Test-side fallout addressed in the same commit: literals wrapped with `jnp.asarray` / `jnp.int32` where helpers tightened, redundant `# ty: ignore` comments dropped, and three "validator rejects non-numeric" tests reframed to assert the boundary cast catches them. Co-Authored-By: Claude Opus 4.7 (1M context) --- pixi.lock | 4 +- pyproject.toml | 3 + src/lcm/ages.py | 14 ++-- src/lcm/dtypes.py | 2 +- src/lcm/grids/continuous.py | 113 ++++++++++++--------------- src/lcm/grids/coordinates.py | 20 ++--- src/lcm/grids/piecewise.py | 46 +++++++---- src/lcm/interfaces.py | 8 +- src/lcm/params/processing.py | 10 +-- src/lcm/persistence.py | 7 +- src/lcm/shocks/_base.py | 9 +-- src/lcm/simulation/simulate.py | 19 ++--- src/lcm/simulation/transitions.py | 8 +- src/lcm/typing.py | 9 +-- tests/simulation/test_simulate.py | 8 +- tests/solution/test_solve_brute.py | 2 +- tests/test_dtypes.py | 28 +++---- tests/test_float_dtype_invariants.py | 42 ++++------ tests/test_grid_helpers.py | 32 +++++--- tests/test_grids.py | 61 +++++++++------ tests/test_int_dtype_invariants.py | 4 +- tests/test_next_state.py | 8 +- tests/test_pandas_utils.py | 18 ++--- tests/test_persistence.py | 4 +- tests/test_validate_param_types.py | 4 +- 25 files changed, 248 insertions(+), 235 deletions(-) diff --git a/pixi.lock b/pixi.lock index 4115222a..cedae98d 100644 --- a/pixi.lock +++ b/pixi.lock @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev195+ga908c8405.d20260505 - sha256: 44c6bd65422fdc0a7d3167cf852107aeca15bf6687a44b57a6749ad553943f11 + version: 0.0.2.dev247+gf2d18faf0.d20260508 + sha256: bb850950f6f17a2050320baee51568923488f2d994ebd9c4f5ec1232da3f9434 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index bb10a893..f7f90756 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,6 +242,9 @@ per-file-ignores."tests/*" = [ "S301", # Use of pickle "SLF001", # Private member access ] +per-file-ignores."tests/test_float_dtype_invariants.py" = [ + "ARG001", # Unused function argument (x64_disabled fixture) +] per-file-ignores."tests/test_next_state.py" = [ "ARG001", # Unused function argument "ARG005", # Unused lambda argument diff --git a/src/lcm/ages.py b/src/lcm/ages.py index 8dee9c73..11d2c88c 100644 --- a/src/lcm/ages.py +++ b/src/lcm/ages.py @@ -10,7 +10,7 @@ import jax.numpy as jnp from lcm.exceptions import GridInitializationError, format_messages -from lcm.typing import Age, Float1D, Int1D +from lcm.typing import Float1D, Int1D STEP_UNITS: MappingProxyType[str, Fraction] = MappingProxyType( { @@ -129,7 +129,7 @@ def exact_step_size(self) -> int | Fraction | None: """ return self._exact_step_size - def period_to_age(self, period: int) -> Age: + def period_to_age(self, period: int) -> int | float: """Convert a period index to the corresponding age. Args: @@ -151,7 +151,7 @@ def period_to_age(self, period: int) -> Age: return int(self._values[period]) return float(self._values[period]) - def age_to_period(self, age: Age) -> int: + def age_to_period(self, age: float) -> int: """Convert an age to the corresponding period index. Args: @@ -172,12 +172,14 @@ def age_to_period(self, age: Age) -> int: raise ValueError(msg) from None @functools.cached_property - def _age_to_period_map(self) -> dict[Age, int]: + def _age_to_period_map(self) -> dict[int | float, int]: if self._is_integer: return {int(v): i for i, v in enumerate(self._exact_values)} return {float(v): i for i, v in enumerate(self._exact_values)} - def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]: + def get_periods_where( + self, predicate: Callable[[int | float], bool] + ) -> tuple[int, ...]: """Get period indices where predicate is True. Args: @@ -187,7 +189,7 @@ def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...] Tuple of period indices where predicate(age) is True. """ - _convert: Callable[[object], Age] = int if self._is_integer else float # ty: ignore[invalid-assignment] + _convert: Callable[[object], int | float] = int if self._is_integer else float # ty: ignore[invalid-assignment] return tuple( period for period in range(self.n_periods) diff --git a/src/lcm/dtypes.py b/src/lcm/dtypes.py index 7345df1f..f2cf3170 100644 --- a/src/lcm/dtypes.py +++ b/src/lcm/dtypes.py @@ -28,7 +28,7 @@ def canonical_float_dtype() -> jnp.dtype: return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32 -def safe_to_int32(value: object, *, name: str) -> Array: +def safe_to_int_dtype(value: object, *, name: str) -> Array: """Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range. Args: diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 9881af4e..839d4cef 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -14,6 +14,7 @@ from lcm.typing import ( Float1D, ScalarFloat, + ScalarInt, ) @@ -30,28 +31,21 @@ class ContinuousGrid(Grid): """Size of the batches that are looped over during the solution.""" @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... @abstractmethod - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" -def _to_jax_scalar(value: float | ScalarFloat | Array) -> ScalarFloat | Array: - """Lift a Python `float` to a canonical-dtype JAX scalar; pass arrays through.""" - if isinstance(value, (int, float)): - return jnp.asarray(value, dtype=canonical_float_dtype()) - return value - - @dataclass(frozen=True, kw_only=True, init=False) class UniformContinuousGrid(ContinuousGrid, ABC): """Grid with start/stop/n_points for linearly or logarithmically spaced values. `start` and `stop` are stored as JAX scalars at `canonical_float_dtype()`, - converted from the Python literals supplied at construction. `n_points` - stays a Python `int` so it can size JAX arrays statically. + `n_points` as a `jnp.int32` JAX scalar — converted from the Python + literals (or other numeric inputs) supplied at construction. """ start: ScalarFloat @@ -60,15 +54,15 @@ class UniformContinuousGrid(ContinuousGrid, ABC): stop: ScalarFloat """The stop value of the grid (JAX scalar at `canonical_float_dtype()`).""" - n_points: int - """The number of points in the grid.""" + n_points: ScalarInt + """The number of points in the grid (`jnp.int32` JAX scalar).""" def __init__( self, *, start: float | ScalarFloat, stop: float | ScalarFloat, - n_points: int, + n_points: int | ScalarInt, batch_size: int = 0, ) -> None: _init_uniform_grid( @@ -85,11 +79,11 @@ def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... @abstractmethod - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" def replace(self, **kwargs: float) -> UniformContinuousGrid: @@ -126,13 +120,13 @@ def to_jax(self) -> Float1D: ) @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_linspace_coordinate( - value=_to_jax_scalar(value), + value=value, start=self.start, stop=self.stop, n_points=self.n_points, @@ -155,7 +149,7 @@ def __init__( *, start: float | ScalarFloat, stop: float | ScalarFloat, - n_points: int, + n_points: int | ScalarInt, batch_size: int = 0, ) -> None: _init_uniform_grid( @@ -174,13 +168,13 @@ def to_jax(self) -> Float1D: ) @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_logspace_coordinate( - value=_to_jax_scalar(value), + value=value, start=self.start, stop=self.stop, n_points=self.n_points, @@ -192,26 +186,32 @@ def _init_uniform_grid( *, start: float | ScalarFloat, stop: float | ScalarFloat, - n_points: int, + n_points: int | ScalarInt, batch_size: int, requires_positive_start: bool, ) -> None: - """Validate the user input and store fields on `grid`. - - Validation runs on the original Python values; once they pass, `start` - and `stop` are converted to JAX scalars at `canonical_float_dtype()` - so downstream code reads typed scalars. + """Cast `start` / `stop` / `n_points` to canonical JAX scalars, validate, store. + + `jnp.asarray(..., dtype=canonical_float_dtype())` and `jnp.int32(...)` lift + every numeric input at the boundary: Python literals from user + construction, JAX scalars from `dataclasses.replace` round-trips, anything + else raises here. The validator can then assume strict `ScalarFloat` / + `ScalarInt` types and only check value invariants (finiteness, ordering, + positivity). """ + dtype = canonical_float_dtype() + start_jax = jnp.asarray(start, dtype=dtype) + stop_jax = jnp.asarray(stop, dtype=dtype) + n_points_jax = jnp.int32(n_points) _validate_continuous_grid( - start=start, - stop=stop, - n_points=n_points, + start=start_jax, + stop=stop_jax, + n_points=n_points_jax, requires_positive_start=requires_positive_start, ) - dtype = canonical_float_dtype() - object.__setattr__(grid, "start", jnp.asarray(start, dtype=dtype)) - object.__setattr__(grid, "stop", jnp.asarray(stop, dtype=dtype)) - object.__setattr__(grid, "n_points", n_points) + object.__setattr__(grid, "start", start_jax) + object.__setattr__(grid, "stop", stop_jax) + object.__setattr__(grid, "n_points", n_points_jax) object.__setattr__(grid, "batch_size", batch_size) @@ -299,10 +299,10 @@ def to_jax(self) -> Float1D: return self.points @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" if self.points is None: raise GridInitializationError( @@ -310,22 +310,21 @@ def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Ar "initialization or use IrregSpacedGrid(n_points=...) and " "supply points at runtime via params." ) - return grid_coordinates.get_irreg_coordinate( - value=_to_jax_scalar(value), points=self.points - ) + return grid_coordinates.get_irreg_coordinate(value=value, points=self.points) def _validate_continuous_grid( *, - start: float | ScalarFloat, - stop: float | ScalarFloat, - n_points: int, + start: ScalarFloat, + stop: ScalarFloat, + n_points: ScalarInt, requires_positive_start: bool = False, ) -> None: """Validate the continuous grid parameters. - Accepts Python ints/floats from user construction and JAX scalars from - `dataclasses.replace` round-trips on already-constructed grids. + `start` and `stop` are post-cast canonical-dtype JAX scalars (the + boundary cast in `_init_uniform_grid` already rejects non-numeric + inputs); the checks here cover only value invariants. Args: start: The start value of the grid. @@ -340,32 +339,24 @@ def _validate_continuous_grid( """ error_messages = [] - valid_start_type = isinstance(start, int | float | Array) - if not valid_start_type: - error_messages.append("start must be a scalar int or float value") - - valid_stop_type = isinstance(stop, int | float | Array) - if not valid_stop_type: - error_messages.append("stop must be a scalar int or float value") - # Reject NaN/inf early — `start >= stop` returns False for NaN, so an # un-finite start would otherwise pass silently and produce a broken grid. - if valid_start_type and not jnp.isfinite(start): + start_finite = bool(jnp.isfinite(start)) + if not start_finite: error_messages.append(f"start must be finite, got {start}") - valid_start_type = False - if valid_stop_type and not jnp.isfinite(stop): + stop_finite = bool(jnp.isfinite(stop)) + if not stop_finite: error_messages.append(f"stop must be finite, got {stop}") - valid_stop_type = False - if not isinstance(n_points, int) or n_points < 1: + if n_points < 1: error_messages.append( f"n_points must be an int greater than 0 but is {n_points}", ) - if valid_start_type and valid_stop_type and start >= stop: + if start_finite and stop_finite and start >= stop: error_messages.append("start must be less than stop") - if valid_start_type and requires_positive_start and start <= 0: + if start_finite and requires_positive_start and start <= 0: error_messages.append( f"start must be > 0 for a log-spaced grid (got {start}); " f"`log(x)` is undefined for `x <= 0`." diff --git a/src/lcm/grids/coordinates.py b/src/lcm/grids/coordinates.py index 2c7802d1..8f7045ad 100644 --- a/src/lcm/grids/coordinates.py +++ b/src/lcm/grids/coordinates.py @@ -18,7 +18,7 @@ def linspace( *, start: ScalarFloat, stop: ScalarFloat, - n_points: int, + n_points: ScalarInt, ) -> Float1D: """Wrapper around jnp.linspace. @@ -26,7 +26,7 @@ def linspace( endpoints. """ - return jnp.linspace(start, stop, n_points) + return jnp.linspace(start, stop, int(n_points)) @overload @@ -35,7 +35,7 @@ def get_linspace_coordinate( value: ScalarFloat, start: ScalarFloat, stop: ScalarFloat, - n_points: int | ScalarInt, + n_points: ScalarInt, ) -> ScalarFloat: ... @overload def get_linspace_coordinate( @@ -43,14 +43,14 @@ def get_linspace_coordinate( value: Array, start: ScalarFloat, stop: ScalarFloat, - n_points: int | ScalarInt, + n_points: ScalarInt, ) -> Array: ... def get_linspace_coordinate( *, value: ScalarFloat | Array, start: ScalarFloat, stop: ScalarFloat, - n_points: int | ScalarInt, + n_points: ScalarInt, ) -> ScalarFloat | Array: """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" step_length = (stop - start) / (n_points - 1) @@ -61,7 +61,7 @@ def logspace( *, start: ScalarFloat, stop: ScalarFloat, - n_points: int, + n_points: ScalarInt, ) -> Float1D: """Wrapper around jnp.logspace. @@ -79,7 +79,7 @@ def logspace( """ start_linear = jnp.log(start) stop_linear = jnp.log(stop) - grid = jnp.logspace(start_linear, stop_linear, n_points, base=jnp.e) + grid = jnp.logspace(start_linear, stop_linear, int(n_points), base=jnp.e) return grid.at[0].set(start).at[-1].set(stop) @@ -89,7 +89,7 @@ def get_logspace_coordinate( value: ScalarFloat, start: ScalarFloat, stop: ScalarFloat, - n_points: int | ScalarInt, + n_points: ScalarInt, ) -> ScalarFloat: ... @overload def get_logspace_coordinate( @@ -97,14 +97,14 @@ def get_logspace_coordinate( value: Array, start: ScalarFloat, stop: ScalarFloat, - n_points: int | ScalarInt, + n_points: ScalarInt, ) -> Array: ... def get_logspace_coordinate( *, value: ScalarFloat | Array, start: ScalarFloat, stop: ScalarFloat, - n_points: int | ScalarInt, + n_points: ScalarInt, ) -> ScalarFloat | Array: """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" # Transform start, stop, and value to linear scale diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index 934c10a1..fa286472 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -8,17 +8,22 @@ from lcm.exceptions import GridInitializationError, format_messages from lcm.grids import coordinates as grid_coordinates -from lcm.grids.continuous import ContinuousGrid, _to_jax_scalar +from lcm.grids.continuous import ContinuousGrid from lcm.typing import ( Float1D, Int1D, ScalarFloat, + ScalarInt, ) -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True, kw_only=True, init=False) class Piece: - """A piece of a piecewise linearly spaced grid.""" + """A piece of a piecewise linearly spaced grid. + + `n_points` is stored as a `jnp.int32` JAX scalar, converted from the + Python literal supplied at construction. + """ interval: str | portion.Interval """The interval for this piece. @@ -26,8 +31,17 @@ class Piece: Can be a string like "[1, 4)" or a `portion.Interval`. """ - n_points: int - """The number of grid points in this piece.""" + n_points: ScalarInt + """The number of grid points in this piece (`jnp.int32` JAX scalar).""" + + def __init__( + self, + *, + interval: str | portion.Interval, + n_points: int | ScalarInt, + ) -> None: + object.__setattr__(self, "interval", interval) + object.__setattr__(self, "n_points", jnp.int32(n_points)) @dataclass(frozen=True, kw_only=True) @@ -69,25 +83,24 @@ def __post_init__(self) -> None: _init_piecewise_grid_cache(self) @property - def n_points(self) -> int: + def n_points(self) -> ScalarInt: """Return the total number of points in the grid.""" - return sum(p.n_points for p in self.pieces) + return sum((p.n_points for p in self.pieces), start=jnp.int32(0)) def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" piece_arrays = [ - jnp.linspace(self._piece_starts[i], self._piece_stops[i], p.n_points) + jnp.linspace(self._piece_starts[i], self._piece_stops[i], int(p.n_points)) for i, p in enumerate(self.pieces) ] return jnp.concatenate(piece_arrays) @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" - value = _to_jax_scalar(value) piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_linspace_coordinate( value=value, @@ -137,9 +150,9 @@ def __post_init__(self) -> None: _init_piecewise_grid_cache(self) @property - def n_points(self) -> int: + def n_points(self) -> ScalarInt: """Return the total number of points in the grid.""" - return sum(p.n_points for p in self.pieces) + return sum((p.n_points for p in self.pieces), start=jnp.int32(0)) def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" @@ -154,12 +167,11 @@ def to_jax(self) -> Float1D: return jnp.concatenate(piece_arrays) @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" - value = _to_jax_scalar(value) piece_idx = jnp.searchsorted(self._breakpoints, value, side="right") local_coord = grid_coordinates.get_logspace_coordinate( value=value, @@ -269,7 +281,7 @@ def _validate_piecewise_lin_spaced_grid( # noqa: C901, PLR0912 ) continue - if not isinstance(piece.n_points, int) or piece.n_points < 2: # noqa: PLR2004 + if piece.n_points < 2: # noqa: PLR2004 error_messages.append( f"pieces[{i}].n_points must be an int >= 2, but is {piece.n_points}" ) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 59617eb6..c74842e6 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -271,13 +271,9 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac if points_key not in all_params: continue if in_states: - state_replacements[name] = cast( - "ContinuousState", all_params[points_key] - ) + state_replacements[name] = all_params[points_key] else: - action_replacements[name] = cast( - "ContinuousAction", all_params[points_key] - ) + action_replacements[name] = all_params[points_key] # `_ShockGrid` is state-only by construction (intrinsic # transitions, forbidden as actions per AGENTS.md). The # `in_states` gate makes that invariant explicit — a diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index c4cb2c93..d6afbccd 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -31,7 +31,7 @@ from dags.tree import QNAME_DELIMITER, qname_from_tree_path, tree_path_from_qname from jax import Array -from lcm.dtypes import safe_to_float_dtype, safe_to_int32 +from lcm.dtypes import safe_to_float_dtype, safe_to_int_dtype from lcm.exceptions import InvalidNameError, InvalidParamsError from lcm.interfaces import InternalRegime from lcm.params.mapping_leaf import MappingLeaf @@ -208,11 +208,11 @@ def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: A - `MappingLeaf` / `SequenceLeaf`: recurse on contents. - Python `bool`: `jnp.bool_(value)` (must come before `int` — `True` is a Python `int` subclass). - - Python `int`: `safe_to_int32(value)` → `jnp.int32`. + - Python `int`: `safe_to_int_dtype(value)` → `jnp.int32`. - Python `float`: `safe_to_float_dtype(value)` → canonical float. - JAX or numpy array, dispatch on `dtype.kind`: - `"b"` (bool) → `jnp.asarray(..., dtype=jnp.bool_)`. - - `"i"` / `"u"` (signed/unsigned int) → `safe_to_int32`. + - `"i"` / `"u"` (signed/unsigned int) → `safe_to_int_dtype`. - `"f"` (float) → `safe_to_float_dtype`. Raises `InvalidParamsError` for: @@ -249,7 +249,7 @@ def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: A if isinstance(value, bool): return jnp.bool_(value) if isinstance(value, int): - return safe_to_int32(value, name=name) + return safe_to_int_dtype(value, name=name) if isinstance(value, float): return safe_to_float_dtype(value, name=name) if isinstance(value, (Array, np.ndarray)): @@ -257,7 +257,7 @@ def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: A if kind == "b": return jnp.asarray(value, dtype=jnp.bool_) if kind in ("i", "u"): - return safe_to_int32(value, name=name) + return safe_to_int_dtype(value, name=name) if kind == "f": return safe_to_float_dtype(value, name=name) msg = ( diff --git a/src/lcm/persistence.py b/src/lcm/persistence.py index 0409e53d..c8498a19 100644 --- a/src/lcm/persistence.py +++ b/src/lcm/persistence.py @@ -219,7 +219,10 @@ def save_simulate_snapshot( _save_pkl(snap_dir / "model.pkl", model) _save_pkl(snap_dir / "params.pkl", params) _save_pkl(snap_dir / "initial_conditions.pkl", initial_conditions) - _save_pkl(snap_dir / "result.pkl", _strip_V_arr_from_result(result, model=model)) + _save_pkl( + snap_dir / "result.pkl", + _strip_V_arr_from_result(result=result, model=model), + ) _save_h5(snap_dir / "arrays.h5", period_to_regime_to_V_arr) _write_metadata( snap_dir, @@ -291,7 +294,7 @@ def _find_project_root() -> Path | None: def _strip_V_arr_from_result( - result: SimulationResult, *, model: Model + *, result: SimulationResult, model: Model ) -> SimulationResult: """Create a copy of result with value arrays and compiled callables stripped. diff --git a/src/lcm/shocks/_base.py b/src/lcm/shocks/_base.py index f8b8dcb4..55d4ac79 100644 --- a/src/lcm/shocks/_base.py +++ b/src/lcm/shocks/_base.py @@ -11,7 +11,6 @@ from lcm.exceptions import GridInitializationError from lcm.grids import ContinuousGrid from lcm.grids import coordinates as grid_coordinates -from lcm.grids.continuous import _to_jax_scalar from lcm.typing import Float1D, FloatND, ScalarFloat @@ -108,18 +107,16 @@ def to_jax(self) -> Float1D: return self.get_gridpoints() @overload - def get_coordinate(self, value: float | ScalarFloat) -> ScalarFloat: ... + def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload def get_coordinate(self, value: Array) -> Array: ... - def get_coordinate(self, value: float | ScalarFloat | Array) -> ScalarFloat | Array: + def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" if not self.is_fully_specified: raise GridInitializationError( "Cannot compute coordinate for a ShockGrid without all shock params." ) - return grid_coordinates.get_irreg_coordinate( - value=_to_jax_scalar(value), points=self.to_jax() - ) + return grid_coordinates.get_irreg_coordinate(value=value, points=self.to_jax()) def _validate_gauss_hermite_grid( diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 51547465..e07c03fd 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -112,8 +112,8 @@ def simulate( # Build reverse lookup for regime transition logging ids_to_names: dict[int, RegimeName] = {v: k for k, v in regime_names_to_ids.items()} - for period in range(ages.n_periods): - age = ages.values[period] # noqa: PD011 + for period_idx, age in enumerate(ages.values): + period = jnp.int32(period_idx) period_start = time.monotonic() # Activate subjects whose starting period matches the current period @@ -129,13 +129,13 @@ def simulate( active_regimes = { regime_name: regime for regime_name, regime in internal_regimes.items() - if period in regime.active_periods + if period_idx in regime.active_periods } active_regimes_next_period = tuple( regime_name for regime_name, regime in internal_regimes.items() - if period + 1 in regime.active_periods + if period_idx + 1 in regime.active_periods ) log_period_header(logger=logger, age=age, n_active_regimes=len(active_regimes)) @@ -158,7 +158,7 @@ def simulate( ) ) states = new_states - simulation_results[regime_name][period] = result + simulation_results[regime_name][period_idx] = result log_nan_in_V( logger=logger, regime_name=regime_name, age=age, V_arr=result.V_arr @@ -201,7 +201,7 @@ def _simulate_regime_in_period( *, regime_name: RegimeName, internal_regime: InternalRegime, - period: int, + period: ScalarInt, age: ScalarInt | ScalarFloat, states: MappingProxyType[str, Array], subject_regime_ids: Int1D, @@ -255,15 +255,16 @@ def _simulate_regime_in_period( # We need to pass the value function array of the next period to the # argmax_and_max_Q_over_a function, as the current Q-function requires the # next period's value function. In the last period, we pass an empty dict. + period_idx = int(period) next_regime_to_V_arr = period_to_regime_to_V_arr.get( - period + 1, MappingProxyType({}) + period_idx + 1, MappingProxyType({}) ) # The Q-function values contain the information of how much value each # action combination is worth. To find the optimal discrete action, we # therefore only need to maximize the Q-function values over all actions. argmax_and_max_Q_over_a = ( - internal_regime.simulate_functions.argmax_and_max_Q_over_a[period] + internal_regime.simulate_functions.argmax_and_max_Q_over_a[period_idx] ) indices_optimal_actions, V_arr = argmax_and_max_Q_over_a( @@ -272,7 +273,7 @@ def _simulate_regime_in_period( **state_action_space.continuous_actions, next_regime_to_V_arr=next_regime_to_V_arr, **internal_params[regime_name], - period=jnp.int32(period), + period=period, age=age, ) validate_V(V_arr=V_arr, age=age, regime_name=regime_name) diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 74f54896..00a5c174 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -71,7 +71,7 @@ def calculate_next_states( *, internal_regime: InternalRegime, optimal_actions: MappingProxyType[ActionName, Array], - period: int, + period: ScalarInt, age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, states: MappingProxyType[str, Array], @@ -128,7 +128,7 @@ def calculate_next_states( **state_action_space.states, **optimal_actions, **stochastic_variables_keys, - period=jnp.int32(period), + period=period, age=age, **regime_params, ) @@ -149,7 +149,7 @@ def calculate_next_regime_membership( internal_regime: InternalRegime, state_action_space: StateActionSpace, optimal_actions: MappingProxyType[ActionName, Array], - period: int, + period: ScalarInt, age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, regime_names_to_ids: MappingProxyType[RegimeName, int], @@ -189,7 +189,7 @@ def calculate_next_regime_membership( internal_regime.simulate_functions.compute_regime_transition_probs( # ty: ignore[call-non-callable] **state_action_space.states, **optimal_actions, - period=jnp.int32(period), + period=period, age=age, **regime_params, ) diff --git a/src/lcm/typing.py b/src/lcm/typing.py index f509bbec..9553ebea 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -22,14 +22,13 @@ type Int1D = Int32[Array, "_"] # noqa: F821 type Bool1D = Bool[Array, "_"] # noqa: F821 -# Many JAX functions are designed to work with scalar numerical values. This also -# includes zero dimensional jax arrays. +# Zero-dimensional JAX scalars — pylcm's canonical scalar form post boundary cast. type ScalarInt = Int32[Scalar, ""] type ScalarFloat = Float[Scalar, ""] type ScalarBool = Bool[Scalar, ""] -type Period = int | Int1D -type Age = int | float +type Period = ScalarInt +type Age = ScalarInt | ScalarFloat type RegimeName = str type StateName = str type ActionName = str @@ -55,7 +54,7 @@ # Internal regime parameters: A flat mapping with function-qualified names. # Keys are always function-qualified (e.g., "utility__risk_aversion", # "H__discount_factor"). Values are scalars or arrays. -type FlatRegimeParams = MappingProxyType[str, bool | float | Array] +type FlatRegimeParams = MappingProxyType[str, Array] type InternalParams = MappingProxyType[RegimeName, FlatRegimeParams] # Immutable templates, used internally diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 1ead8fa2..cb434a72 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -64,10 +64,10 @@ def test_simulate_using_raw_inputs(simulate_inputs): { "working_life": MappingProxyType( { - "H__discount_factor": 1.0, - "utility__disutility_of_work": 1.0, - "next_wealth__interest_rate": 0.05, - "next_regime__final_age_alive": 0, + "H__discount_factor": jnp.asarray(1.0), + "utility__disutility_of_work": jnp.asarray(1.0), + "next_wealth__interest_rate": jnp.asarray(0.05), + "next_regime__final_age_alive": jnp.asarray(0), } ), "dead": MappingProxyType({}), diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 5538e5e1..9b6095d1 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -52,7 +52,7 @@ def test_solve_brute(): # ================================================================================== # create the params # ================================================================================== - internal_params = MappingProxyType({"discount_factor": 0.9}) + internal_params = MappingProxyType({"discount_factor": jnp.asarray(0.9)}) # ================================================================================== # create the list of state_action_spaces diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index e42e4ff0..4106f47c 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype, safe_to_int32 +from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype, safe_to_int_dtype @pytest.mark.parametrize( @@ -12,9 +12,9 @@ [7, np.asarray([0, 1, -3], dtype=np.int64)], ids=["python-int", "int64-array"], ) -def test_safe_to_int32_returns_int32(value: object) -> None: - """`safe_to_int32` returns a `jnp.int32` array for any in-range int input.""" - out = safe_to_int32(value, name="x") +def test_safe_to_int_dtype_returns_int32(value: object) -> None: + """`safe_to_int_dtype` returns a `jnp.int32` array for any in-range int input.""" + out = safe_to_int_dtype(value, name="x") assert out.dtype == jnp.int32 @@ -26,34 +26,34 @@ def test_safe_to_int32_returns_int32(value: object) -> None: ], ids=["python-int", "int64-array"], ) -def test_safe_to_int32_preserves_in_range_values( +def test_safe_to_int_dtype_preserves_in_range_values( value: object, expected: object ) -> None: - """`safe_to_int32` preserves element values for in-range inputs.""" - out = safe_to_int32(value, name="x") + """`safe_to_int_dtype` preserves element values for in-range inputs.""" + out = safe_to_int_dtype(value, name="x") np.testing.assert_array_equal(np.asarray(out), expected) -def test_safe_to_int32_raises_on_python_int_overflow() -> None: +def test_safe_to_int_dtype_raises_on_python_int_overflow() -> None: """A Python int above int32 max raises `ValueError` naming the leaf.""" with pytest.raises(ValueError, match="my_param"): - safe_to_int32(2**32, name="my_param") + safe_to_int_dtype(2**32, name="my_param") -def test_safe_to_int32_raises_on_array_overflow() -> None: +def test_safe_to_int_dtype_raises_on_array_overflow() -> None: """An int64 array containing values above int32 max raises with the leaf name.""" # Use numpy here: `jnp.asarray(..., dtype=jnp.int64)` truncates to int32 # under `jax_enable_x64=False` and trips JAX's own overflow guard before - # `safe_to_int32` ever sees the value. + # `safe_to_int_dtype` ever sees the value. arr = np.asarray([1, 2, 2**32], dtype=np.int64) with pytest.raises(ValueError, match="regime"): - safe_to_int32(arr, name="regime") + safe_to_int_dtype(arr, name="regime") -def test_safe_to_int32_raises_on_underflow() -> None: +def test_safe_to_int_dtype_raises_on_underflow() -> None: """A Python int below int32 min raises `ValueError` naming the leaf.""" with pytest.raises(ValueError, match="offset"): - safe_to_int32(-(2**40), name="offset") + safe_to_int_dtype(-(2**40), name="offset") def test_canonical_float_dtype_is_float32_under_no_x64( diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py index 8d793821..d2367661 100644 --- a/tests/test_float_dtype_invariants.py +++ b/tests/test_float_dtype_invariants.py @@ -21,7 +21,7 @@ def test_build_initial_states_casts_user_float64_to_canonical( - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """A float64 continuous initial state lands at `canonical_float_dtype()`.""" model = get_model(n_periods=3) @@ -36,9 +36,7 @@ def test_build_initial_states_casts_user_float64_to_canonical( assert flat["working_life__wealth"].dtype == canonical_float_dtype() -def test_build_initial_states_casts_user_int_to_canonical( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_build_initial_states_casts_user_int_to_canonical(x64_disabled: None) -> None: """A continuous initial state given as int32 lands at `canonical_float_dtype()`.""" model = get_model(n_periods=3) initial_states = { @@ -53,7 +51,7 @@ def test_build_initial_states_casts_user_int_to_canonical( def test_build_initial_states_missing_continuous_fallback_dtype_is_canonical( - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """A missing continuous state falls back to a canonical-dtype array.""" model = get_model(n_periods=3) @@ -66,7 +64,7 @@ def test_build_initial_states_missing_continuous_fallback_dtype_is_canonical( def test_build_initial_states_missing_continuous_fallback_values_are_nan( - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """A missing continuous state falls back to an all-NaN array. @@ -82,7 +80,7 @@ def test_build_initial_states_missing_continuous_fallback_values_are_nan( def test_process_params_casts_float64_array_to_canonical_under_no_x64( - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """A `float64` array param is downcast to `float32` under `jax_enable_x64=False`. @@ -101,12 +99,10 @@ def test_process_params_casts_float64_array_to_canonical_under_no_x64( ) schedule = out["regime_a"]["schedule"] - assert schedule.dtype == jnp.float32 # ty: ignore[unresolved-attribute] + assert schedule.dtype == jnp.float32 -def test_process_params_casts_python_float_to_canonical( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_process_params_casts_python_float_to_canonical(x64_disabled: None) -> None: """A Python `float` param leaf is cast to `canonical_float_dtype()`.""" template = MappingProxyType( {"regime_a": MappingProxyType({"discount_factor": "float"})} @@ -120,11 +116,11 @@ def test_process_params_casts_python_float_to_canonical( discount_factor = out["regime_a"]["discount_factor"] np.testing.assert_allclose(float(discount_factor), 0.95, rtol=1e-6) - assert discount_factor.dtype == canonical_float_dtype() # ty: ignore[unresolved-attribute] + assert discount_factor.dtype == canonical_float_dtype() def test_process_params_float_array_overflow_raises_with_qualified_name( - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """An out-of-float32 float64 array raises naming the qualified leaf.""" template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) @@ -137,9 +133,7 @@ def test_process_params_float_array_overflow_raises_with_qualified_name( ) -def test_simulate_state_pool_dtype_stable_across_periods( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_simulate_state_pool_dtype_stable_across_periods(x64_disabled: None) -> None: """A multi-period simulate keeps every state's dtype stable across periods. The intended invariant is per-state stability; failing on any single @@ -168,9 +162,7 @@ def test_simulate_state_pool_dtype_stable_across_periods( assert not drifted, f"States drifted across periods: {drifted}" -def test_solve_v_arrays_at_canonical_float_dtype( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_solve_v_arrays_at_canonical_float_dtype(x64_disabled: None) -> None: """Every V-array returned by `model.solve()` is at `canonical_float_dtype()`.""" model = get_model(n_periods=3) period_to_regime_to_V_arr = model.solve(params=get_params(n_periods=3)) @@ -195,7 +187,7 @@ def test_solve_v_arrays_at_canonical_float_dtype( ) def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( make_grid: Callable[[], LinSpacedGrid | LogSpacedGrid | IrregSpacedGrid], - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """Continuous grid `to_jax()` materialises at `float32` under no-x64. @@ -216,7 +208,7 @@ def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( @pytest.mark.parametrize("attr", ["start", "stop"]) def test_uniform_grid_stores_endpoints_as_canonical_jax_scalar( attr: str, - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """`LinSpacedGrid` stores `start`/`stop` as JAX scalars at canonical dtype.""" grid = LinSpacedGrid(start=0.0, stop=100.0, n_points=10) @@ -225,9 +217,7 @@ def test_uniform_grid_stores_endpoints_as_canonical_jax_scalar( assert value.dtype == canonical_float_dtype() -def test_irreg_grid_stores_points_as_canonical_jax_array( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_irreg_grid_stores_points_as_canonical_jax_array(x64_disabled: None) -> None: """`IrregSpacedGrid` stores `points` as a JAX array at canonical dtype.""" grid = IrregSpacedGrid(points=(0.0, 0.5, 1.0)) assert isinstance(grid.points, jnp.ndarray) @@ -237,7 +227,7 @@ def test_irreg_grid_stores_points_as_canonical_jax_array( @pytest.mark.parametrize("key", ["low", "high"]) def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( key: str, - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """`MappingLeaf` float arrays land at `canonical_float_dtype()`.""" template = MappingProxyType( @@ -268,7 +258,7 @@ def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( @pytest.mark.parametrize("index", [0, 1]) def test_process_params_casts_float_array_inside_sequence_leaf_to_canonical( index: int, - x64_disabled: None, # noqa: ARG001 + x64_disabled: None, ) -> None: """`SequenceLeaf` float arrays land at `canonical_float_dtype()`.""" template = MappingProxyType( diff --git a/tests/test_grid_helpers.py b/tests/test_grid_helpers.py index 12eff21e..a3994c1b 100644 --- a/tests/test_grid_helpers.py +++ b/tests/test_grid_helpers.py @@ -15,7 +15,9 @@ def test_linspace(): - calculated = linspace(start=jnp.asarray(1.0), stop=jnp.asarray(2.0), n_points=6) + calculated = linspace( + start=jnp.asarray(1.0), stop=jnp.asarray(2.0), n_points=jnp.int32(6) + ) expected = np.array([1, 1.2, 1.4, 1.6, 1.8, 2]) aaae(calculated, expected, decimal=DECIMAL_PRECISION) @@ -29,7 +31,7 @@ def test_linspace_mapped_value(): value=jnp.asarray(1.2), start=start, stop=stop, - n_points=6, + n_points=jnp.int32(6), ) assert np.allclose(calculated, 1.0) @@ -41,7 +43,7 @@ def test_linspace_mapped_value(): value=jnp.asarray(1.3), start=start, stop=stop, - n_points=6, + n_points=jnp.int32(6), ) assert np.allclose(calculated, 1.5) @@ -50,13 +52,15 @@ def test_linspace_mapped_value(): value=jnp.asarray(0.6), start=start, stop=stop, - n_points=6, + n_points=jnp.int32(6), ) assert np.allclose(calculated, -2.0) def test_logspace(): - calculated = logspace(start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=7) + calculated = logspace( + start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=jnp.int32(7) + ) expected = np.array( [ 1.0, @@ -77,7 +81,7 @@ def test_logspace_mapped_value(): value=jnp.asarray((2.15443469 + 4.64158883) / 2), start=jnp.asarray(1.0), stop=jnp.asarray(100.0), - n_points=7, + n_points=jnp.int32(7), ) assert np.allclose(calculated, 1.5) @@ -88,7 +92,7 @@ def test_map_coordinates_linear(): grid_info = { "start": jnp.asarray(0.0), "stop": jnp.asarray(1.0), - "n_points": 3, + "n_points": jnp.int32(3), } grid = linspace(**grid_info) # [0, 0.5, 1] @@ -113,7 +117,7 @@ def test_map_coordinates_logarithmic(): grid_info = { "start": jnp.asarray(1.0), "stop": jnp.asarray(2.0), - "n_points": 3, + "n_points": jnp.int32(3), } grid = logspace(**grid_info) # [1.0, 1.414213562373095, 2.0] @@ -138,7 +142,7 @@ def test_map_coordinates_linear_outside_grid(): grid_info = { "start": jnp.asarray(1.0), "stop": jnp.asarray(2.0), - "n_points": 2, + "n_points": jnp.int32(2), } grid = linspace(**grid_info) # [1, 2] @@ -148,7 +152,7 @@ def test_map_coordinates_linear_outside_grid(): # Get coordinates corresponding to values outside the grid [1, 2] coordinates = jnp.array( [ - get_linspace_coordinate(value=jnp.asarray(grid_val), **grid_info) # ty: ignore[no-matching-overload] + get_linspace_coordinate(value=jnp.asarray(grid_val), **grid_info) for grid_val in [-1.0, 0.0, 3.0] ] ) @@ -164,21 +168,23 @@ def test_get_linspace_coordinate_with_array(): value=values, start=jnp.asarray(1.0), stop=jnp.asarray(2.0), - n_points=6, + n_points=jnp.int32(6), ) expected = jnp.array([0.0, 1.0, 2.5]) aaae(coords, expected, decimal=DECIMAL_PRECISION) def test_get_logspace_coordinate_with_array(): - grid = logspace(start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=7) + grid = logspace( + start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=jnp.int32(7) + ) mid = (float(grid[1]) + float(grid[2])) / 2 values = jnp.array([mid]) coords = get_logspace_coordinate( value=values, start=jnp.asarray(1.0), stop=jnp.asarray(100.0), - n_points=7, + n_points=jnp.int32(7), ) aaae(coords, jnp.array([1.5]), decimal=DECIMAL_PRECISION) diff --git a/tests/test_grids.py b/tests/test_grids.py index da2370d9..cd9cbfb1 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -167,34 +167,38 @@ class UnorderedCat: assert grid.ordered is False -def test_validate_continuous_grid_invalid_start(): - error_msg = "start must be a scalar int or float value" - with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start="a", stop=1, n_points=10) # ty: ignore[invalid-argument-type] +def test_lin_spaced_grid_rejects_non_numeric_start(): + """Non-numeric `start` is rejected at the boundary cast.""" + with pytest.raises((TypeError, ValueError)): + LinSpacedGrid(start="a", stop=1, n_points=10) # ty: ignore[invalid-argument-type] -def test_validate_continuous_grid_invalid_stop(): - error_msg = "stop must be a scalar int or float value" - with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=1, stop="a", n_points=10) # ty: ignore[invalid-argument-type] +def test_lin_spaced_grid_rejects_non_numeric_stop(): + """Non-numeric `stop` is rejected at the boundary cast.""" + with pytest.raises((TypeError, ValueError)): + LinSpacedGrid(start=1, stop="a", n_points=10) # ty: ignore[invalid-argument-type] -def test_validate_continuous_grid_invalid_n_points(): - error_msg = "n_points must be an int greater than 0 but is a" - with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=1, stop=2, n_points="a") # ty: ignore[invalid-argument-type] +def test_lin_spaced_grid_rejects_non_numeric_n_points(): + """Non-numeric `n_points` is rejected at the boundary cast.""" + with pytest.raises((TypeError, ValueError)): + LinSpacedGrid(start=1, stop=2, n_points="a") # ty: ignore[invalid-argument-type] def test_validate_continuous_grid_negative_n_points(): error_msg = "n_points must be an int greater than 0 but is -1" with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=1, stop=2, n_points=-1) + _validate_continuous_grid( + start=jnp.asarray(1.0), stop=jnp.asarray(2.0), n_points=jnp.int32(-1) + ) def test_validate_continuous_grid_start_greater_than_stop(): error_msg = "start must be less than stop" with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=2, stop=1, n_points=10) + _validate_continuous_grid( + start=jnp.asarray(2.0), stop=jnp.asarray(1.0), n_points=jnp.int32(10) + ) def test_linspace_grid_creation(): @@ -229,12 +233,20 @@ def test_logspace_grid_rejects_negative_start(): def test_validate_continuous_grid_rejects_nan_start(): with pytest.raises(GridInitializationError, match="start must be finite"): - _validate_continuous_grid(start=float("nan"), stop=10, n_points=5) + _validate_continuous_grid( + start=jnp.asarray(float("nan")), + stop=jnp.asarray(10.0), + n_points=jnp.int32(5), + ) def test_validate_continuous_grid_rejects_inf_stop(): with pytest.raises(GridInitializationError, match="stop must be finite"): - _validate_continuous_grid(start=1, stop=float("inf"), n_points=5) + _validate_continuous_grid( + start=jnp.asarray(1.0), + stop=jnp.asarray(float("inf")), + n_points=jnp.int32(5), + ) def test_irreg_spaced_grid_rejects_nan_points(): @@ -332,8 +344,9 @@ def test_linspaced_coordinates_match_other_grid_types( rtol = base_rtol * max_magnitude for value in all_test_values: - lin_coord = float(lin_grid.get_coordinate(value)) - other_coord = float(other_grid.get_coordinate(value)) + value_jax = jnp.asarray(value) + lin_coord = float(lin_grid.get_coordinate(value_jax)) + other_coord = float(other_grid.get_coordinate(value_jax)) assert np.isclose(lin_coord, other_coord, rtol=rtol), ( f"Mismatch at value {value} for {grid_type} vs LinSpacedGrid " f"({start}, {stop}, {n_points}): " @@ -583,7 +596,7 @@ def test_piecewise_log_spaced_grid_coordinate_at_gridpoints(): grid = PiecewiseLogSpacedGrid(pieces=(Piece(interval="[1, 100]", n_points=3),)) points = grid.to_jax() for i, p in enumerate(points): - coord = float(grid.get_coordinate(float(p))) + coord = float(grid.get_coordinate(p)) assert coord == pytest.approx(i) @@ -595,8 +608,8 @@ def test_piecewise_log_spaced_grid_coordinate_multi_piece(): Piece(interval="[10, 100]", n_points=2), ) ) - assert float(grid.get_coordinate(10.0)) == pytest.approx(2.0) - assert float(grid.get_coordinate(100.0)) == pytest.approx(3.0) + assert float(grid.get_coordinate(jnp.asarray(10.0))) == pytest.approx(2.0) + assert float(grid.get_coordinate(jnp.asarray(100.0))) == pytest.approx(3.0) def _create_boundary_test_grid(grid_cls, boundary_style: str): @@ -671,9 +684,9 @@ def test_piecewise_boundary_conditions(grid_cls, boundary_style: str): def test_piecewise_single_piece(): """Test piecewise grid with single piece works correctly.""" grid = PiecewiseLinSpacedGrid(pieces=(Piece(interval="[0, 10]", n_points=11),)) - assert float(grid.get_coordinate(0.0)) == pytest.approx(0.0) - assert float(grid.get_coordinate(5.0)) == pytest.approx(5.0) - assert float(grid.get_coordinate(10.0)) == pytest.approx(10.0) + assert float(grid.get_coordinate(jnp.asarray(0.0))) == pytest.approx(0.0) + assert float(grid.get_coordinate(jnp.asarray(5.0))) == pytest.approx(5.0) + assert float(grid.get_coordinate(jnp.asarray(10.0))) == pytest.approx(10.0) def test_lin_spaced_grid_get_coordinate_with_array(): diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index b9eb1f03..e089c0e6 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -101,7 +101,7 @@ def test_process_params_casts_python_int_to_int32() -> None: final_age = out["regime_a"]["final_age"] assert int(final_age) == 65 - assert final_age.dtype == jnp.int32 # ty: ignore[unresolved-attribute] + assert final_age.dtype == jnp.int32 def test_process_params_casts_int64_array_to_int32() -> None: @@ -115,7 +115,7 @@ def test_process_params_casts_int64_array_to_int32() -> None: ) schedule = out["regime_a"]["schedule"] - assert schedule.dtype == jnp.int32 # ty: ignore[unresolved-attribute] + assert schedule.dtype == jnp.int32 def test_process_params_int_array_overflow_raises_with_qualified_name() -> None: diff --git a/tests/test_next_state.py b/tests/test_next_state.py index a1f688c5..778bc195 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -46,10 +46,10 @@ def test_get_next_state_function_with_solve_target(): got = got_func( **action, - **state, - period=1, - age=1.0, - **flat_regime_params, + **state, # ty: ignore[invalid-argument-type] + period=1, # ty: ignore[invalid-argument-type] + age=1.0, # ty: ignore[invalid-argument-type] + **flat_regime_params, # ty: ignore[invalid-argument-type] ) assert got == {"next_wealth": 1.05 * (20 - 10)} diff --git a/tests/test_pandas_utils.py b/tests/test_pandas_utils.py index e1c67139..a8dba7b6 100644 --- a/tests/test_pandas_utils.py +++ b/tests/test_pandas_utils.py @@ -1443,8 +1443,8 @@ def test_convert_series_function_level_series() -> None: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["working_life"]["next_partner__probs_array"] - assert arr.shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] - assert float(arr[0, 0, 0, 0]) == pytest.approx(1.0) # ty: ignore[not-subscriptable] + assert arr.shape == (3, 2, 2, 2) + assert float(arr[0, 0, 0, 0]) == pytest.approx(1.0) def test_convert_series_model_level_scalar_passthrough() -> None: @@ -1484,7 +1484,7 @@ def test_convert_series_regime_level_series() -> None: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["working_life"]["next_partner__probs_array"] - assert arr.shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (3, 2, 2, 2) def test_convert_series_mixed_dict() -> None: @@ -1511,7 +1511,7 @@ def test_convert_series_mixed_dict() -> None: ) assert result["working_life"]["H__discount_factor"] == 0.95 assert result["working_life"]["utility__disutility_of_work"] == 0.5 - assert result["working_life"]["next_partner__probs_array"].shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] + assert result["working_life"]["next_partner__probs_array"].shape == (3, 2, 2, 2) assert result["working_life"]["next_wealth__interest_rate"] == 0.05 np.testing.assert_allclose( result["working_life"]["labor_income__wage"], jnp.array([10.0]) @@ -1655,7 +1655,7 @@ def test_convert_series_with_derived_categoricals() -> None: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["retirement"]["next_partner__probs_array"] - assert arr.shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (3, 2, 2, 2) def test_convert_series_per_target_transition() -> None: @@ -1735,7 +1735,7 @@ def _next_wealth(wealth: float) -> float: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["working"]["to_working_next_health__probs_array"] - assert arr.shape == (3, 2, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (3, 2, 2) def test_build_outcome_mapping_qualified_func_name() -> None: @@ -1833,8 +1833,8 @@ def _next_wealth_sc(wealth: float) -> float: ages=model.ages, regime_names_to_ids=model.regime_names_to_ids, ) - assert result_both["regime_a"]["utility__rates"].shape == (2,) # ty: ignore[unresolved-attribute] - assert result_both["regime_b"]["utility__rates"].shape == (3,) # ty: ignore[unresolved-attribute] + assert result_both["regime_a"]["utility__rates"].shape == (2,) + assert result_both["regime_b"]["utility__rates"].shape == (3,) def test_convert_series_runtime_grid_param() -> None: @@ -2021,7 +2021,7 @@ def _health_probs_cross( arr = result["pre65"]["to_post65_next_health__health_trans_probs_cross"] # Shape: (n_ages=2, n_source_health=3, n_target_health=2) # n_ages=2 because AgeGrid has ages [0, 1]; missing age 1 is NaN-filled. - assert arr.shape == (2, 3, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (2, 3, 2) def test_resolve_categoricals_includes_derived_when_no_regime_name() -> None: diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 8a96cbd0..561fc1a5 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -29,7 +29,7 @@ def _retired_utility(wealth: ContinuousState) -> FloatND: return jnp.log(wealth) -def _build_tiny_model(*, enable_jit: bool = False, n_subjects: int | None = None): +def _build_tiny_model(*, enable_jit: bool, n_subjects: int): def utility(consumption: ContinuousAction, wealth: ContinuousState) -> FloatND: return jnp.log(consumption + wealth) @@ -77,7 +77,7 @@ def _initial_conditions(): @pytest.fixture def model_and_params(): - return _build_tiny_model() + return _build_tiny_model(enable_jit=False, n_subjects=2) @pytest.fixture diff --git a/tests/test_validate_param_types.py b/tests/test_validate_param_types.py index db634a2a..17428606 100644 --- a/tests/test_validate_param_types.py +++ b/tests/test_validate_param_types.py @@ -73,7 +73,7 @@ def test_jax_array_param_kept_at_canonical_dtype() -> None: params={"bonus": jnp.asarray(1.0), "discount_factor": 0.95} ) bonus = internal["working"]["utility__bonus"] - assert bonus.dtype == canonical_float_dtype() # ty: ignore[unresolved-attribute] + assert bonus.dtype == canonical_float_dtype() def test_python_float_param_cast_to_canonical_dtype() -> None: @@ -82,4 +82,4 @@ def test_python_float_param_cast_to_canonical_dtype() -> None: internal = model._process_params(params={"bonus": 1.0, "discount_factor": 0.95}) bonus = internal["working"]["utility__bonus"] assert float(bonus) == 1.0 - assert bonus.dtype == canonical_float_dtype() # ty: ignore[unresolved-attribute] + assert bonus.dtype == canonical_float_dtype() From aea8735e5965e4a80276f7035ea34c82f97e2ad7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 14:09:15 +0200 Subject: [PATCH 14/22] linspace/logspace: drop int(n_points) cast in favour of ty:ignore `jnp.linspace`/`jnp.logspace`'s `num` parameter is annotated `int` in JAX's stubs but accepts `jnp.int32` JAX scalars in eager mode (verified on cuda12). Pass `n_points: ScalarInt` through directly and silence the type-check mismatch at the single call site rather than materialising the JAX scalar to a Python int. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/coordinates.py | 4 ++-- src/lcm/grids/piecewise.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lcm/grids/coordinates.py b/src/lcm/grids/coordinates.py index 8f7045ad..dc95b40c 100644 --- a/src/lcm/grids/coordinates.py +++ b/src/lcm/grids/coordinates.py @@ -26,7 +26,7 @@ def linspace( endpoints. """ - return jnp.linspace(start, stop, int(n_points)) + return jnp.linspace(start, stop, n_points) # ty: ignore[no-matching-overload] @overload @@ -79,7 +79,7 @@ def logspace( """ start_linear = jnp.log(start) stop_linear = jnp.log(stop) - grid = jnp.logspace(start_linear, stop_linear, int(n_points), base=jnp.e) + grid = jnp.logspace(start_linear, stop_linear, n_points, base=jnp.e) # ty: ignore[invalid-argument-type] return grid.at[0].set(start).at[-1].set(stop) diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index fa286472..f7762d94 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -90,7 +90,7 @@ def n_points(self) -> ScalarInt: def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" piece_arrays = [ - jnp.linspace(self._piece_starts[i], self._piece_stops[i], int(p.n_points)) + jnp.linspace(self._piece_starts[i], self._piece_stops[i], p.n_points) # ty: ignore[no-matching-overload] for i, p in enumerate(self.pieces) ] return jnp.concatenate(piece_arrays) From f4069c1cbc698be99250d93bd3a4f12762fae839 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 16:49:52 +0200 Subject: [PATCH 15/22] Piecewise n_points: sum the cached _piece_n_points array Replace the Python `sum(generator, start=jnp.int32(0))` with a single `_piece_n_points.sum()` reduction. The cached `Int1D` is already populated by `_init_piecewise_grid_cache`, the property is read after `__post_init__`, and the result is the same `ScalarInt`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/piecewise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index f7762d94..43ff2a91 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -85,7 +85,7 @@ def __post_init__(self) -> None: @property def n_points(self) -> ScalarInt: """Return the total number of points in the grid.""" - return sum((p.n_points for p in self.pieces), start=jnp.int32(0)) + return self._piece_n_points.sum() def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" @@ -152,7 +152,7 @@ def __post_init__(self) -> None: @property def n_points(self) -> ScalarInt: """Return the total number of points in the grid.""" - return sum((p.n_points for p in self.pieces), start=jnp.int32(0)) + return self._piece_n_points.sum() def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" From bf12b619c1414fc86bb0b116834edf99acd8a061 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 16:58:29 +0200 Subject: [PATCH 16/22] benchmarks: bump aca-model pin to 67edfe0f Pull in the consumption-grid pinning, borrowing-constraint kink fix, and precision-workaround cleanups so the GPU benchmark CI runs the benchmark-aca-baseline kernel that aca-dev currently tracks. Co-Authored-By: Claude Opus 4.7 (1M context) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index cedae98d..9d89dfe9 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=67edfe0f54a305c23297f17ec53aee07b7d90496#67edfe0f54a305c23297f17ec53aee07b7d90496 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=67edfe0f54a305c23297f17ec53aee07b7d90496#67edfe0f54a305c23297f17ec53aee07b7d90496 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev247+gf2d18faf0.d20260508 - sha256: bb850950f6f17a2050320baee51568923488f2d994ebd9c4f5ec1232da3f9434 + version: 0.0.2.dev250+gf4069c1cb.d20260508 + sha256: bb7f61b5587c8a260a3b971cf0a5b9cbce6b176334917373a47b8496fce0d40f requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index f7f90756..62b42835 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "67edfe0f54a305c23297f17ec53aee07b7d90496" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 1bae7898c95a2f5f6bea73813c9bb7f8bc23afc3 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 17:06:23 +0200 Subject: [PATCH 17/22] simulate: keep period: int through the loop, cast at the JIT boundary The period_idx / period split was noisy: every loop iteration computed both a Python int (for dict-key indexing and `period in active_periods`) and a JAX scalar (for the JIT'd compute call). Drop the JAX-scalar shadow; iterate `for period, age in enumerate(ages.values)` once. `_simulate_regime_in_period(period: int)` keeps the integer through dict lookups and casts to `jnp.int32(period)` only at the `argmax_and_max_Q_over_a` / next-state JIT boundaries. Same pattern for transitions.py. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/simulation/simulate.py | 18 ++++++++---------- src/lcm/simulation/transitions.py | 8 ++++---- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index e07c03fd..12040623 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -112,8 +112,7 @@ def simulate( # Build reverse lookup for regime transition logging ids_to_names: dict[int, RegimeName] = {v: k for k, v in regime_names_to_ids.items()} - for period_idx, age in enumerate(ages.values): - period = jnp.int32(period_idx) + for period, age in enumerate(ages.values): period_start = time.monotonic() # Activate subjects whose starting period matches the current period @@ -129,13 +128,13 @@ def simulate( active_regimes = { regime_name: regime for regime_name, regime in internal_regimes.items() - if period_idx in regime.active_periods + if period in regime.active_periods } active_regimes_next_period = tuple( regime_name for regime_name, regime in internal_regimes.items() - if period_idx + 1 in regime.active_periods + if period + 1 in regime.active_periods ) log_period_header(logger=logger, age=age, n_active_regimes=len(active_regimes)) @@ -158,7 +157,7 @@ def simulate( ) ) states = new_states - simulation_results[regime_name][period_idx] = result + simulation_results[regime_name][period] = result log_nan_in_V( logger=logger, regime_name=regime_name, age=age, V_arr=result.V_arr @@ -201,7 +200,7 @@ def _simulate_regime_in_period( *, regime_name: RegimeName, internal_regime: InternalRegime, - period: ScalarInt, + period: int, age: ScalarInt | ScalarFloat, states: MappingProxyType[str, Array], subject_regime_ids: Int1D, @@ -255,16 +254,15 @@ def _simulate_regime_in_period( # We need to pass the value function array of the next period to the # argmax_and_max_Q_over_a function, as the current Q-function requires the # next period's value function. In the last period, we pass an empty dict. - period_idx = int(period) next_regime_to_V_arr = period_to_regime_to_V_arr.get( - period_idx + 1, MappingProxyType({}) + period + 1, MappingProxyType({}) ) # The Q-function values contain the information of how much value each # action combination is worth. To find the optimal discrete action, we # therefore only need to maximize the Q-function values over all actions. argmax_and_max_Q_over_a = ( - internal_regime.simulate_functions.argmax_and_max_Q_over_a[period_idx] + internal_regime.simulate_functions.argmax_and_max_Q_over_a[period] ) indices_optimal_actions, V_arr = argmax_and_max_Q_over_a( @@ -273,7 +271,7 @@ def _simulate_regime_in_period( **state_action_space.continuous_actions, next_regime_to_V_arr=next_regime_to_V_arr, **internal_params[regime_name], - period=period, + period=jnp.int32(period), age=age, ) validate_V(V_arr=V_arr, age=age, regime_name=regime_name) diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 00a5c174..74f54896 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -71,7 +71,7 @@ def calculate_next_states( *, internal_regime: InternalRegime, optimal_actions: MappingProxyType[ActionName, Array], - period: ScalarInt, + period: int, age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, states: MappingProxyType[str, Array], @@ -128,7 +128,7 @@ def calculate_next_states( **state_action_space.states, **optimal_actions, **stochastic_variables_keys, - period=period, + period=jnp.int32(period), age=age, **regime_params, ) @@ -149,7 +149,7 @@ def calculate_next_regime_membership( internal_regime: InternalRegime, state_action_space: StateActionSpace, optimal_actions: MappingProxyType[ActionName, Array], - period: ScalarInt, + period: int, age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, regime_names_to_ids: MappingProxyType[RegimeName, int], @@ -189,7 +189,7 @@ def calculate_next_regime_membership( internal_regime.simulate_functions.compute_regime_transition_probs( # ty: ignore[call-non-callable] **state_action_space.states, **optimal_actions, - period=period, + period=jnp.int32(period), age=age, **regime_params, ) From c419377686c1b111bbbf9b73906108f7c729bafd Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 17:56:37 +0200 Subject: [PATCH 18/22] benchmarks: bump aca-model pin to d9339ab Pulls in the aca-model CI workflow's matching pylcm pin so the GPU benchmark CI runs the same aca-model rev that aca-dev now tracks. Co-Authored-By: Claude Opus 4.7 (1M context) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 9d89dfe9..1df48761 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=67edfe0f54a305c23297f17ec53aee07b7d90496#67edfe0f54a305c23297f17ec53aee07b7d90496 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6#d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=67edfe0f54a305c23297f17ec53aee07b7d90496#67edfe0f54a305c23297f17ec53aee07b7d90496 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6#d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev250+gf4069c1cb.d20260508 - sha256: bb7f61b5587c8a260a3b971cf0a5b9cbce6b176334917373a47b8496fce0d40f + version: 0.0.2.dev254+g2f486dc36.d20260508 + sha256: ba7a3b94073af0c32e3b385199c34d2af9facf7a39cd8d04753b62726f23e044 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 62b42835..2e26b46a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "67edfe0f54a305c23297f17ec53aee07b7d90496" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 61c2436b67ecd9df1c70e80b770be77681c5df63 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 18:17:08 +0200 Subject: [PATCH 19/22] simulate orchestrates simulate-AOT compile, not solve MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit solve() no longer touches simulate-side compile state. simulate() is the sole driver: spawns the AOT compile in a background thread when n_subjects is set and the batch shape matches, then runs solve (if period_to_regime_to_V_arr is None) and awaits the future at the state-action-space dispatch point. Both public methods share an internal _solve_compiled() body for the snapshot/error handling. Drops _simulate_compile_future from instance state — the future lives in a local variable on the simulate() stack, so there's no per-process state to gate against. The lock keeps protecting _simulate_compile_cache and _warned_n_subjects; the rest of the "maybe spawn" logic collapses into a single inline check at the simulate() call site. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 176 ++++++++++++-------------- tests/simulation/test_simulate_aot.py | 13 +- 2 files changed, 84 insertions(+), 105 deletions(-) diff --git a/src/lcm/model.py b/src/lcm/model.py index fab1f499..27e0fb50 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -119,21 +119,9 @@ class Model: _warned_n_subjects: set[int] """Mismatching `actual_n_subjects` already warned about (one warning each).""" - _simulate_compile_future: ( - Future[MappingProxyType[RegimeName, InternalRegime]] | None - ) - """Pending background AOT compile started by `solve(...)`, or `None`. - - `solve(...)` kicks off `compile_all_simulate_functions` in a single - background thread so XLA compilation overlaps with the GPU-bound - backward induction. `simulate(...)` awaits the future before - dispatching the AOT-compiled program. Cleared after the result lands - in `_simulate_compile_cache`. - """ - _simulate_compile_lock: threading.Lock - """Serialises mutations of `_simulate_compile_cache`, `_warned_n_subjects`, - and `_simulate_compile_future`. + """Serialises mutations of `_simulate_compile_cache` and + `_warned_n_subjects`. The check-then-set on each container is held under this lock. The consequent `log.warning` call sits outside the lock so concurrent @@ -181,7 +169,6 @@ def __init__( self.n_subjects = n_subjects self._simulate_compile_cache = {} self._warned_n_subjects = set() - self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() validate_model_inputs( @@ -217,16 +204,13 @@ def __getstate__(self) -> dict[str, object]: Drops `_simulate_compile_lock` (a `threading.Lock`, not pickleable), `_simulate_compile_cache` (compiled XLA programs that can't survive - a process boundary), `_warned_n_subjects` (its companion set), and - `_simulate_compile_future` (a `Future` tied to the originating thread - pool). + a process boundary), and `_warned_n_subjects` (its companion set). `__setstate__` restores all three to their fresh state. """ state = self.__dict__.copy() state.pop("_simulate_compile_lock", None) state.pop("_simulate_compile_cache", None) state.pop("_warned_n_subjects", None) - state.pop("_simulate_compile_future", None) return state def __setstate__(self, state: dict[str, object]) -> None: @@ -234,7 +218,6 @@ def __setstate__(self, state: dict[str, object]) -> None: self.__dict__.update(state) self._simulate_compile_cache = {} self._warned_n_subjects = set() - self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() def get_params_template(self) -> UserFacingParamsTemplate: @@ -298,17 +281,34 @@ def solve( internal_params=internal_params, ages=self.ages, ) - self._maybe_start_simulate_compile_async( + return self._solve_compiled( internal_params=internal_params, + params=params, + log=get_logger(log_level=log_level), + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, max_compilation_workers=max_compilation_workers, - logger=get_logger(log_level=log_level), ) + + def _solve_compiled( + self, + *, + internal_params: InternalParams, + params: UserParams, + log: logging.Logger, + log_level: LogLevel, + log_path: str | Path | None, + log_keep_n_latest: int, + max_compilation_workers: int | None, + ) -> MappingProxyType[int, MappingProxyType[RegimeName, FloatND]]: + """Run backward induction, persisting a snapshot on debug or NaN failure.""" try: period_to_regime_to_V_arr = solve( internal_params=internal_params, ages=self.ages, internal_regimes=self.internal_regimes, - logger=get_logger(log_level=log_level), + logger=log, enable_jit=self.enable_jit, max_compilation_workers=max_compilation_workers, ) @@ -333,61 +333,55 @@ def solve( ) return period_to_regime_to_V_arr - def _maybe_start_simulate_compile_async( + def _spawn_simulate_compile( self, *, + n_subjects: int, internal_params: InternalParams, max_compilation_workers: int | None, logger: logging.Logger, - ) -> None: - """Spawn `compile_all_simulate_functions` in a background thread. + ) -> Future[MappingProxyType[RegimeName, InternalRegime]]: + """Submit `compile_all_simulate_functions` to a single-thread executor. - Called from `solve(...)` so the simulate-side XLA compilation runs in - parallel with the GPU-bound backward induction. No-op when - `n_subjects is None`, when the cache for this size is already - populated, or when a compile is already in flight. + Caller decides whether to spawn (`n_subjects` set, batch shape + matches, no cache hit). The returned `Future` runs in parallel with + whatever the caller does next — typically `_solve_compiled(...)`. """ - if self.n_subjects is None: - return - with self._simulate_compile_lock: - if self.n_subjects in self._simulate_compile_cache: - return - if self._simulate_compile_future is not None: - return - executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="lcm-simulate-compile" - ) - self._simulate_compile_future = executor.submit( - compile_all_simulate_functions, - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=self.n_subjects, - max_compilation_workers=max_compilation_workers, - logger=logger, - ) - executor.shutdown(wait=False) + executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="lcm-simulate-compile" + ) + future = executor.submit( + compile_all_simulate_functions, + internal_regimes=self.internal_regimes, + internal_params=internal_params, + ages=self.ages, + n_subjects=n_subjects, + max_compilation_workers=max_compilation_workers, + logger=logger, + ) + executor.shutdown(wait=False) + return future def _resolve_simulate_internal_regimes( self, *, + compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None, actual_n_subjects: int, - internal_params: InternalParams, log: logging.Logger, - max_compilation_workers: int | None, ) -> MappingProxyType[RegimeName, InternalRegime]: """Return internal_regimes to use for simulate; AOT cache when matching. - Three dispatch cases: + Dispatch by `n_subjects` and batch-shape match: - `n_subjects is None`: return the original `internal_regimes` (purely lazy path). - - `actual_n_subjects != n_subjects`: return the original - `internal_regimes` and log a warning the first time each - mismatching size is seen. - - `actual_n_subjects == n_subjects`: return the cached AOT-compiled - regimes. If `solve(...)` started a background compile, await it - here; otherwise compile synchronously. + - `actual_n_subjects != n_subjects`: warn once per mismatching size, + return the original `internal_regimes`. + - `actual_n_subjects == n_subjects`, `compile_future is not None`: + await it and cache the result. + - `actual_n_subjects == n_subjects`, `compile_future is None`: cache + must already hold the entry (caller spawned only on cache miss); + return the cached compiled regimes. """ if self.n_subjects is None: return self.internal_regimes @@ -404,28 +398,12 @@ def _resolve_simulate_internal_regimes( self.n_subjects, ) return self.internal_regimes - with self._simulate_compile_lock: - if self.n_subjects in self._simulate_compile_cache: - return self._simulate_compile_cache[self.n_subjects] - future = self._simulate_compile_future - if future is not None: - compiled = future.result() + if compile_future is not None: + compiled = compile_future.result() with self._simulate_compile_lock: self._simulate_compile_cache[self.n_subjects] = compiled - self._simulate_compile_future = None return compiled with self._simulate_compile_lock: - if self.n_subjects not in self._simulate_compile_cache: - self._simulate_compile_cache[self.n_subjects] = ( - compile_all_simulate_functions( - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=self.n_subjects, - max_compilation_workers=max_compilation_workers, - logger=log, - ) - ) return self._simulate_compile_cache[self.n_subjects] def simulate( @@ -508,33 +486,35 @@ def simulate( ages=self.ages, ) log = get_logger(log_level=log_level) - if period_to_regime_to_V_arr is None: - try: - period_to_regime_to_V_arr = solve( + actual_n_subjects = len(next(iter(initial_conditions.values()))) + n_subjects = self.n_subjects + compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None = ( + None + ) + if n_subjects is not None and n_subjects == actual_n_subjects: + with self._simulate_compile_lock: + needs_compile = n_subjects not in self._simulate_compile_cache + if needs_compile: + compile_future = self._spawn_simulate_compile( + n_subjects=n_subjects, internal_params=internal_params, - ages=self.ages, - internal_regimes=self.internal_regimes, - logger=log, - enable_jit=self.enable_jit, max_compilation_workers=max_compilation_workers, + logger=log, ) - except InvalidValueFunctionError as exc: - if log_path is not None and exc.partial_solution is not None: - snap_dir = save_solve_snapshot( - model=self, - params=params, - period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] - log_path=Path(log_path), - log_keep_n_latest=log_keep_n_latest, - ) - exc.add_note(f"Snapshot saved to {snap_dir}") - raise - actual_n_subjects = len(next(iter(initial_conditions.values()))) + if period_to_regime_to_V_arr is None: + period_to_regime_to_V_arr = self._solve_compiled( + internal_params=internal_params, + params=params, + log=log, + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, + max_compilation_workers=max_compilation_workers, + ) simulate_internal_regimes = self._resolve_simulate_internal_regimes( + compile_future=compile_future, actual_n_subjects=actual_n_subjects, - internal_params=internal_params, log=log, - max_compilation_workers=max_compilation_workers, ) result = simulate( internal_params=internal_params, diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 660d99db..77f01814 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -156,21 +156,20 @@ def test_simulate_first_matching_call_populates_aot_cache() -> None: assert n_subjects in model._simulate_compile_cache -def test_solve_with_n_subjects_kicks_off_background_simulate_compile() -> None: - """`solve(...)` spawns the simulate AOT compile in the background. +def test_solve_does_not_populate_simulate_compile_cache() -> None: + """`solve(...)` does not touch simulate-side compile state. - The follow-on `simulate(...)` then awaits the in-flight `Future` instead - of compiling synchronously, so XLA compilation overlaps with the - GPU-bound backward induction in production. + Simulate AOT compilation is driven entirely by `simulate(...)`; calling + `solve(...)` alone leaves `_simulate_compile_cache` empty. """ n_periods = 3 n_subjects = 4 model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) params = get_params(n_periods=n_periods) - assert model._simulate_compile_future is None model.solve(params=params) - assert model._simulate_compile_future is not None + + assert dict(model._simulate_compile_cache) == {} _DECLARED_N = 4 From 1deed362ea2559b0bb21f432ce178c865caa8eb7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 9 May 2026 13:27:11 +0200 Subject: [PATCH 20/22] tests: drop noqa: ARG001 + collapse x64-fixture signatures Move the ARG001 ignore for the x64_disabled / x64_enabled fixture pattern into pyproject.toml's per-file-ignores for test_dtypes.py and test_float_dtype_invariants.py, then drop the per-call noqa comments and the now-redundant -> None return annotations (tests/* already ignores ANN). Single-arg signatures collapse to one line; longer ones stay wrapped, but without the trailing comma noise. Co-Authored-By: Claude Opus 4.7 (1M context) --- pixi.lock | 4 +-- pyproject.toml | 3 +++ tests/test_dtypes.py | 40 +++++++++------------------- tests/test_float_dtype_invariants.py | 39 ++++++++++++--------------- 4 files changed, 34 insertions(+), 52 deletions(-) diff --git a/pixi.lock b/pixi.lock index 1df48761..7507bb3e 100644 --- a/pixi.lock +++ b/pixi.lock @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev254+g2f486dc36.d20260508 - sha256: ba7a3b94073af0c32e3b385199c34d2af9facf7a39cd8d04753b62726f23e044 + version: 0.0.2.dev256+g61c2436b6.d20260509 + sha256: 80f8e8823a9b7d58cdd78b45377e426887832da149342b1563ba0bfbe91653ea requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 2e26b46a..b810d68c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,6 +242,9 @@ per-file-ignores."tests/*" = [ "S301", # Use of pickle "SLF001", # Private member access ] +per-file-ignores."tests/test_dtypes.py" = [ + "ARG001", # Unused function argument (x64_enabled / x64_disabled fixtures) +] per-file-ignores."tests/test_float_dtype_invariants.py" = [ "ARG001", # Unused function argument (x64_disabled fixture) ] diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 4106f47c..8121a021 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -12,7 +12,7 @@ [7, np.asarray([0, 1, -3], dtype=np.int64)], ids=["python-int", "int64-array"], ) -def test_safe_to_int_dtype_returns_int32(value: object) -> None: +def test_safe_to_int_dtype_returns_int32(value: object): """`safe_to_int_dtype` returns a `jnp.int32` array for any in-range int input.""" out = safe_to_int_dtype(value, name="x") assert out.dtype == jnp.int32 @@ -26,21 +26,19 @@ def test_safe_to_int_dtype_returns_int32(value: object) -> None: ], ids=["python-int", "int64-array"], ) -def test_safe_to_int_dtype_preserves_in_range_values( - value: object, expected: object -) -> None: +def test_safe_to_int_dtype_preserves_in_range_values(value: object, expected: object): """`safe_to_int_dtype` preserves element values for in-range inputs.""" out = safe_to_int_dtype(value, name="x") np.testing.assert_array_equal(np.asarray(out), expected) -def test_safe_to_int_dtype_raises_on_python_int_overflow() -> None: +def test_safe_to_int_dtype_raises_on_python_int_overflow(): """A Python int above int32 max raises `ValueError` naming the leaf.""" with pytest.raises(ValueError, match="my_param"): safe_to_int_dtype(2**32, name="my_param") -def test_safe_to_int_dtype_raises_on_array_overflow() -> None: +def test_safe_to_int_dtype_raises_on_array_overflow(): """An int64 array containing values above int32 max raises with the leaf name.""" # Use numpy here: `jnp.asarray(..., dtype=jnp.int64)` truncates to int32 # under `jax_enable_x64=False` and trips JAX's own overflow guard before @@ -50,38 +48,30 @@ def test_safe_to_int_dtype_raises_on_array_overflow() -> None: safe_to_int_dtype(arr, name="regime") -def test_safe_to_int_dtype_raises_on_underflow() -> None: +def test_safe_to_int_dtype_raises_on_underflow(): """A Python int below int32 min raises `ValueError` naming the leaf.""" with pytest.raises(ValueError, match="offset"): safe_to_int_dtype(-(2**40), name="offset") -def test_canonical_float_dtype_is_float32_under_no_x64( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_canonical_float_dtype_is_float32_under_no_x64(x64_disabled: None): """`canonical_float_dtype()` is `float32` when `jax_enable_x64=False`.""" assert canonical_float_dtype() == jnp.float32 -def test_canonical_float_dtype_is_float64_under_x64( - x64_enabled: None, # noqa: ARG001 -) -> None: +def test_canonical_float_dtype_is_float64_under_x64(x64_enabled: None): """`canonical_float_dtype()` is `float64` when `jax_enable_x64=True`.""" assert canonical_float_dtype() == jnp.float64 -def test_safe_to_float_dtype_casts_python_float_to_canonical( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_safe_to_float_dtype_casts_python_float_to_canonical(x64_disabled: None): """A Python float lands at `float32` under no-x64.""" out = safe_to_float_dtype(0.5, name="x") assert out.dtype == jnp.float32 assert float(out) == 0.5 -def test_safe_to_float_dtype_casts_float64_array_to_float32( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_safe_to_float_dtype_casts_float64_array_to_float32(x64_disabled: None): """A `float64` array within float32 range is downcast to `float32`. Build the input with `np.asarray` rather than `jnp.asarray` — under @@ -95,27 +85,21 @@ def test_safe_to_float_dtype_casts_float64_array_to_float32( assert out.dtype == jnp.float32 -def test_safe_to_float_dtype_passes_array_through_under_x64( - x64_enabled: None, # noqa: ARG001 -) -> None: +def test_safe_to_float_dtype_passes_array_through_under_x64(x64_enabled: None): """Under x64, a `float64` array is preserved (no down-cast required).""" arr = jnp.asarray([0.1, 0.2, 0.3], dtype=jnp.float64) out = safe_to_float_dtype(arr, name="x") assert out.dtype == jnp.float64 -def test_safe_to_float_dtype_raises_on_overflow_when_downcasting( - x64_disabled: None, # noqa: ARG001 -) -> None: +def test_safe_to_float_dtype_raises_on_overflow_when_downcasting(x64_disabled: None): """A `float64` value above float32 max raises `OverflowError`, naming the leaf.""" big = 1e40 with pytest.raises(OverflowError, match="big_param"): safe_to_float_dtype(big, name="big_param") -def test_safe_to_float_dtype_no_overflow_check_when_upcasting( - x64_enabled: None, # noqa: ARG001 -) -> None: +def test_safe_to_float_dtype_no_overflow_check_when_upcasting(x64_enabled: None): """Casting `float32` -> `float64` (up) skips the overflow check.""" arr = jnp.asarray([0.1, 0.2], dtype=jnp.float32) out = safe_to_float_dtype(arr, name="x") diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py index d2367661..068ccbbb 100644 --- a/tests/test_float_dtype_invariants.py +++ b/tests/test_float_dtype_invariants.py @@ -20,9 +20,7 @@ ) -def test_build_initial_states_casts_user_float64_to_canonical( - x64_disabled: None, -) -> None: +def test_build_initial_states_casts_user_float64_to_canonical(x64_disabled: None): """A float64 continuous initial state lands at `canonical_float_dtype()`.""" model = get_model(n_periods=3) initial_states = { @@ -36,7 +34,7 @@ def test_build_initial_states_casts_user_float64_to_canonical( assert flat["working_life__wealth"].dtype == canonical_float_dtype() -def test_build_initial_states_casts_user_int_to_canonical(x64_disabled: None) -> None: +def test_build_initial_states_casts_user_int_to_canonical(x64_disabled: None): """A continuous initial state given as int32 lands at `canonical_float_dtype()`.""" model = get_model(n_periods=3) initial_states = { @@ -52,7 +50,7 @@ def test_build_initial_states_casts_user_int_to_canonical(x64_disabled: None) -> def test_build_initial_states_missing_continuous_fallback_dtype_is_canonical( x64_disabled: None, -) -> None: +): """A missing continuous state falls back to a canonical-dtype array.""" model = get_model(n_periods=3) # Supply a placeholder state to set n_subjects without touching `wealth`. @@ -65,7 +63,7 @@ def test_build_initial_states_missing_continuous_fallback_dtype_is_canonical( def test_build_initial_states_missing_continuous_fallback_values_are_nan( x64_disabled: None, -) -> None: +): """A missing continuous state falls back to an all-NaN array. Pinning only the dtype would let a regression that fills the fallback @@ -81,7 +79,7 @@ def test_build_initial_states_missing_continuous_fallback_values_are_nan( def test_process_params_casts_float64_array_to_canonical_under_no_x64( x64_disabled: None, -) -> None: +): """A `float64` array param is downcast to `float32` under `jax_enable_x64=False`. Build with `np.asarray` rather than `jnp.asarray` — the JAX builder @@ -102,7 +100,7 @@ def test_process_params_casts_float64_array_to_canonical_under_no_x64( assert schedule.dtype == jnp.float32 -def test_process_params_casts_python_float_to_canonical(x64_disabled: None) -> None: +def test_process_params_casts_python_float_to_canonical(x64_disabled: None): """A Python `float` param leaf is cast to `canonical_float_dtype()`.""" template = MappingProxyType( {"regime_a": MappingProxyType({"discount_factor": "float"})} @@ -121,7 +119,7 @@ def test_process_params_casts_python_float_to_canonical(x64_disabled: None) -> N def test_process_params_float_array_overflow_raises_with_qualified_name( x64_disabled: None, -) -> None: +): """An out-of-float32 float64 array raises naming the qualified leaf.""" template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) user_params = {"regime_a": {"schedule": np.asarray([0.0, 1e40], dtype=np.float64)}} @@ -133,7 +131,7 @@ def test_process_params_float_array_overflow_raises_with_qualified_name( ) -def test_simulate_state_pool_dtype_stable_across_periods(x64_disabled: None) -> None: +def test_simulate_state_pool_dtype_stable_across_periods(x64_disabled: None): """A multi-period simulate keeps every state's dtype stable across periods. The intended invariant is per-state stability; failing on any single @@ -162,7 +160,7 @@ def test_simulate_state_pool_dtype_stable_across_periods(x64_disabled: None) -> assert not drifted, f"States drifted across periods: {drifted}" -def test_solve_v_arrays_at_canonical_float_dtype(x64_disabled: None) -> None: +def test_solve_v_arrays_at_canonical_float_dtype(x64_disabled: None): """Every V-array returned by `model.solve()` is at `canonical_float_dtype()`.""" model = get_model(n_periods=3) period_to_regime_to_V_arr = model.solve(params=get_params(n_periods=3)) @@ -188,7 +186,7 @@ def test_solve_v_arrays_at_canonical_float_dtype(x64_disabled: None) -> None: def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( make_grid: Callable[[], LinSpacedGrid | LogSpacedGrid | IrregSpacedGrid], x64_disabled: None, -) -> None: +): """Continuous grid `to_jax()` materialises at `float32` under no-x64. Asserts the concrete target dtype rather than `canonical_float_dtype()` @@ -207,9 +205,8 @@ def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( @pytest.mark.parametrize("attr", ["start", "stop"]) def test_uniform_grid_stores_endpoints_as_canonical_jax_scalar( - attr: str, - x64_disabled: None, -) -> None: + attr: str, x64_disabled: None +): """`LinSpacedGrid` stores `start`/`stop` as JAX scalars at canonical dtype.""" grid = LinSpacedGrid(start=0.0, stop=100.0, n_points=10) value = getattr(grid, attr) @@ -217,7 +214,7 @@ def test_uniform_grid_stores_endpoints_as_canonical_jax_scalar( assert value.dtype == canonical_float_dtype() -def test_irreg_grid_stores_points_as_canonical_jax_array(x64_disabled: None) -> None: +def test_irreg_grid_stores_points_as_canonical_jax_array(x64_disabled: None): """`IrregSpacedGrid` stores `points` as a JAX array at canonical dtype.""" grid = IrregSpacedGrid(points=(0.0, 0.5, 1.0)) assert isinstance(grid.points, jnp.ndarray) @@ -226,9 +223,8 @@ def test_irreg_grid_stores_points_as_canonical_jax_array(x64_disabled: None) -> @pytest.mark.parametrize("key", ["low", "high"]) def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( - key: str, - x64_disabled: None, -) -> None: + key: str, x64_disabled: None +): """`MappingLeaf` float arrays land at `canonical_float_dtype()`.""" template = MappingProxyType( {"regime_a": MappingProxyType({"sched": "MappingLeaf"})} @@ -257,9 +253,8 @@ def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( @pytest.mark.parametrize("index", [0, 1]) def test_process_params_casts_float_array_inside_sequence_leaf_to_canonical( - index: int, - x64_disabled: None, -) -> None: + index: int, x64_disabled: None +): """`SequenceLeaf` float arrays land at `canonical_float_dtype()`.""" template = MappingProxyType( {"regime_a": MappingProxyType({"sched": "SequenceLeaf"})} From 00f3b4ae4c131c9317d4dfd9db1e54de1bace07d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 9 May 2026 13:28:23 +0200 Subject: [PATCH 21/22] test_next_state: pass JAX scalars instead of ty:ignore-ing Python ones `period=1, age=1.0, **flat_regime_params={...float...}` was suppressed with `# ty: ignore[invalid-argument-type]` to keep the call site short. Once `ScalarInt` / `ScalarFloat` tightened to JAX-only, the fix is to pass `jnp.int32(1)` / `jnp.asarray(1.0)` (and to wrap the float param leaves in `jnp.asarray`). The ignore comments come out and the call site genuinely type-checks. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_next_state.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 778bc195..ccb79b39 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -37,19 +37,19 @@ def test_get_next_state_function_with_solve_target(): ) flat_regime_params = { - "discount_factor": 1.0, - "utility__disutility_of_work": 1.0, - "next_wealth__interest_rate": 0.05, + "discount_factor": jnp.asarray(1.0), + "utility__disutility_of_work": jnp.asarray(1.0), + "next_wealth__interest_rate": jnp.asarray(0.05), } - action = {"labor_supply": 1, "consumption": 10} - state = {"wealth": 20} + action = {"labor_supply": jnp.asarray(1), "consumption": jnp.asarray(10.0)} + state = {"wealth": jnp.asarray(20.0)} got = got_func( **action, - **state, # ty: ignore[invalid-argument-type] - period=1, # ty: ignore[invalid-argument-type] - age=1.0, # ty: ignore[invalid-argument-type] - **flat_regime_params, # ty: ignore[invalid-argument-type] + **state, + period=jnp.int32(1), + age=jnp.asarray(1.0), + **flat_regime_params, ) assert got == {"next_wealth": 1.05 * (20 - 10)} From ca66ba9ba140907aecdae4495b7a76c00c93ea41 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 9 May 2026 15:28:33 +0200 Subject: [PATCH 22/22] simulate: swap AOT-compiled regimes for lazy ones on the result MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `SimulationResult.to_pickle()` (and any cloudpickle.dumps on the result) hit `cannot pickle 'jaxlib._jax.LoadedExecutable'` when the result carried the AOT-compiled `internal_regimes`. The compiled callables (`argmax_and_max_Q_over_a`, `next_state`, `compute_regime_transition_probs`) wrap a `LoadedExecutable` that can't survive a process boundary. `to_dataframe` only reads `simulate_functions.functions / constraints / transitions / stochastic_transition_names` — none of which the AOT pass replaces. So after `simulate(...)` runs, the result has no use for the compiled callables: `model.simulate()` swaps them out for the lazy `self.internal_regimes` before returning. Add a TDD test that round-trips the result through cloudpickle under `n_subjects` matching, which is the failure mode pytask hit on HPC. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 7 +++++++ tests/simulation/test_simulate_aot.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/lcm/model.py b/src/lcm/model.py index 27e0fb50..23104c5b 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -527,6 +527,13 @@ def simulate( simulation_output_dtypes=self.simulation_output_dtypes, seed=seed, ) + # AOT-compiled regimes carry `jax.stages.Compiled` callables that + # wrap an unpicklable `LoadedExecutable`. `to_dataframe` only reads + # the lazy DAG functions / constraints / transitions on + # `simulate_functions`, never the compiled callables — so swap in + # the lazy regimes to keep the result cloudpickle-safe. + if simulate_internal_regimes is not self.internal_regimes: + result._internal_regimes = self.internal_regimes # noqa: SLF001 if log_level == "debug" and log_path is not None: save_simulate_snapshot( model=self, diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 77f01814..698256e4 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -269,6 +269,30 @@ def test_simulate_warns_only_once_per_mismatching_size( assert len(mismatch_warnings) == 1 +def test_simulate_result_pickles_when_n_subjects_matches() -> None: + """`simulate(...)` returns a result that round-trips through cloudpickle. + + With `n_subjects` matching the batch shape, the simulate path runs + AOT-compiled callables that wrap `LoadedExecutable` (unpicklable). + `to_dataframe` doesn't need those callables, so the returned result + must carry the lazy regimes — otherwise downstream pickling + (e.g. pytask handing the result to the next task) fails. + """ + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + + result = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=n_subjects), + ) + + restored = cloudpickle.loads(cloudpickle.dumps(result)) + assert restored.n_subjects == n_subjects + + def test_unpickled_model_can_simulate_with_aot() -> None: """A cloudpickle round-tripped `Model` still drives `simulate(...)` with AOT.""" n_periods = 3