From 1c65f45dd8b2153bf794223ecf952f8a6f5e31bc Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 06:29:29 +0200 Subject: [PATCH 01/80] Support runtime-supplied points on continuous-action IrregSpacedGrids MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the existing runtime-points mechanism (previously state-only) to continuous action grids. With this change, an action declared as `IrregSpacedGrid(n_points=N)` adds an `{action_name: {"points": "Float1D"}}` entry to the regime params template, and `state_action_space()` substitutes the runtime-supplied points into `continuous_actions` at solve / simulate time. Motivation: aca-dev's structural retirement model has a `consumption` action grid whose lower bound is the per-iteration `consumption_floor` parameter. Without this change the c-grid bounds would have to be fixed at build time, which forces either an over-wide grid (wasted density) or model rebuilds per estimation iteration (recompilation). Mirrors the existing state-grid treatment: - `regime_template.py`: walks `regime.actions` alongside `regime.states`, factoring the shared shadowing check into helpers. - `interfaces.InternalRegime.state_action_space()`: builds both state and continuous-action replacements in a single sweep over `self.grids`, then calls `_base_state_action_space.replace(...)` with whichever side actually had substitutions. - `pandas_utils._is_runtime_grid_param`: also recognises action grids so column extraction in `to_dataframe()` keeps working. Tests (TDD): four new tests in `tests/test_runtime_params.py`, mirroring the state-grid counterparts — params-template entry, solve, runtime-vs-fixed equivalence, and a sanity check that varying runtime points actually changes V. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/interfaces.py | 68 ++++++++++++++++------- src/lcm/pandas_utils.py | 15 +++--- src/lcm/params/regime_template.py | 52 +++++++++++++----- tests/test_runtime_params.py | 90 +++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 39 deletions(-) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 9ff454b2..6c80ddb9 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -242,12 +242,12 @@ class InternalRegime: """Flat resolved fixed params for this regime, used by to_dataframe targets.""" def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpace: - """Return the state-action space with runtime state grids filled in. + """Return the state-action space with runtime grids filled in. - For IrregSpacedGrid with runtime-supplied points, the grid points come from - params as `{state_name}__points`. For _ShockGrid with runtime-supplied params, - the grid points are computed from shock params in the params dict or - resolved_fixed_params. + For IrregSpacedGrid (state or continuous action) with runtime-supplied + points, the grid points come from params as `{name}__points`. For + `_ShockGrid` with runtime-supplied params, the grid points are computed + from shock params in the params dict or `resolved_fixed_params`. Args: regime_params: Flat regime parameters supplied at runtime. @@ -257,35 +257,63 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac """ all_params = {**self.resolved_fixed_params, **regime_params} - replacements: dict[str, ContinuousState | DiscreteState] = {} - for state_name, spec in self.grids.items(): - if state_name not in self._base_state_action_space.states: + state_replacements: dict[str, ContinuousState | DiscreteState] = {} + action_replacements: dict[str, ContinuousAction] = {} + for name, spec in self.grids.items(): + in_states = name in self._base_state_action_space.states + in_continuous_actions = ( + name in self._base_state_action_space.continuous_actions + ) + if not (in_states or in_continuous_actions): continue if isinstance(spec, IrregSpacedGrid) and spec.pass_points_at_runtime: - points_key = f"{state_name}__points" + points_key = f"{name}__points" if points_key not in all_params: continue - replacements[state_name] = cast( - "ContinuousState", all_params[points_key] - ) - elif isinstance(spec, _ShockGrid) and spec.params_to_pass_at_runtime: + if in_states: + state_replacements[name] = cast( + "ContinuousState", all_params[points_key] + ) + else: + action_replacements[name] = cast( + "ContinuousAction", all_params[points_key] + ) + elif ( + in_states + and isinstance(spec, _ShockGrid) + and spec.params_to_pass_at_runtime + ): all_present = all( - f"{state_name}__{p}" in all_params - for p in spec.params_to_pass_at_runtime + f"{name}__{p}" in all_params for p in spec.params_to_pass_at_runtime ) if not all_present: continue shock_kw: dict[str, float] = dict(spec.params) for p in spec.params_to_pass_at_runtime: - shock_kw[p] = cast("float", all_params[f"{state_name}__{p}"]) - replacements[state_name] = spec.compute_gridpoints(**shock_kw) + shock_kw[p] = cast("float", all_params[f"{name}__{p}"]) + state_replacements[name] = spec.compute_gridpoints(**shock_kw) - if not replacements: + if not state_replacements and not action_replacements: return self._base_state_action_space - new_states = dict(self._base_state_action_space.states) | replacements + new_states = ( + MappingProxyType( + dict(self._base_state_action_space.states) | state_replacements + ) + if state_replacements + else None + ) + new_continuous_actions = ( + MappingProxyType( + dict(self._base_state_action_space.continuous_actions) + | action_replacements + ) + if action_replacements + else None + ) return self._base_state_action_space.replace( - states=MappingProxyType(new_states) + states=new_states, + continuous_actions=new_continuous_actions, ) diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 805932e2..2d696a45 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -504,12 +504,15 @@ def _resolve_per_target_template_key( def _is_runtime_grid_param(*, func_name: FunctionName, regime: Regime) -> bool: """Check if a template function key refers to a runtime grid param.""" - if func_name not in regime.states: - return False - grid = regime.states[func_name] - return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or ( - isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime) - ) + if func_name in regime.states: + grid = regime.states[func_name] + return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or ( + isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime) + ) + if func_name in regime.actions: + grid = regime.actions[func_name] + return isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime + return False def _fail_if_period_level(sr: pd.Series) -> None: diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 299b2c27..2b9c748e 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -70,27 +70,53 @@ def create_regime_params_template( _validate_no_shadowing(function_params, regime) + _add_runtime_grid_params(function_params, regime) + + return MappingProxyType( + {k: MappingProxyType(v) for k, v in function_params.items()} + ) + + +def _add_runtime_grid_params( + function_params: dict[FunctionName, dict[str, str]], + regime: Regime, +) -> None: + """Add runtime-supplied state/action grid params to the template in place.""" for state_name, grid in regime.states.items(): if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: - if state_name in function_params: - raise InvalidNameError( - f"IrregSpacedGrid state '{state_name}' (with runtime-supplied " - f"points) conflicts with a function of the same name in the regime." - ) + _fail_if_runtime_grid_shadows_function( + function_params=function_params, name=state_name, kind="state" + ) function_params[state_name] = {"points": "Float1D"} elif isinstance(grid, _ShockGrid) and grid.params_to_pass_at_runtime: - if state_name in function_params: - raise InvalidNameError( - f"_ShockGrid state '{state_name}' (with runtime-supplied params) " - f"conflicts with a function of the same name in the regime." - ) + _fail_if_runtime_grid_shadows_function( + function_params=function_params, + name=state_name, + kind="_ShockGrid state", + ) function_params[state_name] = dict.fromkeys( grid.params_to_pass_at_runtime, "float" ) - return MappingProxyType( - {k: MappingProxyType(v) for k, v in function_params.items()} - ) + for action_name, grid in regime.actions.items(): + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + _fail_if_runtime_grid_shadows_function( + function_params=function_params, name=action_name, kind="action" + ) + function_params[action_name] = {"points": "Float1D"} + + +def _fail_if_runtime_grid_shadows_function( + *, + function_params: dict[FunctionName, dict[str, str]], + name: str, + kind: str, +) -> None: + if name in function_params: + raise InvalidNameError( + f"IrregSpacedGrid {kind} '{name}' (with runtime-supplied " + f"points/params) conflicts with a function of the same name in the regime." + ) def _collect_all_functions_for_template( diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index 9d588649..ceaee4cf 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -141,3 +141,93 @@ def test_runtime_grid_matches_fixed(): for period in V_fixed: if "alive" in V_fixed[period] and "alive" in V_runtime[period]: aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) + + +def _make_action_grid_model(*, consumption_grid): + """Create a 2-regime model where consumption is the runtime-points action grid.""" + alive = Regime( + functions={"utility": _utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": consumption_grid}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_runtime_action_grid_in_params_template(): + """IrregSpacedGrid action with runtime-supplied points adds 'points' to template.""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + alive_template = model._params_template["alive"] + assert "consumption" in alive_template + assert "points" in alive_template["consumption"] + + +def test_solve_with_runtime_action_grid(): + """Solve should work when action grid points are provided via params.""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}, + } + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + assert len(period_to_regime_to_V_arr) > 0 + + +def test_runtime_action_grid_matches_fixed(): + """Runtime action grid with same points gives same V as a fixed action grid.""" + points = jnp.linspace(0.1, 5.0, 5) + + model_fixed = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(points=list(points.tolist())), + ) + params_fixed = {"discount_factor": 0.95, "interest_rate": 0.05} + V_fixed = model_fixed.solve(params=params_fixed, log_level="off") + + model_runtime = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params_runtime = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": points}}, + } + V_runtime = model_runtime.solve(params=params_runtime, log_level="off") + + for period in V_fixed: + if "alive" in V_fixed[period] and "alive" in V_runtime[period]: + aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) + + +def test_runtime_action_grid_changes_solution(): + """Different runtime action points should yield different V (sanity check).""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + base = {"discount_factor": 0.95, "interest_rate": 0.05} + V_low = model.solve( + params=base | {"alive": {"consumption": {"points": jnp.linspace(0.1, 1.0, 5)}}}, + log_level="off", + ) + V_high = model.solve( + params=base | {"alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}}, + log_level="off", + ) + # Period 0 alive value should differ when the action support differs + assert not jnp.allclose(V_low[0]["alive"], V_high[0]["alive"]) From 1792279387abea0004821b5cf706ee60e9752f7c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 07:24:08 +0200 Subject: [PATCH 02/80] benchmarks: bump aca-model to runtime-consumption-points version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit aca-model now declares `consumption` as `IrregSpacedGrid(n_points=N)` with runtime-supplied points. The bench builder now passes `model=model` to `get_benchmark_params` so consumption gridpoints are injected into params before solving. aca-model rev: adc8a19 → 4123fe9 (feature/runtime-consumption-points) Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmarks/bench_aca_baseline.py | 2 +- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 3ef4c9b2..477d5635 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -55,7 +55,7 @@ def _build() -> tuple[object, object, object]: ) model = create_benchmark_model() - _, model_params = get_benchmark_params() + _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 ) diff --git a/pixi.lock b/pixi.lock index 5b3f6f2e..3f235281 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=adc8a19328608781a5cb2a65ab2d93d580163aae#adc8a19328608781a5cb2a65ab2d93d580163aae + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=adc8a19328608781a5cb2a65ab2d93d580163aae#adc8a19328608781a5cb2a65ab2d93d580163aae +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev142+g95ab1b648.d20260423 - sha256: ae4d75e092f6528909d9f185ed13eccc4aabeae96f5c6d41987cbb2704afa7e7 + version: 0.0.2.dev125+g1c65f45dd.d20260429 + sha256: 9d509ac58b2af5658d439c586c2971ead9988f34017c4b0b64c4dd6db51b27aa requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index ced35ccc..7efd8a17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "adc8a19328608781a5cb2a65ab2d93d580163aae" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "4123fe9739c1c4bccebaa149985d0415a4272ef1" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From cf00e99b277e95f903eb3f2b782a288297ec8a63 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 13:55:45 +0200 Subject: [PATCH 03/80] Fail loudly when reading runtime IrregSpacedGrid before substitution `IrregSpacedGrid(n_points=N)` declares a continuous grid whose values are supplied at runtime via `params[regime][grid_name]['points']`. Substitution happens inside `InternalRegime.state_action_space(regime_params=...)` at solve / simulate time. Any code path that calls `to_jax()` on the base grid before substitution silently got `jnp.full(N, jnp.nan)` and went on to compute against the placeholder. That is exactly what fired in `validate_initial_conditions` for `task_simulate_aca`: the validator built the action grid by calling `internal_regime.grids[name].to_jax()` (placeholder NaNs), then asked `borrowing_constraint(consumption=NaN, wealth=W)` whether each gridpoint was feasible. NaN comparisons are False, so every action was reported infeasible for every subject in every initial regime. Make the invariant explicit: `IrregSpacedGrid.to_jax()` raises `GridInitializationError` for runtime-supplied grids, with a message pointing the caller at `state_action_space(regime_params=...)` for real values or `.n_points` for shape. Confine the legitimate "placeholder needed for AOT tracing" caller (the base state-action space) to a private helper in `state_action_space.py` that uses NaN explicitly. Reroute `_check_regime_feasibility` through the substituted state-action space. Add regression tests covering both runtime action and runtime state grids round-tripping `simulate(check_initial_conditions=True)`, and unit tests pinning down the new raise + the existing NaN-source mechanics in `map_coordinates` / `get_irreg_coordinate`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 19 +- src/lcm/simulation/initial_conditions.py | 21 +- src/lcm/state_action_space.py | 23 +- tests/test_runtime_params.py | 10 +- tests/test_single_feasible_action.py | 562 +++++++++++++++++++++++ 5 files changed, 624 insertions(+), 11 deletions(-) create mode 100644 tests/test_single_feasible_action.py diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index c81f9c6a..db646ca2 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -188,9 +188,24 @@ def pass_points_at_runtime(self) -> bool: return self.points is None def to_jax(self) -> Float1D: - """Convert the grid to a Jax array.""" + """Convert the grid to a Jax array. + + Raises `GridInitializationError` for runtime-supplied grids + (`pass_points_at_runtime=True`). Substitution happens at solve / + simulate time via `InternalRegime.state_action_space(regime_params=...)`; + any code path that reads the base grid's points before substitution is + a bug. + """ if self.points is None: - return jnp.full(self.n_points, jnp.nan) + raise GridInitializationError( + f"IrregSpacedGrid was declared with n_points={self.n_points} " + f"and no points; values are supplied at runtime via " + f"params['']['']['points']. Reading the grid " + f"before substitution is a bug — call " + f"`internal_regime.state_action_space(regime_params=...)` and " + f"read points from there, or use `.n_points` if only the shape " + f"is needed." + ) return jnp.asarray(self.points) @overload diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 1be89309..2cdd3c3b 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Mapping, Sequence from types import MappingProxyType -from typing import Never +from typing import Never, cast import jax import numpy as np @@ -25,6 +25,7 @@ from lcm.regime_building.Q_and_F import _get_feasibility from lcm.typing import ( ActionName, + FlatRegimeParams, InternalParams, RegimeName, RegimeNamesToIds, @@ -579,11 +580,21 @@ def _check_regime_feasibility( # noqa: C901 if not action_names: return None - grids = MappingProxyType( - {name: spec.to_jax() for name, spec in internal_regime.grids.items()} + # Build the state-action space with runtime-supplied grid points + # substituted. The base grid's `to_jax()` raises for runtime-supplied + # `IrregSpacedGrid`s declared with `pass_points_at_runtime=True`, so the + # validator must read points from `state_action_space(regime_params=...)`. + state_action_space = internal_regime.state_action_space( + regime_params=cast("FlatRegimeParams", MappingProxyType(dict(regime_params))), + ) + action_grids: dict[str, Array] = { + **state_action_space.discrete_actions, + **state_action_space.continuous_actions, + } + flat_actions = _build_flat_action_grid( + action_names=action_names, + grids=MappingProxyType(action_grids), ) - - flat_actions = _build_flat_action_grid(action_names=action_names, grids=grids) filtered_params = {k: v for k, v in regime_params.items() if k in accepted} state_names = list(internal_regime.variable_info.query("is_state").index) diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index ad6f6c1f..28cbce72 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -1,13 +1,29 @@ from types import MappingProxyType +import jax.numpy as jnp import pandas as pd from jax import Array -from lcm.grids import Grid +from lcm.grids import Grid, IrregSpacedGrid from lcm.interfaces import StateActionSpace from lcm.typing import StateName, StateOrActionName +def _grid_to_jax_or_placeholder(grid: Grid) -> Array: + """Return the grid's points, or a NaN placeholder for runtime-supplied grids. + + `IrregSpacedGrid.to_jax()` raises when its points haven't been supplied — that + is the right behaviour everywhere except here: the base state-action space + needs a *shape-correct* array to wire through pytree structures and AOT + tracing before runtime substitution by + `InternalRegime.state_action_space(regime_params=...)`. NaN (rather than + zero) makes any accidental computation against the placeholder fail loudly. + """ + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + return jnp.full(grid.n_points, jnp.nan) + return grid.to_jax() + + def create_state_action_space( *, variable_info: pd.DataFrame, @@ -33,7 +49,8 @@ def create_state_action_space( """ if states is None: _states = { - sn: grids[sn].to_jax() for sn in variable_info.query("is_state").index + sn: _grid_to_jax_or_placeholder(grids[sn]) + for sn in variable_info.query("is_state").index } else: _validate_all_states_present( @@ -47,7 +64,7 @@ def create_state_action_space( for name in variable_info.query("is_action & is_discrete").index } continuous_actions = { - name: grids[name].to_jax() + name: _grid_to_jax_or_placeholder(grids[name]) for name in variable_info.query("is_action & is_continuous").index } state_and_discrete_action_names = tuple( diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index ceaee4cf..0dab6acd 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -69,7 +69,15 @@ def test_runtime_grid_creation(): grid = IrregSpacedGrid(n_points=5) assert grid.pass_points_at_runtime assert grid.n_points == 5 - assert grid.to_jax().shape == (5,) # placeholder zeros + + +def test_runtime_grid_to_jax_raises_before_substitution(): + """`to_jax()` on a runtime grid is a bug — substitution happens at + solve/simulate time via `InternalRegime.state_action_space(regime_params=...)`. + The error message must point the caller at that API.""" + grid = IrregSpacedGrid(n_points=5) + with pytest.raises(Exception, match="state_action_space"): + grid.to_jax() def test_fixed_grid_not_runtime(): diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py new file mode 100644 index 00000000..70b6dbfe --- /dev/null +++ b/tests/test_single_feasible_action.py @@ -0,0 +1,562 @@ +"""Reproduce the aca-model NaN failure (issue OpenSourceEconomics/aca-model#9): +solve raises `InvalidValueFunctionError: ... regime 'dead': all values are NaN`. + +Two hypotheses are exercised here: + +1. **single/all infeasible action**: when constraints leave a single (or no) + consumption gridpoint feasible. `max(Q, where=F, initial=-inf)` is supposed + to mask infeasible Q values, but if Q itself contains NaN at infeasible + cells (e.g. `log(0)` at `consumption=0`), the mask is enough — proven by + the tests in this file. `-inf` from all-infeasible cells does not cascade + to NaN by itself either. + +2. **CRRA bequest with pref_type indexing under jnp.where**: the bequest + function evaluates *both* branches of `jnp.where(jnp.isclose(gamma, 1), + log_branch, power_branch)`. For a parameter set where gamma is exactly 1 + for one preference type but not the other, the power_branch divides by + `1 - gamma = 0` for the gamma=1 type. NaN/Inf from the *unselected* + branch leaks through `jnp.where`'s gradient/forward pass under JIT + tracing in some configurations. + +The model uses an `IrregSpacedGrid` consumption action with runtime-supplied +points to also exercise the `feature/runtime-action-grids` path (PR #338). +""" + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.grids import IrregSpacedGrid +from lcm.typing import ContinuousAction, ContinuousState, FloatND + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _utility(consumption: ContinuousAction) -> FloatND: + """CRRA-like utility. log requires consumption > 0.""" + return jnp.log(consumption) + + +def _next_wealth( + wealth: ContinuousState, + consumption: ContinuousAction, +) -> ContinuousState: + return wealth - consumption + + +def _borrowing_constraint( + consumption: ContinuousAction, wealth: ContinuousState +) -> FloatND: + return consumption <= wealth + + +def _next_regime(age: int, last_alive_age: int) -> FloatND: + """Deterministic regime transition: alive→alive while age= last_alive_age, RegimeId.dead, RegimeId.alive) + + +def _build_model( + *, + wealth_lo: float, + wealth_hi: float, + n_wealth: int, + consumption_lo: float, + consumption_hi: float, + n_consumption: int, + n_periods: int, +) -> tuple[Model, dict]: + """Build a 2-regime (alive, dead) model with runtime consumption points. + + The wealth grid is fixed-LinSpaced (so changing it doesn't perturb the + runtime-action-grid path). The consumption grid is the runtime-supplied + IrregSpacedGrid. + """ + last_alive_age = n_periods - 2 # alive at ages 0..n-2; dead at n-1 + alive = Regime( + functions={"utility": _utility}, + states={"wealth": LinSpacedGrid(start=wealth_lo, stop=wealth_hi, n_points=n_wealth)}, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": IrregSpacedGrid(n_points=n_consumption)}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age > last_alive_age, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=n_periods - 1, step="Y"), + regime_id_class=RegimeId, + ) + consumption_points = jnp.linspace(consumption_lo, consumption_hi, n_consumption) + params = { + "discount_factor": 0.95, + "alive": { + "consumption": {"points": consumption_points}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + } + return model, params + + +def test_baseline_no_nan(): + """Healthy regime: at every wealth, every consumption point is feasible. + + consumption_lo == wealth_lo so the smallest action is always feasible at + every wealth gridpoint. `consumption_hi <= wealth_lo` ensures even the + largest consumption is feasible at the lowest wealth state. + """ + model, params = _build_model( + wealth_lo=10.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + + +def test_some_states_have_only_one_feasible_action(): + """At low-wealth states, only consumption[0] satisfies `consumption <= wealth`. + + consumption_lo < wealth_lo (so smallest action feasible at every wealth), + but consumption[1:] > wealth_lo (so at the lowest wealth gridpoint only + consumption[0] is feasible). + """ + model, params = _build_model( + wealth_lo=1.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=0.5, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + # At wealth = 1.0, only consumption = 0.5 is feasible (1.625, 2.75, 3.875, + # 5.0 all > 1.0). + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)), ( + "single-feasible-action state should not produce NaN V" + ) + + +def test_some_states_have_no_feasible_action(): + """At sufficiently low wealth, *every* consumption gridpoint is infeasible. + + consumption_lo > wealth_lo means at the lowest wealth state no + consumption point satisfies the borrowing constraint. + """ + model, params = _build_model( + wealth_lo=0.1, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + # At wealth = 0.1, no consumption point is feasible. max returns -inf. + # The next period interpolates V over the resulting -inf cells. + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)), ( + "all-infeasible state should yield -inf, not NaN" + ) + + +def test_log_zero_consumption_propagates_nan_via_max_when_unconstrained(): + """U(c=0) = -inf is fine, but U evaluated at infeasible negative wealth + via `next_wealth = wealth - c` going through interpolation in the next + period should not pollute V.""" + # consumption_lo = 0 → log(0) = -inf at the smallest action, regardless + # of feasibility. `where=F_arr` should mask this out of the max. + model, params = _build_model( + wealth_lo=1.0, + wealth_hi=10.0, + n_wealth=5, + consumption_lo=0.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + # log(0) = -inf is not NaN, but combined with where-mask edge + # cases this could in principle leak; check it does not. + assert not jnp.any(jnp.isnan(V_arr)) + + +@pytest.mark.parametrize( + ("wealth_lo", "consumption_lo", "label"), + [ + (1.0, 0.5, "single-feasible"), + (0.1, 1.0, "all-infeasible"), + ], +) +def test_simulate_with_constrained_action_grid(wealth_lo, consumption_lo, label): + """End-to-end solve+simulate for both regimes.""" + model, params = _build_model( + wealth_lo=wealth_lo, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=consumption_lo, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([wealth_lo, 5.0, 20.0]), + "regime": jnp.array( + [RegimeId.alive, RegimeId.alive, RegimeId.alive], dtype=jnp.int32 + ), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + check_initial_conditions=False, + log_level="off", + ) + df = result.to_dataframe() + assert not df["value"].isna().any(), ( + f"{label}: simulated value column should not contain NaN" + ) + + +# --------------------------------------------------------------------------- +# Replicas of the aca-baseline failure path: dead regime with a CRRA bequest +# whose `gamma` is per-pref_type, evaluated through `jnp.where`. +# --------------------------------------------------------------------------- + + +@categorical(ordered=False) +class PrefType: + type_0: int + type_1: int + + +@categorical(ordered=False) +class AliveDeadRegimeId: + alive: int + dead: int + + +def _crra_bequest( + assets: ContinuousState, + pref_type, + bequest_shifter: float, + consumption_weight, + coefficient_rra, +) -> FloatND: + """Replica of aca_model.agent.preferences.bequest, simplified. + + `consumption_weight` and `coefficient_rra` are FloatND indexed by + `pref_type`. Both branches of the `jnp.where` are traced. + """ + alpha = consumption_weight[pref_type] + gamma = coefficient_rra[pref_type] + assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter + one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + val = jnp.where( + jnp.isclose(gamma, 1.0), + jnp.log(assets_shifted), + assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, + ) + return val + + +def _alive_utility( + consumption: ContinuousAction, pref_type, consumption_weight +) -> FloatND: + """Make pref_type matter in alive's utility too (otherwise pylcm complains).""" + alpha = consumption_weight[pref_type] + return alpha * jnp.log(consumption) + + +def _next_assets(assets: ContinuousState, consumption: ContinuousAction) -> ContinuousState: + return assets - consumption + + +def _alive_borrow(consumption: ContinuousAction, assets: ContinuousState) -> FloatND: + return consumption <= assets + + +def _alive_to_dead(age: int, last_alive_age: int) -> FloatND: + """Deterministic regime transition; returns a scalar regime ID.""" + return jnp.where( + age >= last_alive_age, AliveDeadRegimeId.dead, AliveDeadRegimeId.alive + ) + + +def _build_alive_dead_model( + *, + coefficient_rra: tuple[float, float], + consumption_weight: tuple[float, float], + n_periods: int = 3, +) -> tuple[Model, dict]: + last_alive_age = n_periods - 2 + from lcm import DiscreteGrid + + alive = Regime( + functions={"utility": _alive_utility}, + states={ + "assets": LinSpacedGrid(start=1.0, stop=20.0, n_points=5), + "pref_type": DiscreteGrid(PrefType, batch_size=1), + }, + state_transitions={"assets": _next_assets, "pref_type": None}, + actions={"consumption": IrregSpacedGrid(n_points=5)}, + constraints={"borrowing_constraint": _alive_borrow}, + transition=_alive_to_dead, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": _crra_bequest}, + states={ + "assets": LinSpacedGrid(start=1.0, stop=20.0, n_points=5), + "pref_type": DiscreteGrid(PrefType, batch_size=1), + }, + active=lambda _age: True, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=n_periods - 1, step="Y"), + regime_id_class=AliveDeadRegimeId, + ) + cw_arr = jnp.asarray(consumption_weight) + params = { + "discount_factor": 0.95, + "alive": { + "utility": {"consumption_weight": cw_arr}, + "consumption": {"points": jnp.linspace(0.5, 5.0, 5)}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + "dead": { + "utility": { + "bequest_shifter": 100.0, + "consumption_weight": cw_arr, + "coefficient_rra": jnp.asarray(coefficient_rra), + }, + }, + } + return model, params + + +def test_bequest_gamma_close_to_one_is_safe(): + """gamma=0.999077 (the benchmark value for type_1) should not produce NaN.""" + model, params = _build_alive_dead_model( + coefficient_rra=(3.84, 0.999077), + consumption_weight=(0.68, 0.88), + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + + +def test_bequest_gamma_exactly_one_for_one_type_only(): + """gamma=1.0 exactly for one pref type triggers `jnp.where(isclose, log, power)`. + + The unselected `power` branch divides by `1 - gamma = 0` for that type. JAX + evaluates both branches of `jnp.where`; the non-finite from the unselected + branch is masked, but `0/0` produces NaN that may not be masked correctly + when the operand to `jnp.where` is itself NaN under XLA's nan-prop rules. + """ + model, params = _build_alive_dead_model( + coefficient_rra=(3.84, 1.0), # type_1 hits the log branch + consumption_weight=(0.68, 0.88), + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for period, regime_to_V in period_to_regime_to_V_arr.items(): + for regime, V_arr in regime_to_V.items(): + assert not jnp.any(jnp.isnan(V_arr)), ( + f"NaN in V[{regime}, period={period}] when one type has gamma=1.0" + ) + + +# --------------------------------------------------------------------------- +# Direct probe: `map_coordinates` produces NaN at ±inf / NaN coordinates. +# This is the concrete NaN source — `lower_weight = 1 - inf = -inf` and +# `upper_weight = inf` combined with positive grid values gives `inf - inf = +# NaN`. The aca-baseline NaN-in-V at age 51 is most plausibly traced back to +# *some* upstream computation (next_assets / next_aime, or a state coordinate +# from `get_irreg_coordinate` / `get_*_coordinate` that divides by zero on a +# degenerate grid segment) producing inf, which then poisons the value +# function via this interpolation path. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("bad_coord", [jnp.inf, -jnp.inf, jnp.nan]) +def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): + """`map_coordinates` cannot recover from a non-finite continuous-state + coordinate: the linear-interp weights become `inf` and `1 - inf = -inf`, + and `inf * V[k] - inf * V[k-1]` reduces to NaN. + + Implication for callers: any path that can feed `inf` or `NaN` into the + coordinate finder (e.g. division by zero in a state transition, an + overflow when V values are O(1e8), or a `0/0` in a degenerate + IrregSpacedGrid segment) will produce NaN in V. + """ + from lcm.regime_building.ndimage import map_coordinates + + V_arr = jnp.array([1.0, 5.0, 12.0]) + out = map_coordinates(V_arr, coordinates=[jnp.array(bad_coord)]) + assert jnp.isnan(out) + + +def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): + """`get_irreg_coordinate` divides by `upper_point - lower_point`. If a + runtime-supplied points array contains duplicates, the divisor is 0. + + Reproduces a class of failures that is *only* possible under + `feature/runtime-action-grids` / runtime-state-grids when the caller + constructs the points from a parameter that can collapse (e.g. + `geomspace(consumption_floor, MAX, n_points)` with `consumption_floor == + MAX`, or any param-driven `linspace` whose endpoints can coincide). + """ + from lcm.grids.coordinates import get_irreg_coordinate + + # Duplicate adjacent points where the query value equals the duplicate. + # `searchsorted([0, 1, 1], 1.0, side='right')=3` → clipped to n-1=2, + # idx_lower=1, lower_point=points[1]=1.0, upper_point=points[2]=1.0, + # step_size=0 → decimal_part = 0/0 = nan. + points = jnp.array([0.0, 1.0, 1.0]) + coord = get_irreg_coordinate(value=jnp.array(1.0), points=points) + assert not jnp.isfinite(coord), ( + "Duplicate adjacent grid points cause `step_size = 0` in " + "`get_irreg_coordinate`; current behaviour is to silently return " + "inf/nan, which then poisons V interpolation downstream." + ) + + +# --------------------------------------------------------------------------- +# `validate_initial_conditions` uses `_base_state_action_space` directly, +# which still holds the placeholder zeros for runtime-supplied +# `IrregSpacedGrid`. With a feasibility constraint that the all-zero +# placeholder fails, every subject is reported infeasible — even though the +# real (post-substitution) grid would pass. This affects runtime grids +# regardless of whether they are state or action grids. +# --------------------------------------------------------------------------- + + +def _runtime_state_grid_model() -> tuple[Model, dict, dict]: + """A 2-regime model with a runtime-supplied IrregSpacedGrid *state*.""" + + @categorical(ordered=False) + class RuntimeRegimeId: + alive: int + dead: int + + def utility(consumption, wealth): + return jnp.log(consumption) + 0.0 * wealth + + def next_wealth(wealth, consumption): + return wealth - consumption + + def borrow(consumption, wealth): + # The validator sees `wealth` as a per-subject array with the + # subject-supplied initial values, but `consumption` as the *grid* + # (placeholder zeros for runtime grids). With a feasibility check + # that requires `consumption > 0`, every action gridpoint is + # infeasible until the runtime points replace the placeholder. + return consumption > 0 + + def next_regime(age, last_alive_age): + return jnp.where(age >= last_alive_age, RuntimeRegimeId.dead, RuntimeRegimeId.alive) + + last_alive_age = 1 + alive = Regime( + functions={"utility": utility}, + states={"wealth": IrregSpacedGrid(n_points=4)}, # runtime state grid + state_transitions={"wealth": next_wealth}, + actions={"consumption": LinSpacedGrid(start=0.5, stop=5.0, n_points=5)}, + constraints={"borrow": borrow}, + transition=next_regime, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age > last_alive_age, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RuntimeRegimeId, + ) + params = { + "discount_factor": 0.95, + "alive": { + "wealth": {"points": jnp.linspace(1.0, 10.0, 4)}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + } + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([2.0, 5.0, 9.0]), + "regime": jnp.array( + [RuntimeRegimeId.alive, RuntimeRegimeId.alive, RuntimeRegimeId.alive], + dtype=jnp.int32, + ), + } + return model, params, initial_conditions + + +def test_runtime_action_grid_passes_initial_conditions_validation(): + """`feature/runtime-action-grids` regression: initial-conditions + feasibility check must use the *substituted* action grid, not the + `_base_state_action_space` placeholder zeros.""" + model, params = _build_model( + wealth_lo=10.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([10.0, 15.0, 20.0]), + "regime": jnp.array( + [RegimeId.alive, RegimeId.alive, RegimeId.alive], dtype=jnp.int32 + ), + } + # `check_initial_conditions=True` (the default) must pass — the + # runtime-supplied consumption points are well-formed. + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + ) + assert result.n_subjects == 3 + + +def test_runtime_state_grid_passes_initial_conditions_validation(): + """Same regression for runtime-supplied *state* grids.""" + model, params, initial_conditions = _runtime_state_grid_model() + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + ) + assert result.n_subjects == 3 From db98cde25b64ce1ba506b7b32a8c0c3c881b938f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:12:19 +0000 Subject: [PATCH 04/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_single_feasible_action.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index 70b6dbfe..20b19152 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -78,7 +78,9 @@ def _build_model( last_alive_age = n_periods - 2 # alive at ages 0..n-2; dead at n-1 alive = Regime( functions={"utility": _utility}, - states={"wealth": LinSpacedGrid(start=wealth_lo, stop=wealth_hi, n_points=n_wealth)}, + states={ + "wealth": LinSpacedGrid(start=wealth_lo, stop=wealth_hi, n_points=n_wealth) + }, state_transitions={"wealth": _next_wealth}, actions={"consumption": IrregSpacedGrid(n_points=n_consumption)}, constraints={"borrowing_constraint": _borrowing_constraint}, @@ -290,7 +292,9 @@ def _alive_utility( return alpha * jnp.log(consumption) -def _next_assets(assets: ContinuousState, consumption: ContinuousAction) -> ContinuousState: +def _next_assets( + assets: ContinuousState, consumption: ContinuousAction +) -> ContinuousState: return assets - consumption @@ -479,7 +483,9 @@ def borrow(consumption, wealth): return consumption > 0 def next_regime(age, last_alive_age): - return jnp.where(age >= last_alive_age, RuntimeRegimeId.dead, RuntimeRegimeId.alive) + return jnp.where( + age >= last_alive_age, RuntimeRegimeId.dead, RuntimeRegimeId.alive + ) last_alive_age = 1 alive = Regime( From 03ba800ff8d9ccc95b75b5ce89026f8c96868611 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 14:35:44 +0200 Subject: [PATCH 05/80] Fix remaining ruff check errors in test_single_feasible_action.py Move late `DiscreteGrid`, `map_coordinates`, and `get_irreg_coordinate` imports to the module top level (PLC0415), drop the unnecessary `val` assignment before return (RET504), and mark the unused `wealth` arg in the local `borrow` constraint as `# noqa: ARG001`. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_single_feasible_action.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index 20b19152..04628ed2 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -25,8 +25,10 @@ import jax.numpy as jnp import pytest -from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical +from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical from lcm.grids import IrregSpacedGrid +from lcm.grids.coordinates import get_irreg_coordinate +from lcm.regime_building.ndimage import map_coordinates from lcm.typing import ContinuousAction, ContinuousState, FloatND @@ -276,12 +278,11 @@ def _crra_bequest( gamma = coefficient_rra[pref_type] assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) - val = jnp.where( + return jnp.where( jnp.isclose(gamma, 1.0), jnp.log(assets_shifted), assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, ) - return val def _alive_utility( @@ -316,7 +317,6 @@ def _build_alive_dead_model( n_periods: int = 3, ) -> tuple[Model, dict]: last_alive_age = n_periods - 2 - from lcm import DiscreteGrid alive = Regime( functions={"utility": _alive_utility}, @@ -418,8 +418,6 @@ def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): overflow when V values are O(1e8), or a `0/0` in a degenerate IrregSpacedGrid segment) will produce NaN in V. """ - from lcm.regime_building.ndimage import map_coordinates - V_arr = jnp.array([1.0, 5.0, 12.0]) out = map_coordinates(V_arr, coordinates=[jnp.array(bad_coord)]) assert jnp.isnan(out) @@ -435,8 +433,6 @@ def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): `geomspace(consumption_floor, MAX, n_points)` with `consumption_floor == MAX`, or any param-driven `linspace` whose endpoints can coincide). """ - from lcm.grids.coordinates import get_irreg_coordinate - # Duplicate adjacent points where the query value equals the duplicate. # `searchsorted([0, 1, 1], 1.0, side='right')=3` → clipped to n-1=2, # idx_lower=1, lower_point=points[1]=1.0, upper_point=points[2]=1.0, @@ -474,7 +470,7 @@ def utility(consumption, wealth): def next_wealth(wealth, consumption): return wealth - consumption - def borrow(consumption, wealth): + def borrow(consumption, wealth): # noqa: ARG001 # The validator sees `wealth` as a per-subject array with the # subject-supplied initial values, but `consumption` as the *grid* # (placeholder zeros for runtime grids). With a feasibility check From 3769d2d63a0f708b2bf4b8a11b603b686282b996 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 17:28:20 +0200 Subject: [PATCH 06/80] Raise on regime-function-output indexed by discrete state in a consumer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A regime function whose output is then re-indexed by a discrete state inside another consumer (function, constraint, or transition) is a silent footgun: pylcm broadcasts function outputs to per-cell scalars before consumption, so the indexing silently produces NaN at runtime instead of the intended scalar. The aca-baseline benchmark hit this via `bequest(... utility_scale_factor[pref_type])` where `utility_scale_factor` is registered as a regime function — the dead regime's V came back all-NaN with no actionable error. Adds an AST-walking validator in `validate_logical_consistency` that inspects every consumer (functions, constraints, transition) for a `Subscript(Name=X, slice=Name=Y)` pattern where `X` is in `regime.functions` and `Y` is a `DiscreteGrid` state. If any clash is found, raises `RegimeInitializationError` listing each clash and pointing the user at the safe pattern (function takes the state, returns a scalar — see `discount_factor`). Three TDD tests in `tests/test_function_output_state_indexing.py`: - the clash raises (functions case) - the safe pattern (function takes the state, scalar return) builds - the check applies to constraints too Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 80 ++++++++++ tests/test_function_output_state_indexing.py | 153 +++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 tests/test_function_output_state_indexing.py diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 0754c0a7..ddc63631 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -5,6 +5,9 @@ """ +import ast +import inspect +import textwrap from collections.abc import Callable, Mapping from typing import TypeAliasType @@ -125,6 +128,7 @@ def validate_logical_consistency(regime: Regime) -> None: error_messages.extend(_validate_active(regime.active)) error_messages.extend(_validate_state_transitions(regime)) + error_messages.extend(_validate_function_output_state_indexing(regime)) states_and_actions_overlap = set(regime.states) & set(regime.actions) if states_and_actions_overlap: @@ -138,6 +142,82 @@ def validate_logical_consistency(regime: Regime) -> None: raise RegimeInitializationError(msg) +def _validate_function_output_state_indexing(regime: Regime) -> list[str]: + """Detect the regime-function-output / state-indexed-input name clash. + + A regime function whose output is then re-indexed by a discrete state inside + another consumer (function, constraint, or transition) is a silent footgun: + pylcm broadcasts function outputs to per-cell scalars before consumption, so + the indexing produces NaN at runtime instead of the intended scalar. + + The safe pattern is to take the state as input on the producing function and + return the scalar directly (see `aca_model.agent.preferences.discount_factor`). + """ + function_output_names = set(regime.functions) + discrete_state_names = { + name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid) + } + if not function_output_names or not discrete_state_names: + return [] + + consumers: list[tuple[str, Callable]] = [] + consumers.extend(regime.functions.items()) + consumers.extend(regime.constraints.items()) + if callable(regime.transition): + consumers.append(("regime_transition", regime.transition)) + + errors: list[str] = [] + for consumer_name, func in consumers: + clashes = _find_function_output_state_indexing( + func=func, + function_output_names=function_output_names, + discrete_state_names=discrete_state_names, + ) + for func_output_name, state_name in clashes: + errors.append( + f"Consumer '{consumer_name}' indexes regime function output " + f"'{func_output_name}' by discrete state '{state_name}' " + f"(`{func_output_name}[{state_name}]`). pylcm broadcasts " + f"function outputs to per-cell scalars before consumption, so " + f"this indexing silently produces NaN. Refactor " + f"'{func_output_name}' to take '{state_name}' as input and " + f"return the scalar directly." + ) + return errors + + +def _find_function_output_state_indexing( + *, + func: Callable, + function_output_names: set[str], + discrete_state_names: set[str], +) -> list[tuple[str, str]]: + """Return `(function_output_name, state_name)` clashes inside `func`'s body.""" + try: + source = textwrap.dedent(inspect.getsource(func)) + except OSError, TypeError: + return [] + try: + tree = ast.parse(source) + except SyntaxError: + return [] + + clashes: list[tuple[str, str]] = [] + for node in ast.walk(tree): + if not isinstance(node, ast.Subscript): + continue + if not isinstance(node.value, ast.Name): + continue + if node.value.id not in function_output_names: + continue + if not isinstance(node.slice, ast.Name): + continue + if node.slice.id not in discrete_state_names: + continue + clashes.append((node.value.id, node.slice.id)) + return clashes + + def collect_state_transitions( states: Mapping[StateName, Grid], state_transitions: Mapping[ diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_state_indexing.py new file mode 100644 index 00000000..f326833c --- /dev/null +++ b/tests/test_function_output_state_indexing.py @@ -0,0 +1,153 @@ +"""Tests for the regime-function-output / state-indexed-input name clash. + +A regime function whose output is then re-indexed by a discrete state +inside another function is a silent footgun: pylcm broadcasts function +outputs to per-cell scalars before consumption, so the indexing produces +NaN at runtime instead of the intended scalar. + +The validation layer must raise on construction with a clear message +pointing the user at the safe pattern (see `discount_factor` in +`aca_model.agent.preferences`: take the state as input, return a scalar). +""" + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.exceptions import RegimeInitializationError +from lcm.typing import ContinuousAction, DiscreteState, FloatND + + +@categorical(ordered=False) +class PrefType: + type_0: int + type_1: int + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _per_type_scale(some_param: FloatND) -> FloatND: + """Returns one scale per pref_type — depends only on a per-type Series param.""" + return jnp.abs(1.0 / (1.0 - some_param)) + + +def _utility_with_state_indexed_function_output( + consumption: ContinuousAction, + pref_type: DiscreteState, + per_type_scale: FloatND, +) -> FloatND: + # The clash: per_type_scale is registered as a regime function output + # (returns shape (n_pref_types,)), but here it is consumed and indexed + # by the `pref_type` discrete state. + return per_type_scale[pref_type] * jnp.log(consumption + 1.0) + + +def _next_regime(period: int) -> FloatND: + return jnp.where(period >= 1, RegimeId.dead, RegimeId.alive) + + +def _make_clashing_model() -> Model: + alive = Regime( + functions={ + "utility": _utility_with_state_indexed_function_output, + "per_type_scale": _per_type_scale, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_function_output_indexed_by_state_raises(): + """A regime function output indexed by a discrete state inside another regime + function must raise on construction (silent NaN bug otherwise).""" + with pytest.raises( + RegimeInitializationError, + match=r"per_type_scale.*pref_type", + ): + _make_clashing_model() + + +def _utility_safe( + consumption: ContinuousAction, + pref_type: DiscreteState, # noqa: ARG001 + per_type_scale: FloatND, +) -> FloatND: + # Safe variant: per_type_scale is consumed as a scalar (no [pref_type] indexing). + return per_type_scale * jnp.log(consumption + 1.0) + + +def _per_type_scale_safe(pref_type: DiscreteState, some_param: FloatND) -> FloatND: + """Safe variant: takes pref_type, returns scalar (mirrors discount_factor).""" + return jnp.abs(1.0 / (1.0 - some_param[pref_type])) + + +def test_safe_pattern_does_not_raise(): + """The safe pattern (function takes the state, returns a scalar) builds fine.""" + alive = Regime( + functions={ + "utility": _utility_safe, + "per_type_scale": _per_type_scale_safe, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_constraint_indexing_function_output_by_state_raises(): + """The check applies to regime constraints too, not only `functions`.""" + + def _constraint_indexing_function_output( + consumption: ContinuousAction, + pref_type: DiscreteState, + per_type_scale: FloatND, + ) -> FloatND: + return consumption <= per_type_scale[pref_type] + + with pytest.raises( + RegimeInitializationError, + match=r"per_type_scale.*pref_type", + ): + Regime( + functions={ + "utility": lambda consumption, pref_type, per_type_scale: jnp.log( # noqa: ARG005 + consumption + 1.0 + ), + "per_type_scale": _per_type_scale, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + constraints={"feasibility": _constraint_indexing_function_output}, + transition=_next_regime, + active=lambda age: age < 2, + ) From 282542f25e58d72172e39fb79aebee94d2057590 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 18:32:41 +0200 Subject: [PATCH 07/80] benchmarks: bump aca-model to dead-regime-NaN fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit aca-model `feature/runtime-consumption-points` 4123fe9 → 1342861 (refactors `utility_scale_factor` to take `pref_type` and return a scalar, eliminating the regime-function-output / state-indexed-input clash that produced NaN in the dead regime's V). Co-Authored-By: Claude Opus 4.7 (1M context) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 3f235281..e507ccc8 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev125+g1c65f45dd.d20260429 - sha256: 9d509ac58b2af5658d439c586c2971ead9988f34017c4b0b64c4dd6db51b27aa + version: 0.0.2.dev130+g3769d2d63.d20260429 + sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 7efd8a17..3fe01162 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "4123fe9739c1c4bccebaa149985d0415a4272ef1" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 72c83f79faa3b7536bdc6b331d2ffba711ead8a5 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 21:02:11 +0200 Subject: [PATCH 08/80] Substitute runtime-supplied action gridpoints in simulate's state-action space `create_regime_state_action_space` (used during forward simulation) was calling `create_state_action_space` directly, which leaves `pass_points_at_runtime=True` IrregSpacedGrid action grids as their NaN placeholder. The placeholder fed straight into `argmax_and_max_Q_over_a` and `_lookup_values_from_indices`, so optimal actions came back NaN, the source regime's `next_state` propagated NaN into every target regime's namespaced state, and `validate_V` raised on the first downstream regime whose utility depended on those states (the dead regime in aca-model: assets/pref_type both NaN). Route through `internal_regime.state_action_space(regime_params=...)` (the same path solve uses) and overlay the per-subject states. Add a TDD regression test in tests/test_runtime_params.py covering the simulate path. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/simulation/simulate.py | 1 + src/lcm/simulation/transitions.py | 25 ++++++----- tests/test_runtime_params.py | 74 +++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 9c017db3..d54d2674 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -244,6 +244,7 @@ def _simulate_regime_in_period( state_action_space = create_regime_state_action_space( internal_regime=internal_regime, states=states, + regime_params=internal_params[regime_name], ) # Compute optimal actions # We need to pass the value function array of the next period to the diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 67e6df67..23bfd58c 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -15,7 +15,6 @@ from lcm.interfaces import InternalRegime, StateActionSpace from lcm.simulation.random import generate_simulation_keys -from lcm.state_action_space import create_state_action_space from lcm.typing import ( ActionName, Bool1D, @@ -31,29 +30,35 @@ def create_regime_state_action_space( *, internal_regime: InternalRegime, states: MappingProxyType[str, Array], + regime_params: FlatRegimeParams, ) -> StateActionSpace: """Create the state-action space containing only the relevant subjects in a regime. + Continuous action grids declared with `pass_points_at_runtime=True` are + completed from `regime_params` (via + `InternalRegime.state_action_space`) — otherwise they would carry the + NaN placeholder used during compilation, which propagates into + `optimal_actions` and ultimately `next_states`. + Args: internal_regime: The internal regime instance. states: The current states of all subjects. + regime_params: Flat regime parameters supplied at runtime, used to + substitute runtime-supplied action gridpoints. Returns: The state-action space for the subjects in the regime. """ - relevant_state_names = internal_regime.variable_info.query("is_state").index + base = internal_regime.state_action_space(regime_params=regime_params) - states_for_state_action_space = { - sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names - } - - return create_state_action_space( - variable_info=internal_regime.variable_info, - grids=internal_regime.grids, - states=states_for_state_action_space, + relevant_state_names = internal_regime.variable_info.query("is_state").index + states_for_state_action_space = MappingProxyType( + {sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names} ) + return base.replace(states=states_for_state_action_space) + def calculate_next_states( *, diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index 0dab6acd..a401608d 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -239,3 +239,77 @@ def test_runtime_action_grid_changes_solution(): ) # Period 0 alive value should differ when the action support differs assert not jnp.allclose(V_low[0]["alive"], V_high[0]["alive"]) + + +def _make_action_grid_model_with_stateful_dead(*, consumption_grid): + """Variant where `dead` has a `wealth` state so its utility depends on it. + + Mirrors the aca-model dead regime (carries assets / pref_type so the + bequest function can read them). Used to surface NaN propagation + when the simulate path forgets to substitute runtime-supplied action + gridpoints. + """ + + def _alive_utility( + consumption: ContinuousAction, wealth: ContinuousState + ) -> FloatND: + return jnp.log(consumption + 1) + 0.01 * wealth + + def _dead_utility(wealth: ContinuousState) -> FloatND: + return jnp.log(wealth + 1) + + alive = Regime( + functions={"utility": _alive_utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={ + "wealth": { + "alive": _next_wealth, + "dead": _next_wealth, + }, + }, + actions={"consumption": consumption_grid}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": _dead_utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + active=lambda _age: True, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_simulate_with_runtime_action_grid_no_nan() -> None: + """Simulate must substitute runtime-supplied action gridpoints into the + state-action space; otherwise the action grid is filled with NaN + placeholders, optimal_actions become NaN, next_states propagate NaN to + the dead regime, and `validate_V` raises. + """ + model = _make_action_grid_model_with_stateful_dead( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}, + } + initial_conditions = { + "regime": jnp.array([RegimeId.alive, RegimeId.alive, RegimeId.alive]), + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([2.0, 5.0, 9.0]), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + check_initial_conditions=False, + ) + df = result.to_dataframe() + assert not df["value"].isna().any() From db6214f275d3863fc11a79395ec6b9a0e300db79 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 15:02:27 +0200 Subject: [PATCH 09/80] Guard log-spaced grids against non-positive start; reject non-finite grid values `LogSpacedGrid` previously inherited only the generic continuous-grid checks (start < stop, n_points > 0). With `start <= 0`, `to_jax()` silently returned NaN/-inf, and the bug would only surface deep inside an interpolation kernel. Now refuses at construction. While here, tighten two adjacent silent-failure modes: - `_validate_continuous_grid` rejects non-finite `start`/`stop`. `start >= stop` is False for NaN, so a NaN bound previously slipped through every check. - `_validate_irreg_spaced_grid` rejects non-finite points. The ascending-order test uses `>=`, which is False for NaN, so a NaN point previously passed the order check silently. Both matter for runtime-supplied grids: e.g. `geomspace(consumption_floor, MAX, N)` with a bad `consumption_floor` produces all-NaN points, and we want that caught at the grid layer rather than as a downstream V_arr NaN diagnostic. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 58 +++++++++++++++++++++++++++++++------ tests/test_grids.py | 30 +++++++++++++++++++ 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index db646ca2..8550188d 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -1,4 +1,5 @@ import dataclasses +import math from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -116,12 +117,24 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: class LogSpacedGrid(UniformContinuousGrid): """A logarithmically spaced grid of continuous values. + Requires `start > 0`; otherwise `log(start)` is undefined and `to_jax()` + silently returns NaN/-inf, which propagates through every downstream + interpolation. + Example: -------- Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 10, 100]`. """ + def __post_init__(self) -> None: + _validate_continuous_grid( + start=self.start, + stop=self.stop, + n_points=self.n_points, + requires_positive_start=True, + ) + def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" return grid_coordinates.logspace( @@ -228,6 +241,7 @@ def _validate_continuous_grid( start: float, stop: float, n_points: int, + requires_positive_start: bool = False, ) -> None: """Validate the continuous grid parameters. @@ -235,6 +249,8 @@ def _validate_continuous_grid( start: The start value of the grid. stop: The stop value of the grid. n_points: The number of points in the grid. + requires_positive_start: If True, also require `start > 0` (used by + log-spaced grids since `log(x)` is undefined for `x <= 0`). Raises: GridInitializationError: If the grid parameters are invalid. @@ -250,6 +266,15 @@ def _validate_continuous_grid( if not valid_stop_type: error_messages.append("stop must be a scalar int or float value") + # Reject NaN/inf early — `start >= stop` returns False for NaN, so an + # un-finite start would otherwise pass silently and produce a broken grid. + if valid_start_type and not math.isfinite(start): + error_messages.append(f"start must be finite, got {start}") + valid_start_type = False + if valid_stop_type and not math.isfinite(stop): + error_messages.append(f"stop must be finite, got {stop}") + valid_stop_type = False + if not isinstance(n_points, int) or n_points < 1: error_messages.append( f"n_points must be an int greater than 0 but is {n_points}", @@ -258,6 +283,12 @@ def _validate_continuous_grid( if valid_start_type and valid_stop_type and start >= stop: error_messages.append("start must be less than stop") + if valid_start_type and requires_positive_start and start <= 0: + error_messages.append( + f"start must be > 0 for a log-spaced grid (got {start}); " + f"`log(x)` is undefined for `x <= 0`." + ) + if error_messages: msg = format_messages(error_messages) raise GridInitializationError(msg) @@ -290,15 +321,24 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None: f"Non-numeric elements found at indices: {non_numeric}" ) else: - # Check that points are in ascending order - for i in range(len(points) - 1): - if points[i] >= points[i + 1]: - error_messages.append( - "Points must be in strictly ascending order. " - f"Found points[{i}]={points[i]} >= " - f"points[{i + 1}]={points[i + 1]}" - ) - break + # Reject NaN/inf — comparisons with NaN are False, so the + # ascending-order check below would silently let them through. + non_finite = [(i, p) for i, p in enumerate(points) if not math.isfinite(p)] + if non_finite: + error_messages.append( + f"All elements of points must be finite. " + f"Non-finite elements found at: {non_finite}" + ) + else: + # Check that points are in strictly ascending order + for i in range(len(points) - 1): + if points[i] >= points[i + 1]: + error_messages.append( + "Points must be in strictly ascending order. " + f"Found points[{i}]={points[i]} >= " + f"points[{i + 1}]={points[i + 1]}" + ) + break if error_messages: msg = format_messages(error_messages) diff --git a/tests/test_grids.py b/tests/test_grids.py index 045f829b..da2370d9 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -217,6 +217,36 @@ def test_logspace_grid_invalid_start(): LogSpacedGrid(start=1, stop=0, n_points=10) +def test_logspace_grid_rejects_zero_start(): + with pytest.raises(GridInitializationError, match="log-spaced grid"): + LogSpacedGrid(start=0, stop=10, n_points=5) + + +def test_logspace_grid_rejects_negative_start(): + with pytest.raises(GridInitializationError, match="log-spaced grid"): + LogSpacedGrid(start=-1.0, stop=10, n_points=5) + + +def test_validate_continuous_grid_rejects_nan_start(): + with pytest.raises(GridInitializationError, match="start must be finite"): + _validate_continuous_grid(start=float("nan"), stop=10, n_points=5) + + +def test_validate_continuous_grid_rejects_inf_stop(): + with pytest.raises(GridInitializationError, match="stop must be finite"): + _validate_continuous_grid(start=1, stop=float("inf"), n_points=5) + + +def test_irreg_spaced_grid_rejects_nan_points(): + with pytest.raises(GridInitializationError, match="must be finite"): + IrregSpacedGrid(points=(1.0, float("nan"), 3.0)) + + +def test_irreg_spaced_grid_rejects_inf_points(): + with pytest.raises(GridInitializationError, match="must be finite"): + IrregSpacedGrid(points=(1.0, 2.0, float("inf"))) + + def test_replace_mixin(): grid = LinSpacedGrid(start=1, stop=5, n_points=5) new_grid = grid.replace(start=0) From f56b9be523be22ecdf331d6c6ca0e21e29ff6a3b Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 18:24:08 +0200 Subject: [PATCH 10/80] Address PR review: docstrings, type hints, validation, drop separator banners MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tests/test_single_feasible_action.py: drop three decorative section banners (AGENTS.md prohibits `# ---...---` separators); fold the banner prose into the docstrings of the tests/helpers below. - tests/test_single_feasible_action.py: type-annotate `_crra_bequest` and `_alive_utility`'s pref_type / consumption_weight / coefficient_rra arguments (DiscreteState / FloatND). - tests/test_runtime_params.py: type-annotate `_make_action_grid_model` and `_make_action_grid_model_with_stateful_dead`. - src/lcm/simulation/transitions.py: re-run `_validate_all_states_present` in the new `create_regime_state_action_space` (the substitution switch from `create_state_action_space(states=...)` to `base.replace(states=...)` had silently dropped this check). - src/lcm/params/regime_template.py: docstring on `_fail_if_runtime_grid_shadows_function`; fix stale phrasing in `create_regime_params_template` ("matching the state name" → "matching the state or action name"). - src/lcm/interfaces.py: comment why the `_ShockGrid` substitution branch is gated on `in_states` only (state-only by design, AGENTS.md forbids ShockGrids as actions; gate is the explicit enforcement of that invariant). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/interfaces.py | 5 +++ src/lcm/params/regime_template.py | 19 ++++++++- src/lcm/simulation/transitions.py | 11 +++-- tests/test_runtime_params.py | 6 ++- tests/test_single_feasible_action.py | 62 +++++++++++----------------- 5 files changed, 59 insertions(+), 44 deletions(-) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 6c80ddb9..59617eb6 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -278,6 +278,11 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac action_replacements[name] = cast( "ContinuousAction", all_params[points_key] ) + # `_ShockGrid` is state-only by construction (intrinsic + # transitions, forbidden as actions per AGENTS.md). The + # `in_states` gate makes that invariant explicit — a + # `_ShockGrid` reaching the action branch would be a model + # bug, not something this method should silently substitute. elif ( in_states and isinstance(spec, _ShockGrid) diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 2b9c748e..85a66b47 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -32,7 +32,7 @@ def create_regime_params_template( Grids with runtime-supplied values (IrregSpacedGrid without points, `_ShockGrid` without full shock_params) add entries to the template under - pseudo-function keys matching the state name. + pseudo-function keys matching the state or action name. Args: regime: The regime as provided by the user. @@ -112,6 +112,23 @@ def _fail_if_runtime_grid_shadows_function( name: str, kind: str, ) -> None: + """Raise if a runtime grid name collides with an existing function name. + + Runtime-supplied state and action grids contribute pseudo-function entries + to the params template (keyed by the state or action name). Letting such a + pseudo-function entry shadow a real regime function would silently break + parameter resolution, so we reject it at template-construction time. + + Args: + function_params: Template entries collected so far, keyed by + (pseudo-)function name. + name: State or action name being added. + kind: `"state"` or `"action"`, surfaced in the error message. + + Raises: + InvalidNameError: If `name` already exists in `function_params`. + + """ if name in function_params: raise InvalidNameError( f"IrregSpacedGrid {kind} '{name}' (with runtime-supplied " diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 23bfd58c..65c8068b 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -15,6 +15,7 @@ from lcm.interfaces import InternalRegime, StateActionSpace from lcm.simulation.random import generate_simulation_keys +from lcm.state_action_space import _validate_all_states_present from lcm.typing import ( ActionName, Bool1D, @@ -53,11 +54,15 @@ def create_regime_state_action_space( base = internal_regime.state_action_space(regime_params=regime_params) relevant_state_names = internal_regime.variable_info.query("is_state").index - states_for_state_action_space = MappingProxyType( - {sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names} + states_for_state_action_space = { + sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names + } + _validate_all_states_present( + provided_states=states_for_state_action_space, + required_state_names=set(relevant_state_names), ) - return base.replace(states=states_for_state_action_space) + return base.replace(states=MappingProxyType(states_for_state_action_space)) def calculate_next_states( diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index a401608d..f68adec4 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -151,7 +151,7 @@ def test_runtime_grid_matches_fixed(): aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) -def _make_action_grid_model(*, consumption_grid): +def _make_action_grid_model(*, consumption_grid: IrregSpacedGrid) -> Model: """Create a 2-regime model where consumption is the runtime-points action grid.""" alive = Regime( functions={"utility": _utility}, @@ -241,7 +241,9 @@ def test_runtime_action_grid_changes_solution(): assert not jnp.allclose(V_low[0]["alive"], V_high[0]["alive"]) -def _make_action_grid_model_with_stateful_dead(*, consumption_grid): +def _make_action_grid_model_with_stateful_dead( + *, consumption_grid: IrregSpacedGrid +) -> Model: """Variant where `dead` has a `wealth` state so its utility depends on it. Mirrors the aca-model dead regime (carries assets / pref_type so the diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index 04628ed2..b5ce7a3f 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -29,7 +29,7 @@ from lcm.grids import IrregSpacedGrid from lcm.grids.coordinates import get_irreg_coordinate from lcm.regime_building.ndimage import map_coordinates -from lcm.typing import ContinuousAction, ContinuousState, FloatND +from lcm.typing import ContinuousAction, ContinuousState, DiscreteState, FloatND @categorical(ordered=False) @@ -244,12 +244,6 @@ def test_simulate_with_constrained_action_grid(wealth_lo, consumption_lo, label) ) -# --------------------------------------------------------------------------- -# Replicas of the aca-baseline failure path: dead regime with a CRRA bequest -# whose `gamma` is per-pref_type, evaluated through `jnp.where`. -# --------------------------------------------------------------------------- - - @categorical(ordered=False) class PrefType: type_0: int @@ -264,10 +258,10 @@ class AliveDeadRegimeId: def _crra_bequest( assets: ContinuousState, - pref_type, + pref_type: DiscreteState, bequest_shifter: float, - consumption_weight, - coefficient_rra, + consumption_weight: FloatND, + coefficient_rra: FloatND, ) -> FloatND: """Replica of aca_model.agent.preferences.bequest, simplified. @@ -286,7 +280,9 @@ def _crra_bequest( def _alive_utility( - consumption: ContinuousAction, pref_type, consumption_weight + consumption: ContinuousAction, + pref_type: DiscreteState, + consumption_weight: FloatND, ) -> FloatND: """Make pref_type matter in alive's utility too (otherwise pylcm complains).""" alpha = consumption_weight[pref_type] @@ -395,28 +391,19 @@ def test_bequest_gamma_exactly_one_for_one_type_only(): ) -# --------------------------------------------------------------------------- -# Direct probe: `map_coordinates` produces NaN at ±inf / NaN coordinates. -# This is the concrete NaN source — `lower_weight = 1 - inf = -inf` and -# `upper_weight = inf` combined with positive grid values gives `inf - inf = -# NaN`. The aca-baseline NaN-in-V at age 51 is most plausibly traced back to -# *some* upstream computation (next_assets / next_aime, or a state coordinate -# from `get_irreg_coordinate` / `get_*_coordinate` that divides by zero on a -# degenerate grid segment) producing inf, which then poisons the value -# function via this interpolation path. -# --------------------------------------------------------------------------- - - @pytest.mark.parametrize("bad_coord", [jnp.inf, -jnp.inf, jnp.nan]) def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): """`map_coordinates` cannot recover from a non-finite continuous-state coordinate: the linear-interp weights become `inf` and `1 - inf = -inf`, - and `inf * V[k] - inf * V[k-1]` reduces to NaN. + and `inf * V[k] - inf * V[k-1]` reduces to NaN. The aca-baseline NaN + at age 51 traces back to some upstream computation (next_assets, + next_aime, or a coordinate finder that divides by zero on a degenerate + grid segment) feeding inf into this path. Implication for callers: any path that can feed `inf` or `NaN` into the - coordinate finder (e.g. division by zero in a state transition, an - overflow when V values are O(1e8), or a `0/0` in a degenerate - IrregSpacedGrid segment) will produce NaN in V. + coordinate finder (division by zero in a state transition, overflow + when V values are O(1e8), or `0/0` in a degenerate IrregSpacedGrid + segment) will produce NaN in V. """ V_arr = jnp.array([1.0, 5.0, 12.0]) out = map_coordinates(V_arr, coordinates=[jnp.array(bad_coord)]) @@ -446,18 +433,17 @@ def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): ) -# --------------------------------------------------------------------------- -# `validate_initial_conditions` uses `_base_state_action_space` directly, -# which still holds the placeholder zeros for runtime-supplied -# `IrregSpacedGrid`. With a feasibility constraint that the all-zero -# placeholder fails, every subject is reported infeasible — even though the -# real (post-substitution) grid would pass. This affects runtime grids -# regardless of whether they are state or action grids. -# --------------------------------------------------------------------------- - - def _runtime_state_grid_model() -> tuple[Model, dict, dict]: - """A 2-regime model with a runtime-supplied IrregSpacedGrid *state*.""" + """Build a 2-regime model with a runtime-supplied IrregSpacedGrid *state*. + + Reproduces the failure mode where `validate_initial_conditions` reads + `_base_state_action_space` directly — still holding the placeholder + zeros for runtime-supplied `IrregSpacedGrid`s. With a feasibility + constraint that the all-zero placeholder fails, every subject is + reported infeasible even though the real (post-substitution) grid + would pass. The same mechanism affects runtime grids whether they + are states or actions. + """ @categorical(ordered=False) class RuntimeRegimeId: From 3b7be82be2e9e7b82ea80173ce4c5e82148eb132 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:42:03 +0200 Subject: [PATCH 11/80] LogSpacedGrid docstring: drop redundant rationale The validator's error message already explains why; the class docstring only needs the contract. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 8550188d..20811ae9 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -117,9 +117,7 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: class LogSpacedGrid(UniformContinuousGrid): """A logarithmically spaced grid of continuous values. - Requires `start > 0`; otherwise `log(start)` is undefined and `to_jax()` - silently returns NaN/-inf, which propagates through every downstream - interpolation. + Requires `start > 0`. Example: -------- From 8589146010a60274083075f22d5e23f6e7d73a94 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:44:14 +0200 Subject: [PATCH 12/80] IrregSpacedGrid.to_jax docstring: shorter, point at the substituted-grid path Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 20811ae9..664dc829 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -202,10 +202,9 @@ def to_jax(self) -> Float1D: """Convert the grid to a Jax array. Raises `GridInitializationError` for runtime-supplied grids - (`pass_points_at_runtime=True`). Substitution happens at solve / - simulate time via `InternalRegime.state_action_space(regime_params=...)`; - any code path that reads the base grid's points before substitution is - a bug. + (`pass_points_at_runtime=True`). To get the substituted points, + call `internal_regime.state_action_space(regime_params=...)` and + read from `.states[name]` or `.continuous_actions[name]`. """ if self.points is None: raise GridInitializationError( From 541392c0f649741991bf81c7bd9a4e41d6491424 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:44:50 +0200 Subject: [PATCH 13/80] IrregSpacedGrid.to_jax error message: same shape as docstring Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 664dc829..1e4b1140 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -208,13 +208,13 @@ def to_jax(self) -> Float1D: """ if self.points is None: raise GridInitializationError( - f"IrregSpacedGrid was declared with n_points={self.n_points} " - f"and no points; values are supplied at runtime via " - f"params['']['']['points']. Reading the grid " - f"before substitution is a bug — call " + f"IrregSpacedGrid declared with n_points={self.n_points} and " + f"no points; values are supplied at runtime via " + f"params['']['']['points']. To get the " + f"substituted points, call " f"`internal_regime.state_action_space(regime_params=...)` and " - f"read points from there, or use `.n_points` if only the shape " - f"is needed." + f"read from `.states[name]` or `.continuous_actions[name]`. " + f"Use `.n_points` if only the shape is needed." ) return jnp.asarray(self.points) From 54d22b00b7062403790fb7fa9a0795e1335e9649 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:46:11 +0200 Subject: [PATCH 14/80] Use jnp.isfinite in grid validators; drop math import Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 1e4b1140..9a147c44 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -1,5 +1,4 @@ import dataclasses -import math from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -265,10 +264,10 @@ def _validate_continuous_grid( # Reject NaN/inf early — `start >= stop` returns False for NaN, so an # un-finite start would otherwise pass silently and produce a broken grid. - if valid_start_type and not math.isfinite(start): + if valid_start_type and not jnp.isfinite(start): error_messages.append(f"start must be finite, got {start}") valid_start_type = False - if valid_stop_type and not math.isfinite(stop): + if valid_stop_type and not jnp.isfinite(stop): error_messages.append(f"stop must be finite, got {stop}") valid_stop_type = False @@ -320,7 +319,7 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None: else: # Reject NaN/inf — comparisons with NaN are False, so the # ascending-order check below would silently let them through. - non_finite = [(i, p) for i, p in enumerate(points) if not math.isfinite(p)] + non_finite = [(i, p) for i, p in enumerate(points) if not jnp.isfinite(p)] if non_finite: error_messages.append( f"All elements of points must be finite. " From ba38876a5c0065d581890cbaf7225546b4c7db00 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:51:24 +0200 Subject: [PATCH 15/80] Drop cryptic aca_model reference from validator docstring Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index ddc63631..96df0dce 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -151,7 +151,7 @@ def _validate_function_output_state_indexing(regime: Regime) -> list[str]: the indexing produces NaN at runtime instead of the intended scalar. The safe pattern is to take the state as input on the producing function and - return the scalar directly (see `aca_model.agent.preferences.discount_factor`). + return the scalar directly. """ function_output_names = set(regime.functions) discrete_state_names = { From 4d11dead41e59dbc14ea3daf3cf35303cecc4b3b Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:55:59 +0200 Subject: [PATCH 16/80] Validator: also flag function-output indexed by a derived categorical MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Derived categoricals (`regime.derived_categoricals`, function outputs that pylcm treats as categoricals — see https://pylcm.readthedocs.io/en/latest/pandas-interop/#derived-categoricals) suffer the same per-cell broadcast clash as discrete states. Extend `discrete_state_names` in `_validate_function_output_state_indexing` to include them; add a TDD test. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 2 +- tests/test_function_output_state_indexing.py | 47 ++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 96df0dce..9f883e0c 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -156,7 +156,7 @@ def _validate_function_output_state_indexing(regime: Regime) -> list[str]: function_output_names = set(regime.functions) discrete_state_names = { name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid) - } + } | set(regime.derived_categoricals) if not function_output_names or not discrete_state_names: return [] diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_state_indexing.py index f326833c..b4d690ed 100644 --- a/tests/test_function_output_state_indexing.py +++ b/tests/test_function_output_state_indexing.py @@ -123,6 +123,53 @@ def test_safe_pattern_does_not_raise(): ) +def test_function_output_indexed_by_derived_categorical_raises(): + """The check applies to derived categoricals (function outputs treated as + categoricals via `derived_categoricals`), not only states.""" + + @categorical(ordered=False) + class IsMarried: + single: int + married: int + + def _is_married(spousal_income: DiscreteState) -> DiscreteState: + return jnp.int32(spousal_income > 0) + + def _per_marital_scale(some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param)) + + def _utility_clash( + consumption: ContinuousAction, + is_married: DiscreteState, + per_marital_scale: FloatND, + ) -> FloatND: + return per_marital_scale[is_married] * jnp.log(consumption + 1.0) + + @categorical(ordered=True) + class SpousalIncome: + single: int + married_no_inc: int + married_has_inc: int + + with pytest.raises( + RegimeInitializationError, + match=r"per_marital_scale.*is_married", + ): + Regime( + functions={ + "utility": _utility_clash, + "per_marital_scale": _per_marital_scale, + "is_married": _is_married, + }, + states={"spousal_income": DiscreteGrid(SpousalIncome)}, + state_transitions={"spousal_income": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + derived_categoricals={"is_married": DiscreteGrid(IsMarried)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + + def test_constraint_indexing_function_output_by_state_raises(): """The check applies to regime constraints too, not only `functions`.""" From 01608bb049d41448e7a2f9bd9bd97de26030e425 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:57:50 +0200 Subject: [PATCH 17/80] create_regime_state_action_space docstring: trim rationale Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/simulation/transitions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 65c8068b..18bdbb6b 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -37,9 +37,7 @@ def create_regime_state_action_space( Continuous action grids declared with `pass_points_at_runtime=True` are completed from `regime_params` (via - `InternalRegime.state_action_space`) — otherwise they would carry the - NaN placeholder used during compilation, which propagates into - `optimal_actions` and ultimately `next_states`. + `InternalRegime.state_action_space`). Args: internal_regime: The internal regime instance. From 6241056a354f852028ebbfa36b02c524523438d1 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:00:42 +0200 Subject: [PATCH 18/80] state_action_space: move private helpers below public function (deep module) Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/state_action_space.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index 28cbce72..7c88e165 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -9,21 +9,6 @@ from lcm.typing import StateName, StateOrActionName -def _grid_to_jax_or_placeholder(grid: Grid) -> Array: - """Return the grid's points, or a NaN placeholder for runtime-supplied grids. - - `IrregSpacedGrid.to_jax()` raises when its points haven't been supplied — that - is the right behaviour everywhere except here: the base state-action space - needs a *shape-correct* array to wire through pytree structures and AOT - tracing before runtime substitution by - `InternalRegime.state_action_space(regime_params=...)`. NaN (rather than - zero) makes any accidental computation against the placeholder fail loudly. - """ - if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: - return jnp.full(grid.n_points, jnp.nan) - return grid.to_jax() - - def create_state_action_space( *, variable_info: pd.DataFrame, @@ -79,6 +64,21 @@ def create_state_action_space( ) +def _grid_to_jax_or_placeholder(grid: Grid) -> Array: + """Return the grid's points, or a NaN placeholder for runtime-supplied grids. + + `IrregSpacedGrid.to_jax()` raises when its points haven't been supplied — that + is the right behaviour everywhere except here: the base state-action space + needs a *shape-correct* array to wire through pytree structures and AOT + tracing before runtime substitution by + `InternalRegime.state_action_space(regime_params=...)`. NaN (rather than + zero) makes any accidental computation against the placeholder fail loudly. + """ + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + return jnp.full(grid.n_points, jnp.nan) + return grid.to_jax() + + def _validate_all_states_present( *, provided_states: dict[StateName, Array], required_state_names: set[StateName] ) -> None: From 2efe9e1b5bf9381badd66775069fdeaafcfd918c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:07:08 +0200 Subject: [PATCH 19/80] Drop application-specific (aca) references from test docstrings pylcm is a general library; references to a particular companion application become stale fast and force readers to know unrelated projects to follow the test rationale. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_function_output_state_indexing.py | 4 +-- tests/test_runtime_params.py | 6 ++--- tests/test_single_feasible_action.py | 27 ++++++++------------ tests/test_validation_scalar_actions.py | 4 +-- 4 files changed, 17 insertions(+), 24 deletions(-) diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_state_indexing.py index b4d690ed..2400f33b 100644 --- a/tests/test_function_output_state_indexing.py +++ b/tests/test_function_output_state_indexing.py @@ -6,8 +6,8 @@ NaN at runtime instead of the intended scalar. The validation layer must raise on construction with a clear message -pointing the user at the safe pattern (see `discount_factor` in -`aca_model.agent.preferences`: take the state as input, return a scalar). +pointing the user at the safe pattern: take the discrete state as input +on the producing function, return a scalar. """ import jax.numpy as jnp diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index f68adec4..61b2414a 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -246,10 +246,8 @@ def _make_action_grid_model_with_stateful_dead( ) -> Model: """Variant where `dead` has a `wealth` state so its utility depends on it. - Mirrors the aca-model dead regime (carries assets / pref_type so the - bequest function can read them). Used to surface NaN propagation - when the simulate path forgets to substitute runtime-supplied action - gridpoints. + Used to surface NaN propagation when the simulate path forgets to + substitute runtime-supplied action gridpoints. """ def _alive_utility( diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index b5ce7a3f..53a18594 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -1,5 +1,4 @@ -"""Reproduce the aca-model NaN failure (issue OpenSourceEconomics/aca-model#9): -solve raises `InvalidValueFunctionError: ... regime 'dead': all values are NaN`. +"""Reproduce ways `solve` can raise `InvalidValueFunctionError: all values are NaN`. Two hypotheses are exercised here: @@ -10,16 +9,16 @@ the tests in this file. `-inf` from all-infeasible cells does not cascade to NaN by itself either. -2. **CRRA bequest with pref_type indexing under jnp.where**: the bequest +2. **CRRA bequest with discrete-state indexing under jnp.where**: a bequest function evaluates *both* branches of `jnp.where(jnp.isclose(gamma, 1), log_branch, power_branch)`. For a parameter set where gamma is exactly 1 - for one preference type but not the other, the power_branch divides by + for one categorical type but not the other, the power_branch divides by `1 - gamma = 0` for the gamma=1 type. NaN/Inf from the *unselected* branch leaks through `jnp.where`'s gradient/forward pass under JIT tracing in some configurations. The model uses an `IrregSpacedGrid` consumption action with runtime-supplied -points to also exercise the `feature/runtime-action-grids` path (PR #338). +points to also exercise the runtime-action-grids substitution path. """ import jax.numpy as jnp @@ -263,7 +262,7 @@ def _crra_bequest( consumption_weight: FloatND, coefficient_rra: FloatND, ) -> FloatND: - """Replica of aca_model.agent.preferences.bequest, simplified. + """Simplified CRRA bequest used to probe NaN propagation under jnp.where. `consumption_weight` and `coefficient_rra` are FloatND indexed by `pref_type`. Both branches of the `jnp.where` are traced. @@ -395,14 +394,11 @@ def test_bequest_gamma_exactly_one_for_one_type_only(): def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): """`map_coordinates` cannot recover from a non-finite continuous-state coordinate: the linear-interp weights become `inf` and `1 - inf = -inf`, - and `inf * V[k] - inf * V[k-1]` reduces to NaN. The aca-baseline NaN - at age 51 traces back to some upstream computation (next_assets, - next_aime, or a coordinate finder that divides by zero on a degenerate - grid segment) feeding inf into this path. + and `inf * V[k] - inf * V[k-1]` reduces to NaN. Implication for callers: any path that can feed `inf` or `NaN` into the coordinate finder (division by zero in a state transition, overflow - when V values are O(1e8), or `0/0` in a degenerate IrregSpacedGrid + when V values are large, or `0/0` in a degenerate IrregSpacedGrid segment) will produce NaN in V. """ V_arr = jnp.array([1.0, 5.0, 12.0]) @@ -414,11 +410,10 @@ def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): """`get_irreg_coordinate` divides by `upper_point - lower_point`. If a runtime-supplied points array contains duplicates, the divisor is 0. - Reproduces a class of failures that is *only* possible under - `feature/runtime-action-grids` / runtime-state-grids when the caller - constructs the points from a parameter that can collapse (e.g. - `geomspace(consumption_floor, MAX, n_points)` with `consumption_floor == - MAX`, or any param-driven `linspace` whose endpoints can coincide). + Reproduces a class of failures specific to runtime-supplied grids: when + the caller constructs points from a parameter that can collapse (e.g. a + `geomspace(lo, hi, n)` with `lo == hi`, or any param-driven `linspace` + whose endpoints can coincide). """ # Duplicate adjacent points where the query value equals the duplicate. # `searchsorted([0, 1, 1], 1.0, side='right')=3` → clipped to n-1=2, diff --git a/tests/test_validation_scalar_actions.py b/tests/test_validation_scalar_actions.py index 6855e4b9..89289f5b 100644 --- a/tests/test_validation_scalar_actions.py +++ b/tests/test_validation_scalar_actions.py @@ -7,8 +7,8 @@ Validation must do the same. This pattern arises in models that pass multi-dimensional lookup tables as parameters -via MappingLeaf — e.g. tax schedules and pension accrual tables in aca-model, or -tax-transfer schedules in ttsim/gettsim. +via MappingLeaf — e.g. tax schedules, pension accrual tables, or tax-transfer +schedules. """ import jax From f1c4d5dc2be4b8c28128a45b504276f31e5230ce Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:16:22 +0200 Subject: [PATCH 20/80] Validator: rename to discrete_grid_names, also include discrete actions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The variable previously named `discrete_state_names` accumulated state DiscreteGrids, derived categoricals, and now discrete actions — all three suffer the same per-cell broadcast clash when a consumer does `func_output[X]`. Renamed the variable, the two helpers (`_validate_function_output_grid_indexing`, `_find_function_output_grid_indexing`), the test module (`test_function_output_grid_indexing.py`), and the error-message wording ("discrete state" → "discrete grid"). Added a TDD test for the discrete-action case. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 53 +++++++++++-------- ... => test_function_output_grid_indexing.py} | 53 ++++++++++++++++--- 2 files changed, 76 insertions(+), 30 deletions(-) rename tests/{test_function_output_state_indexing.py => test_function_output_grid_indexing.py} (78%) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 9f883e0c..1e2be941 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -128,7 +128,7 @@ def validate_logical_consistency(regime: Regime) -> None: error_messages.extend(_validate_active(regime.active)) error_messages.extend(_validate_state_transitions(regime)) - error_messages.extend(_validate_function_output_state_indexing(regime)) + error_messages.extend(_validate_function_output_grid_indexing(regime)) states_and_actions_overlap = set(regime.states) & set(regime.actions) if states_and_actions_overlap: @@ -142,22 +142,29 @@ def validate_logical_consistency(regime: Regime) -> None: raise RegimeInitializationError(msg) -def _validate_function_output_state_indexing(regime: Regime) -> list[str]: - """Detect the regime-function-output / state-indexed-input name clash. +def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: + """Detect the regime-function-output / discrete-grid-indexed-input name clash. - A regime function whose output is then re-indexed by a discrete state inside - another consumer (function, constraint, or transition) is a silent footgun: - pylcm broadcasts function outputs to per-cell scalars before consumption, so - the indexing produces NaN at runtime instead of the intended scalar. + A regime function whose output is then re-indexed by a discrete grid (state, + action, or derived categorical) inside another consumer (function, + constraint, or transition) is a silent footgun: pylcm broadcasts function + outputs to per-cell scalars before consumption, so the indexing produces + NaN at runtime instead of the intended scalar. - The safe pattern is to take the state as input on the producing function and - return the scalar directly. + The safe pattern is to take the discrete grid as input on the producing + function and return the scalar directly. """ function_output_names = set(regime.functions) - discrete_state_names = { - name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid) - } | set(regime.derived_categoricals) - if not function_output_names or not discrete_state_names: + discrete_grid_names = ( + {name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid)} + | { + name + for name, grid in regime.actions.items() + if isinstance(grid, DiscreteGrid) + } + | set(regime.derived_categoricals) + ) + if not function_output_names or not discrete_grid_names: return [] consumers: list[tuple[str, Callable]] = [] @@ -168,31 +175,31 @@ def _validate_function_output_state_indexing(regime: Regime) -> list[str]: errors: list[str] = [] for consumer_name, func in consumers: - clashes = _find_function_output_state_indexing( + clashes = _find_function_output_grid_indexing( func=func, function_output_names=function_output_names, - discrete_state_names=discrete_state_names, + discrete_grid_names=discrete_grid_names, ) - for func_output_name, state_name in clashes: + for func_output_name, grid_name in clashes: errors.append( f"Consumer '{consumer_name}' indexes regime function output " - f"'{func_output_name}' by discrete state '{state_name}' " - f"(`{func_output_name}[{state_name}]`). pylcm broadcasts " + f"'{func_output_name}' by discrete grid '{grid_name}' " + f"(`{func_output_name}[{grid_name}]`). pylcm broadcasts " f"function outputs to per-cell scalars before consumption, so " f"this indexing silently produces NaN. Refactor " - f"'{func_output_name}' to take '{state_name}' as input and " + f"'{func_output_name}' to take '{grid_name}' as input and " f"return the scalar directly." ) return errors -def _find_function_output_state_indexing( +def _find_function_output_grid_indexing( *, func: Callable, function_output_names: set[str], - discrete_state_names: set[str], + discrete_grid_names: set[str], ) -> list[tuple[str, str]]: - """Return `(function_output_name, state_name)` clashes inside `func`'s body.""" + """Return `(function_output_name, grid_name)` clashes inside `func`'s body.""" try: source = textwrap.dedent(inspect.getsource(func)) except OSError, TypeError: @@ -212,7 +219,7 @@ def _find_function_output_state_indexing( continue if not isinstance(node.slice, ast.Name): continue - if node.slice.id not in discrete_state_names: + if node.slice.id not in discrete_grid_names: continue clashes.append((node.value.id, node.slice.id)) return clashes diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_grid_indexing.py similarity index 78% rename from tests/test_function_output_state_indexing.py rename to tests/test_function_output_grid_indexing.py index 2400f33b..fa0f2753 100644 --- a/tests/test_function_output_state_indexing.py +++ b/tests/test_function_output_grid_indexing.py @@ -1,12 +1,13 @@ -"""Tests for the regime-function-output / state-indexed-input name clash. +"""Tests for the regime-function-output / discrete-grid-indexed-input name clash. -A regime function whose output is then re-indexed by a discrete state -inside another function is a silent footgun: pylcm broadcasts function -outputs to per-cell scalars before consumption, so the indexing produces -NaN at runtime instead of the intended scalar. +A regime function whose output is then re-indexed by a discrete grid +(state, action, or derived categorical) inside another function is a +silent footgun: pylcm broadcasts function outputs to per-cell scalars +before consumption, so the indexing produces NaN at runtime instead of +the intended scalar. The validation layer must raise on construction with a clear message -pointing the user at the safe pattern: take the discrete state as input +pointing the user at the safe pattern: take the discrete grid as input on the producing function, return a scalar. """ @@ -15,7 +16,7 @@ from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical from lcm.exceptions import RegimeInitializationError -from lcm.typing import ContinuousAction, DiscreteState, FloatND +from lcm.typing import ContinuousAction, DiscreteAction, DiscreteState, FloatND @categorical(ordered=False) @@ -170,6 +171,44 @@ class SpousalIncome: ) +def test_function_output_indexed_by_discrete_action_raises(): + """The check applies to discrete actions, not only states/derived categoricals.""" + + @categorical(ordered=False) + class WorkChoice: + no_work: int + work: int + + def _per_choice_scale(some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param)) + + def _utility_clash( + consumption: ContinuousAction, + labor_supply: DiscreteAction, + per_choice_scale: FloatND, + ) -> FloatND: + return per_choice_scale[labor_supply] * jnp.log(consumption + 1.0) + + with pytest.raises( + RegimeInitializationError, + match=r"per_choice_scale.*labor_supply", + ): + Regime( + functions={ + "utility": _utility_clash, + "per_choice_scale": _per_choice_scale, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={ + "consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5), + "labor_supply": DiscreteGrid(WorkChoice), + }, + transition=_next_regime, + active=lambda age: age < 2, + ) + + def test_constraint_indexing_function_output_by_state_raises(): """The check applies to regime constraints too, not only `functions`.""" From d9f37ceeda2382f527dad8c99c9bc3144696a502 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:22:32 +0200 Subject: [PATCH 21/80] Validator: tighten to actual footgun shape; correct behaviour description MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous docstring claimed the indexing 'silently produces NaN', but a disabled-validator probe shows otherwise: - When the producer takes the discrete grid as input, its output is a per-cell scalar; `func_output[grid]` raises `IndexError: Too many indices` at trace time. This is the real footgun the validator should catch. - When the producer does NOT take the discrete grid as input, its output stays array-shaped and `func_output[grid]` is correct code that solves to sensible V values. The previous validator flagged both shapes — including the safe one — as a clash. Tighten: only fire when the producing function also takes the discrete grid as input. Update the description to match observed behaviour (IndexError, not NaN). Add a regression test that exercises the array-valued-producer + state-indexed-consumer shape and asserts it builds without raising. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 40 ++++++--- tests/test_function_output_grid_indexing.py | 99 ++++++++++++++------- 2 files changed, 93 insertions(+), 46 deletions(-) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 1e2be941..92656692 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -145,14 +145,12 @@ def validate_logical_consistency(regime: Regime) -> None: def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: """Detect the regime-function-output / discrete-grid-indexed-input name clash. - A regime function whose output is then re-indexed by a discrete grid (state, - action, or derived categorical) inside another consumer (function, - constraint, or transition) is a silent footgun: pylcm broadcasts function - outputs to per-cell scalars before consumption, so the indexing produces - NaN at runtime instead of the intended scalar. - - The safe pattern is to take the discrete grid as input on the producing - function and return the scalar directly. + The unsafe pattern is: a regime function `f` takes a discrete grid `g` + (state, action, or derived categorical) as an input — so `f`'s output + is a per-cell scalar — and a consumer then indexes `f[g]`. The + consumer is indexing a 0-d array by a scalar integer, which raises + `IndexError` at trace time. The fix is to drop the redundant `[g]` + in the consumer (or refactor `f` not to take `g`). """ function_output_names = set(regime.functions) discrete_grid_names = ( @@ -167,6 +165,19 @@ def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: if not function_output_names or not discrete_grid_names: return [] + # Only treat `func_output[grid]` as unsafe when the producing function + # *also* takes `grid` as an input — that is the case where the output + # is per-cell scalar and the consumer's indexing is wrong. If the + # producing function does not take `grid`, its output shape is + # whatever it computed (typically an array indexable by `grid`) and + # the consumer pattern is correct. + function_inputs: dict[str, set[str]] = {} + for name, func in regime.functions.items(): + try: + function_inputs[name] = set(inspect.signature(func).parameters) + except ValueError, TypeError: + function_inputs[name] = set() + consumers: list[tuple[str, Callable]] = [] consumers.extend(regime.functions.items()) consumers.extend(regime.constraints.items()) @@ -181,14 +192,17 @@ def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: discrete_grid_names=discrete_grid_names, ) for func_output_name, grid_name in clashes: + if grid_name not in function_inputs.get(func_output_name, set()): + continue errors.append( f"Consumer '{consumer_name}' indexes regime function output " f"'{func_output_name}' by discrete grid '{grid_name}' " - f"(`{func_output_name}[{grid_name}]`). pylcm broadcasts " - f"function outputs to per-cell scalars before consumption, so " - f"this indexing silently produces NaN. Refactor " - f"'{func_output_name}' to take '{grid_name}' as input and " - f"return the scalar directly." + f"(`{func_output_name}[{grid_name}]`), but '{func_output_name}' " + f"already takes '{grid_name}' as input — its output is a " + f"per-cell scalar, so the indexing raises IndexError at trace " + f"time. Drop the redundant `[{grid_name}]` in '{consumer_name}', " + f"or refactor '{func_output_name}' not to take '{grid_name}' " + f"as input." ) return errors diff --git a/tests/test_function_output_grid_indexing.py b/tests/test_function_output_grid_indexing.py index fa0f2753..adc37816 100644 --- a/tests/test_function_output_grid_indexing.py +++ b/tests/test_function_output_grid_indexing.py @@ -1,14 +1,14 @@ """Tests for the regime-function-output / discrete-grid-indexed-input name clash. -A regime function whose output is then re-indexed by a discrete grid -(state, action, or derived categorical) inside another function is a -silent footgun: pylcm broadcasts function outputs to per-cell scalars -before consumption, so the indexing produces NaN at runtime instead of -the intended scalar. - -The validation layer must raise on construction with a clear message -pointing the user at the safe pattern: take the discrete grid as input -on the producing function, return a scalar. +The unsafe pattern is: a regime function `f` takes a discrete grid `g` (state, +action, or derived categorical) as an input — so `f`'s output is a per-cell +scalar — and a consumer then indexes `f[g]`. The consumer is indexing a 0-d +array by a scalar integer, which raises `IndexError` at trace time. The +validator catches the pattern at construction so the user gets a clear message +instead of a cryptic JAX trace error during solve. + +The fix is to drop the redundant `[g]` in the consumer (or refactor `f` not +to take `g`). """ import jax.numpy as jnp @@ -31,19 +31,21 @@ class RegimeId: dead: int -def _per_type_scale(some_param: FloatND) -> FloatND: - """Returns one scale per pref_type — depends only on a per-type Series param.""" - return jnp.abs(1.0 / (1.0 - some_param)) +def _per_type_scale_takes_pref_type( + pref_type: DiscreteState, some_param: FloatND +) -> FloatND: + """Takes pref_type — output is per-cell scalar; consumer must not re-index.""" + return jnp.abs(1.0 / (1.0 - some_param[pref_type])) -def _utility_with_state_indexed_function_output( +def _utility_redundantly_indexes( consumption: ContinuousAction, pref_type: DiscreteState, per_type_scale: FloatND, ) -> FloatND: - # The clash: per_type_scale is registered as a regime function output - # (returns shape (n_pref_types,)), but here it is consumed and indexed - # by the `pref_type` discrete state. + # The clash: per_type_scale's producer takes pref_type, so its output is a + # per-cell scalar. Indexing that scalar by pref_type again raises IndexError + # at trace time. return per_type_scale[pref_type] * jnp.log(consumption + 1.0) @@ -54,8 +56,8 @@ def _next_regime(period: int) -> FloatND: def _make_clashing_model() -> Model: alive = Regime( functions={ - "utility": _utility_with_state_indexed_function_output, - "per_type_scale": _per_type_scale, + "utility": _utility_redundantly_indexes, + "per_type_scale": _per_type_scale_takes_pref_type, }, states={"pref_type": DiscreteGrid(PrefType)}, state_transitions={"pref_type": None}, @@ -76,8 +78,8 @@ def _make_clashing_model() -> Model: def test_function_output_indexed_by_state_raises(): - """A regime function output indexed by a discrete state inside another regime - function must raise on construction (silent NaN bug otherwise).""" + """A regime function output redundantly indexed by a discrete state inside + another regime function must raise on construction.""" with pytest.raises( RegimeInitializationError, match=r"per_type_scale.*pref_type", @@ -85,7 +87,7 @@ def test_function_output_indexed_by_state_raises(): _make_clashing_model() -def _utility_safe( +def _utility_consumes_scalar( consumption: ContinuousAction, pref_type: DiscreteState, # noqa: ARG001 per_type_scale: FloatND, @@ -94,17 +96,48 @@ def _utility_safe( return per_type_scale * jnp.log(consumption + 1.0) -def _per_type_scale_safe(pref_type: DiscreteState, some_param: FloatND) -> FloatND: - """Safe variant: takes pref_type, returns scalar (mirrors discount_factor).""" - return jnp.abs(1.0 / (1.0 - some_param[pref_type])) +def test_safe_pattern_does_not_raise(): + """The safe pattern (function takes the state, returns a scalar; consumer + uses it directly) builds fine.""" + alive = Regime( + functions={ + "utility": _utility_consumes_scalar, + "per_type_scale": _per_type_scale_takes_pref_type, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) -def test_safe_pattern_does_not_raise(): - """The safe pattern (function takes the state, returns a scalar) builds fine.""" +def _per_type_scale_array_output(some_param: FloatND) -> FloatND: + """Does NOT take pref_type — output is `(n_pref_types,)`-shaped. + + A consumer indexing this output by `pref_type` is correct: the indexing + selects the per-type entry. The validator must NOT flag this case. + """ + return jnp.abs(1.0 / (1.0 - some_param)) + + +def test_array_valued_producer_indexed_by_state_does_not_raise(): + """When the producing function does NOT take the discrete grid as input, + its output stays array-shaped and `func_output[grid]` is correct code.""" alive = Regime( functions={ - "utility": _utility_safe, - "per_type_scale": _per_type_scale_safe, + "utility": _utility_redundantly_indexes, + "per_type_scale": _per_type_scale_array_output, }, states={"pref_type": DiscreteGrid(PrefType)}, state_transitions={"pref_type": None}, @@ -136,8 +169,8 @@ class IsMarried: def _is_married(spousal_income: DiscreteState) -> DiscreteState: return jnp.int32(spousal_income > 0) - def _per_marital_scale(some_param: FloatND) -> FloatND: - return jnp.abs(1.0 / (1.0 - some_param)) + def _per_marital_scale(is_married: DiscreteState, some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param[is_married])) def _utility_clash( consumption: ContinuousAction, @@ -179,8 +212,8 @@ class WorkChoice: no_work: int work: int - def _per_choice_scale(some_param: FloatND) -> FloatND: - return jnp.abs(1.0 / (1.0 - some_param)) + def _per_choice_scale(labor_supply: DiscreteAction, some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param[labor_supply])) def _utility_clash( consumption: ContinuousAction, @@ -228,7 +261,7 @@ def _constraint_indexing_function_output( "utility": lambda consumption, pref_type, per_type_scale: jnp.log( # noqa: ARG005 consumption + 1.0 ), - "per_type_scale": _per_type_scale, + "per_type_scale": _per_type_scale_takes_pref_type, }, states={"pref_type": DiscreteGrid(PrefType)}, state_transitions={"pref_type": None}, From 2cc46ff6af60a76e35b97c68a055f433667e070d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:44:25 +0200 Subject: [PATCH 22/80] solve_brute: stream NaN/Inf reductions instead of stacking-and-flushing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #334 introduced a deferred-diagnostics accumulator that appends every (regime, period) NaN/Inf flag to a Python list, stacks the lists at end of solve, and `.tolist()`s the stacks to host. On a 16 GB V100 at production aca-baseline grid sizes the stacked reduction graph holds the per-period `isnan(V_arr)` / `isinf(V_arr)` intermediates alive simultaneously; the post-loop `.tolist()` then asks XLA to compile the fan-in and OOMs on a ~7.3 GiB allocation on top of the already-resident solution V arrays. Symptom: backward induction reports every age as "finished in ~14 ms" (dispatch-async times), then `JaxRuntimeError: RESOURCE_EXHAUSTED` at the first `.tolist()`. Fix: replace the per-period list-append with a running scalar OR; add a per-period `block_until_ready()` so each period's reduction kernel finishes (and its intermediate is freed) before the next period dispatches. `block_until_ready` is device-only — no host transfer, no PCIe round-trip — so it doesn't reintroduce the per-period sync that #334 removed; in practice the small reduction has finished by the time `max_Q_over_a` (~14 ms/period) returns. End of solve: one `.item()` per running scalar. On a healthy solve those two bools are False and we return without materialising any per-row state. Failure paths (`running_any_nan` / `running_any_inf` True) walk `diagnostic_rows` and materialise one bool per row to localise the offender — same total host transfers as the prior code, but only on the failure path. Debug-stats path (`log_level="debug"`) still appends min/max/mean per period; a single per-period `block_until_ready` after the appends frees those intermediates too. The end-of-solve `_log_per_period_stats` keeps the existing per-(regime, period) log line. `_StackedReductions`, `_emit_deferred_diagnostics`, and the old `_raise_if_nan` / `_warn_if_inf` (taking pre-materialised flag lists) are replaced by `_emit_post_loop_diagnostics` (orchestrator), `_raise_first_nan_row`, `_warn_inf_rows`, and `_log_per_period_stats`. Tests: new `tests/solution/test_diagnostics.py` covering the four log levels — happy-path warning, NaN-raise with `(regime, age)` in the message, off-level skip, and per-period debug stats. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 235 ++++++++++++++++------------- tests/solution/test_diagnostics.py | 126 ++++++++++++++++ 2 files changed, 254 insertions(+), 107 deletions(-) create mode 100644 tests/solution/test_diagnostics.py diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index eb70efe9..711c3327 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -73,15 +73,29 @@ def solve( solution: dict[int, MappingProxyType[RegimeName, FloatND]] = {} - # Async diagnostics accumulators: every `jnp.any(isnan)`, - # `jnp.any(isinf)` (and the debug min/max/mean trio) lives here as - # a device-side scalar during the hot loop. No host sync happens - # until the single flush in `_emit_deferred_diagnostics` post-loop. - # This replaces the pre-existing synchronous `log_nan_in_V` + - # `log_V_stats` + `validate_V` triple, which forced one host - # transfer per (regime, period) — ~n_regimes * n_periods stalls - # per solve, a meaningful throughput tax in MSM-style loops. - # Both gates fall out of the public log level: `"off"` ⇒ nothing, + # Async diagnostics accumulators: per-period `jnp.any(isnan)` / + # `jnp.any(isinf)` (and the debug min/max/mean trio) live here as + # device-side scalars during the hot loop. The two NaN/Inf flags + # are folded into single running scalars; the per-period min/max/ + # mean trio is appended to a list (only emitted at debug, where + # we genuinely want every number on host). + # + # Per-period `block_until_ready()` after the running update forces + # the device kernel to finish before the next period dispatches. + # This frees the per-period `isnan(V_arr)` / `isinf(V_arr)` + # intermediate buffers (~2 MB each at production grid sizes) so + # they don't stack up. `block_until_ready` is a *device-only* sync + # — no host transfer, no PCIe round-trip — so it doesn't + # re-introduce the per-period host stalls that #334 removed; if + # `max_Q_over_a` (the dominant per-period kernel) is in flight, + # the call returns immediately when the small reduction is done. + # + # One host transfer per stat at end of solve (`.item()` on the + # running scalars) decides whether to enter the failure-path + # localisation. On a healthy solve no per-row materialisation + # happens. + # + # Gate falls out of the public log level: `"off"` ⇒ nothing, # `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the # min/max/mean trio. `"off"` skips even the NaN fail-fast — that # is the documented contract of `"off"` (suppress all output) and @@ -92,8 +106,8 @@ def solve( diagnostic_min: list[FloatND] = [] diagnostic_max: list[FloatND] = [] diagnostic_mean: list[FloatND] = [] - diagnostic_any_nan: list[FloatND] = [] - diagnostic_any_inf: list[FloatND] = [] + running_any_nan: FloatND = jnp.zeros((), dtype=bool) + running_any_inf: FloatND = jnp.zeros((), dtype=bool) logger.info("Starting solution") total_start = time.monotonic() @@ -136,18 +150,19 @@ def solve( # Async reductions: gated on log level. `"off"` skips # everything — no kernel launches, no host syncs, no - # NaN fail-fast. `"warning"` / `"progress"` launches the - # two cheap isnan/isinf reductions; `"debug"` adds the - # min/max/mean trio. Each extra full-V read is a - # memory-bandwidth tax on the larger models, so the - # default keeps it to two reductions per (regime, period). + # NaN fail-fast. `"warning"` / `"progress"` folds two + # cheap isnan/isinf reductions into the running scalars; + # `"debug"` adds the min/max/mean trio. Each extra full-V + # read is a memory-bandwidth tax on the larger models, so + # the default keeps it to two reductions per (regime, + # period). if diagnostics_enabled: if stats_enabled: diagnostic_min.append(jnp.min(V_arr)) diagnostic_max.append(jnp.max(V_arr)) diagnostic_mean.append(jnp.mean(V_arr)) - diagnostic_any_nan.append(jnp.any(jnp.isnan(V_arr))) - diagnostic_any_inf.append(jnp.any(jnp.isinf(V_arr))) + running_any_nan = running_any_nan | jnp.any(jnp.isnan(V_arr)) + running_any_inf = running_any_inf | jnp.any(jnp.isinf(V_arr)) diagnostic_rows.append( _DiagnosticRow( regime_name=regime_name, @@ -166,6 +181,21 @@ def solve( period_solution[regime_name] = V_arr + # Force the device-side reduction kernels to finish before the + # next period dispatches, so each period's `isnan` / `isinf` + # (and min/max/mean) intermediate buffers can be freed instead + # of stacking up. `block_until_ready` does NOT transfer to host + # — it is a device-side wait, cheap when the dominant + # per-period kernel (`max_Q_over_a`) is the actual bottleneck. + if diagnostics_enabled: + running_any_nan.block_until_ready() + running_any_inf.block_until_ready() + if stats_enabled and diagnostic_mean: + # Blocking on the last-appended stat suffices: XLA + # serialises dispatch order, so a finished `mean` + # implies a finished `min`/`max` too. + diagnostic_mean[-1].block_until_ready() + # Maintain consistent pytree structure: keep all regime keys, # update active regimes with solved V arrays. next_regime_to_V_arr = MappingProxyType( @@ -181,22 +211,16 @@ def solve( elapsed = time.monotonic() - period_start log_period_timing(logger=logger, elapsed=elapsed) - # One flush of the GPU kernel queue: ship the stacked reductions - # to host in two transfers (isnan / isinf) by default, plus three - # more (min / max / mean) when debug stats were enabled. Skipped - # entirely at `log_level="off"` — nothing was accumulated. if diagnostics_enabled: - _emit_deferred_diagnostics( + _emit_post_loop_diagnostics( logger=logger, diagnostic_rows=diagnostic_rows, - reductions=_StackedReductions( - mins=jnp.stack(diagnostic_min) if diagnostic_min else None, - maxs=jnp.stack(diagnostic_max) if diagnostic_max else None, - means=jnp.stack(diagnostic_mean) if diagnostic_mean else None, - any_nan=jnp.stack(diagnostic_any_nan), - any_inf=jnp.stack(diagnostic_any_inf), - ), solution=MappingProxyType(solution), + running_any_nan=running_any_nan, + running_any_inf=running_any_inf, + diagnostic_min=diagnostic_min if stats_enabled else None, + diagnostic_max=diagnostic_max if stats_enabled else None, + diagnostic_mean=diagnostic_mean if stats_enabled else None, ) total_elapsed = time.monotonic() - total_start @@ -417,91 +441,61 @@ class _DiagnosticRow: compute-intermediates closure (e.g. terminal periods).""" -@dataclass(frozen=True) -class _StackedReductions: - """Per-stat JAX arrays stacked across all diagnostic rows; still on device. - - `mins` / `maxs` / `means` are `None` when the solve ran with a log - level below `debug` — the GPU wasn't asked to compute those - statistics so there's nothing to stack. - """ - - mins: FloatND | None - """Per-row min of V, or `None` below debug log level.""" - maxs: FloatND | None - """Per-row max of V, or `None` below debug log level.""" - means: FloatND | None - """Per-row mean of V, or `None` below debug log level.""" - any_nan: FloatND - """Per-row boolean flag: any NaN in V at this (regime, period).""" - any_inf: FloatND - """Per-row boolean flag: any Inf in V at this (regime, period).""" - - -def _emit_deferred_diagnostics( +def _emit_post_loop_diagnostics( *, logger: logging.Logger, diagnostic_rows: list[_DiagnosticRow], - reductions: _StackedReductions, solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + running_any_nan: FloatND, + running_any_inf: FloatND, + diagnostic_min: list[FloatND] | None, + diagnostic_max: list[FloatND] | None, + diagnostic_mean: list[FloatND] | None, ) -> None: - """Flush async diagnostics to host, emit logs, raise on NaN. - - Exactly two host transfers by default (one per stat stack), plus - three more (min / max / mean) when debug stats were enabled. - Ordering: NaN check first so we raise before emitting any stats - lines the user wouldn't see anyway; inf check next (warning only); - per-period stats last at debug log level. The `.tolist()` calls - are what actually block on the GPU queue — everything above this - function ran async. + """Flush async diagnostics: raise on NaN, warn on Inf, log debug stats. + + Two host transfers (the `.item()` calls on the running scalars) + decide whether we enter the per-row failure-path localisation. On + a healthy solve neither inner walk runs and no per-row scalar is + materialised — the property that lets a production-sized solve at + `log_level="warning"` fit on a 16 GB device that was OOMing on the + previous stack-and-flush pattern. """ - any_nan = reductions.any_nan.tolist() - any_inf = reductions.any_inf.tolist() - - _raise_if_nan( - diagnostic_rows=diagnostic_rows, - any_nan_per_row=any_nan, - solution=solution, - ) - _warn_if_inf( - logger=logger, - diagnostic_rows=diagnostic_rows, - any_inf_per_row=any_inf, - ) - - if ( - not logger.isEnabledFor(logging.DEBUG) - or reductions.mins is None - or reductions.maxs is None - or reductions.means is None - ): - return - - mins = reductions.mins.tolist() - maxs = reductions.maxs.tolist() - means = reductions.means.tolist() - for row, v_min, v_max, v_mean in zip( - diagnostic_rows, mins, maxs, means, strict=True - ): - logger.debug( - " %s age %s V min=%.3g max=%.3g mean=%.3g", - row.regime_name, - row.age, - v_min, - v_max, - v_mean, + if running_any_nan.item(): + _raise_first_nan_row( + diagnostic_rows=diagnostic_rows, + solution=solution, + ) + if running_any_inf.item(): + _warn_inf_rows( + logger=logger, + diagnostic_rows=diagnostic_rows, + solution=solution, + ) + if diagnostic_min is not None and diagnostic_max is not None and diagnostic_mean: + _log_per_period_stats( + logger=logger, + diagnostic_rows=diagnostic_rows, + mins=jnp.stack(diagnostic_min), + maxs=jnp.stack(diagnostic_max), + means=jnp.stack(diagnostic_mean), ) -def _raise_if_nan( +def _raise_first_nan_row( *, diagnostic_rows: list[_DiagnosticRow], - any_nan_per_row: list, # list[bool] solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], ) -> None: - """Find the first NaN-bearing (regime, period) and raise.""" - for row, flag in zip(diagnostic_rows, any_nan_per_row, strict=True): - if flag: + """Find the first NaN-bearing (regime, period) and raise. + + Only invoked on the failure path (`running_any_nan` was True). + Materialises one host-side bool per row until the first hit; on + a healthy solve this function is never called. + """ + for row in diagnostic_rows: + V_arr = solution[row.period][row.regime_name] + if jnp.any(jnp.isnan(V_arr)).item(): _raise_at(row=row, solution=solution) @@ -525,17 +519,44 @@ def _raise_at( ) -def _warn_if_inf( +def _warn_inf_rows( *, logger: logging.Logger, diagnostic_rows: list[_DiagnosticRow], - any_inf_per_row: list, # list[bool] + solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], ) -> None: - """Emit a warning per (regime, period) with Inf values.""" - for row, flag in zip(diagnostic_rows, any_inf_per_row, strict=True): - if flag: + """Emit a warning per (regime, period) with Inf values. + + Only invoked on the failure path (`running_any_inf` was True). + Materialises one host-side bool per row. + """ + for row in diagnostic_rows: + V_arr = solution[row.period][row.regime_name] + if jnp.any(jnp.isinf(V_arr)).item(): logger.warning( "Inf in V_arr for regime '%s' at age %s", row.regime_name, row.age, ) + + +def _log_per_period_stats( + *, + logger: logging.Logger, + diagnostic_rows: list[_DiagnosticRow], + mins: FloatND, + maxs: FloatND, + means: FloatND, +) -> None: + """Emit one debug log line per (regime, period) with V min/max/mean.""" + for row, v_min, v_max, v_mean in zip( + diagnostic_rows, mins.tolist(), maxs.tolist(), means.tolist(), strict=True + ): + logger.debug( + " %s age %s V min=%.3g max=%.3g mean=%.3g", + row.regime_name, + row.age, + v_min, + v_max, + v_mean, + ) diff --git a/tests/solution/test_diagnostics.py b/tests/solution/test_diagnostics.py new file mode 100644 index 00000000..dbdc9497 --- /dev/null +++ b/tests/solution/test_diagnostics.py @@ -0,0 +1,126 @@ +"""Tests for the post-loop diagnostics path in `solve_brute.solve`. + +These cover: +- happy path at `log_level="warning"` runs without raising and without + the deferred-stack fan-in that previously OOMed at production sizes; +- NaN-bearing solves raise `InvalidValueFunctionError` and the message + identifies the offending `(regime, age)`; +- `log_level="debug"` emits one stat line per `(regime, period)`; +- `log_level="off"` emits nothing and skips even the NaN fail-fast. +""" + +import logging +from pathlib import Path + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.exceptions import InvalidValueFunctionError +from lcm.typing import ContinuousAction, ContinuousState, FloatND + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _utility(consumption: ContinuousAction, wealth: ContinuousState) -> FloatND: + return jnp.log(consumption + 1) + 0.01 * wealth + + +def _next_wealth( + wealth: ContinuousState, + consumption: ContinuousAction, + interest_rate: float, +) -> ContinuousState: + return (1 + interest_rate) * (wealth - consumption) + + +def _borrowing_constraint( + consumption: ContinuousAction, wealth: ContinuousState +) -> FloatND: + return consumption <= wealth + + +def _next_regime(period: int) -> FloatND: + return jnp.where(period >= 1, RegimeId.dead, RegimeId.alive) + + +def _make_model() -> Model: + alive = Regime( + functions={"utility": _utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5, n_points=5)}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +_HEALTHY_PARAMS = {"discount_factor": 0.95, "interest_rate": 0.05} + + +def test_warning_level_solves_without_per_row_materialisation(): + """Happy-path solve at log_level="warning" returns finite V without + entering the failure-path localisation.""" + model = _make_model() + period_to_regime_to_V_arr = model.solve(params=_HEALTHY_PARAMS, log_level="warning") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + assert not jnp.any(jnp.isinf(V_arr)) + + +def test_nan_failure_raises_with_regime_and_age(): + """A NaN-producing parameter set raises with the offending (regime, age). + + `discount_factor=NaN` poisons the next-V contribution to Q on the + first non-terminal period; the validator must surface the offending + regime in the error message. + """ + model = _make_model() + params = {**_HEALTHY_PARAMS, "discount_factor": float("nan")} + with pytest.raises(InvalidValueFunctionError, match=r"alive"): + model.solve(params=params, log_level="warning") + + +def test_off_level_solves_without_diagnostics(caplog: pytest.LogCaptureFixture): + """log_level="off" emits no diagnostic records and skips the NaN fail-fast. + + Even with a NaN-producing parameter set, solve() returns instead of + raising — the documented contract of `"off"`. + """ + model = _make_model() + params = {**_HEALTHY_PARAMS, "discount_factor": float("nan")} + with caplog.at_level(logging.DEBUG): + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + assert period_to_regime_to_V_arr is not None + assert not [r for r in caplog.records if r.levelno >= logging.WARNING] + + +def test_debug_level_emits_per_period_stats( + caplog: pytest.LogCaptureFixture, tmp_path: Path +): + """log_level="debug" logs a min/max/mean line for every (regime, period).""" + model = _make_model() + with caplog.at_level(logging.DEBUG, logger="lcm"): + model.solve(params=_HEALTHY_PARAMS, log_level="debug", log_path=tmp_path) + debug_stat_lines = [ + r + for r in caplog.records + if "V min=" in r.getMessage() and "max=" in r.getMessage() + ] + assert len(debug_stat_lines) >= 1 From 365da0722f9e308cd84270d86c7e8c42374939e4 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 12:19:08 +0200 Subject: [PATCH 23/80] solve_brute: stop pinning per-period V templates in diagnostic_rows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each `_DiagnosticRow` previously held the active-period `state_action_space`, the rolling `next_regime_to_V_arr`, the regime's flat params, and a `compute_intermediates` closure (which itself captured the state_action_space). At production grid sizes — 50+ periods × ~6 active regimes — the accumulated references pin every period's full-shape V mapping in device memory, OOMing the V100 16 GB mid-loop on `block_until_ready` (the next allocation that has nowhere to go). The streaming NaN/Inf reduction landed earlier addressed only the per-period reduction buffers; the row-level retention is the larger leak. Strip `_DiagnosticRow` to the three Python scalars actually needed for failure-path localisation (`regime_name`, `period`, `age`) and reconstruct the heavy bits from `solution`, `internal_regimes`, and `internal_params` inside `_raise_at`. The reconstruction mirrors the loop's roll-forward semantics: for each regime, take the smallest later period in `solution` where the regime was active, falling back to a zeros template — the same value the rolling `next_regime_to_V_arr` slot held during the live dispatch. Also lock the row's shape via a structural test so future changes that re-introduce device-backed fields fail loudly in CI rather than silently regressing OOM behaviour. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 109 +++++++++++++++++++++++--------- tests/test_nan_diagnostics.py | 22 +++++++ 2 files changed, 100 insertions(+), 31 deletions(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 711c3327..b9ead0ed 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -12,7 +12,7 @@ from lcm.ages import AgeGrid from lcm.interfaces import InternalRegime -from lcm.typing import FlatRegimeParams, FloatND, InternalParams, RegimeName +from lcm.typing import FloatND, InternalParams, RegimeName from lcm.utils.error_handling import validate_V from lcm.utils.logging import ( format_duration, @@ -168,14 +168,6 @@ def solve( regime_name=regime_name, period=period, age=float(ages.values[period]), - state_action_space=state_action_space, - next_regime_to_V_arr=next_regime_to_V_arr, - regime_params=internal_params[regime_name], - compute_intermediates=( - internal_regime.solve_functions.compute_intermediates.get( - period - ) - ), ) ) @@ -216,6 +208,8 @@ def solve( logger=logger, diagnostic_rows=diagnostic_rows, solution=MappingProxyType(solution), + internal_regimes=internal_regimes, + internal_params=internal_params, running_any_nan=running_any_nan, running_any_inf=running_any_inf, diagnostic_min=diagnostic_min if stats_enabled else None, @@ -414,11 +408,17 @@ def _get_regime_V_shapes( class _DiagnosticRow: """Metadata captured during the backward-induction loop. - Stored refs only — no device work — so appending these rows inside - the hot loop costs essentially nothing. The expensive part (NaN - diagnostic enrichment via `compute_intermediates`) runs at most - once per solve, on the first offending row found after the single - post-loop host flush. + Holds only Python-scalar metadata — no device-array references — so + every (regime, period) row stays at a few bytes. The expensive bits + (state-action space, next-period V mapping, params, the + `compute_intermediates` closure) are reconstructed lazily on the + failure path from `internal_regimes`, `internal_params`, and the + partial `solution` that has been built up to that point. + + The earlier design captured those device-backed objects directly on + each row, which pinned every period's V template in device memory + until the post-loop flush — at production grid sizes that hits OOM + well before the loop completes. """ regime_name: RegimeName @@ -427,18 +427,6 @@ class _DiagnosticRow: """Period index in the backward-induction loop.""" age: float """Age corresponding to `period` (pulled off `AgeGrid.values`).""" - state_action_space: object - """Typed as `object` to avoid a heavy import cycle; consumers know - the actual runtime type from the `max_Q_over_a` signature.""" - next_regime_to_V_arr: MappingProxyType[RegimeName, FloatND] - """Incoming next-period V-arrays, passed through unchanged to - `compute_intermediates` when a NaN is detected.""" - regime_params: FlatRegimeParams - """Flat regime parameters used at this (regime, period).""" - compute_intermediates: Callable | None - """Optional closure that recomputes U / F / E[V] / Q for NaN - diagnostic enrichment. `None` when the regime has no - compute-intermediates closure (e.g. terminal periods).""" def _emit_post_loop_diagnostics( @@ -446,6 +434,8 @@ def _emit_post_loop_diagnostics( logger: logging.Logger, diagnostic_rows: list[_DiagnosticRow], solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, running_any_nan: FloatND, running_any_inf: FloatND, diagnostic_min: list[FloatND] | None, @@ -465,6 +455,8 @@ def _emit_post_loop_diagnostics( _raise_first_nan_row( diagnostic_rows=diagnostic_rows, solution=solution, + internal_regimes=internal_regimes, + internal_params=internal_params, ) if running_any_inf.item(): _warn_inf_rows( @@ -486,6 +478,8 @@ def _raise_first_nan_row( *, diagnostic_rows: list[_DiagnosticRow], solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, ) -> None: """Find the first NaN-bearing (regime, period) and raise. @@ -496,29 +490,82 @@ def _raise_first_nan_row( for row in diagnostic_rows: V_arr = solution[row.period][row.regime_name] if jnp.any(jnp.isnan(V_arr)).item(): - _raise_at(row=row, solution=solution) + _raise_at( + row=row, + solution=solution, + internal_regimes=internal_regimes, + internal_params=internal_params, + ) def _raise_at( *, row: _DiagnosticRow, solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, ) -> None: """Run the enriched NaN diagnostic on a single offending row and raise.""" + internal_regime = internal_regimes[row.regime_name] + regime_params = internal_params[row.regime_name] + state_action_space = internal_regime.state_action_space(regime_params=regime_params) + next_regime_to_V_arr = _reconstruct_next_regime_to_V_arr( + period=row.period, + internal_regimes=internal_regimes, + internal_params=internal_params, + solution=solution, + ) + compute_intermediates = internal_regime.solve_functions.compute_intermediates.get( + row.period + ) V_arr = solution[row.period][row.regime_name] validate_V( V_arr=V_arr, age=row.age, regime_name=row.regime_name, partial_solution=solution, - compute_intermediates=row.compute_intermediates, - state_action_space=row.state_action_space, # ty: ignore[invalid-argument-type] - next_regime_to_V_arr=row.next_regime_to_V_arr, - internal_params=row.regime_params, + compute_intermediates=compute_intermediates, + state_action_space=state_action_space, + next_regime_to_V_arr=next_regime_to_V_arr, + internal_params=regime_params, period=row.period, ) +def _reconstruct_next_regime_to_V_arr( + *, + period: int, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], +) -> MappingProxyType[RegimeName, FloatND]: + """Recreate the rolling `next_regime_to_V_arr` that was used at `period`. + + The hot loop rolls the per-regime V forward via `period_solution.get(name, + next_regime_to_V_arr[name])`, so at iteration `period` each regime's slot + holds its V from the smallest later period where it was active, falling + back to a zeros template otherwise. + + We rebuild the same mapping post-hoc from `solution`. The shapes come from + the regime's state-action space at the supplied params — identical to what + `_get_regime_V_shapes` saw during solve setup. + """ + regime_V_shapes = _get_regime_V_shapes( + internal_regimes=internal_regimes, + internal_params=internal_params, + ) + later_periods = sorted(p for p in solution if p > period) + result: dict[RegimeName, FloatND] = {} + for regime_name, shape in regime_V_shapes.items(): + rolled: FloatND | None = None + for q in later_periods: + if regime_name in solution[q]: + rolled = solution[q][regime_name] + break + result[regime_name] = rolled if rolled is not None else jnp.zeros(shape) + return MappingProxyType(result) + + def _warn_inf_rows( *, logger: logging.Logger, diff --git a/tests/test_nan_diagnostics.py b/tests/test_nan_diagnostics.py index 47c42f32..779f90c8 100644 --- a/tests/test_nan_diagnostics.py +++ b/tests/test_nan_diagnostics.py @@ -146,6 +146,28 @@ def borrowing_constraint( return model, params +def test_diagnostic_row_holds_only_python_scalars() -> None: + """`_DiagnosticRow` must not pin device-backed objects. + + Earlier the row stored `state_action_space`, `next_regime_to_V_arr`, + `regime_params`, and a `compute_intermediates` closure (which itself + captured the state_action_space). Across periods these refs accumulated, + pinning every period's V template in device memory until the post-loop + flush — at production grid sizes that hits OOM well before the loop + completes. The failure path now reconstructs those objects from `solution`, + `internal_regimes`, and `internal_params` instead. + """ + from lcm.solution.solve_brute import _DiagnosticRow # noqa: PLC0415 + + expected = {"regime_name", "period", "age"} + actual = set(_DiagnosticRow.__dataclass_fields__) + assert actual == expected, ( + f"_DiagnosticRow must hold only {expected}; got {actual}. Adding " + "device-backed fields here pins per-period V templates in device " + "memory and re-introduces the OOM during long backward inductions." + ) + + def test_nan_diagnostics_end_to_end() -> None: """Real model: `model.solve()` attaches a diagnostics dict when V has NaN. From bf1cdf4dab35867a3280c9be7fe1a1f6f9f2a493 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 07:13:33 +0200 Subject: [PATCH 24/80] solve_brute: fail-fast on NaN per period; rewrite stale diagnostic hint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes targeting the NaN-in-V failure path: 1. Fail-fast at age boundary. Adds a per-period `running_any_nan.item()` host transfer right after the existing `block_until_ready`. On True, the loop breaks out and the existing post-loop emitter raises immediately. Cost: one scalar bool transfer per period — negligible next to `max_Q_over_a`. Without this, backward induction would finish the entire age range (potentially ~2h on production grids) before raising at the first-NaN row, leaving the user staring at an idle-looking solve. Inf stays non-fatal; the post-loop warning still fires for any period that flagged it. 2. Drop the misleading "re-solve with debug logging" suggestion from `validate_V`. The diagnostic [NOTE] is added inline by `_enrich_with_diagnostics` whenever `compute_intermediates` is wired up — i.e. on the default path — so suggesting a re-solve to "produce" diagnostics is wrong: they were already produced. Replace with a pointer to the [NOTE] for the per-axis breakdown plus a mention of `log_path=...` for snapshot persistence (the only thing debug-mode actually adds beyond the inline diagnostic). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 10 ++++++++++ src/lcm/utils/error_handling.py | 12 ++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index b9ead0ed..653b013d 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -203,6 +203,16 @@ def solve( elapsed = time.monotonic() - period_start log_period_timing(logger=logger, elapsed=elapsed) + # Fail-fast on NaN: surface the offending period immediately + # instead of finishing the whole backward induction. Costs one + # host transfer of a scalar bool per period — negligible next + # to the per-period `max_Q_over_a` kernel, and only paid when + # diagnostics are on. Inf is non-fatal so we don't break on + # it; the post-loop emitter still raises a warning if any + # period flagged Inf. + if diagnostics_enabled and running_any_nan.item(): + break + if diagnostics_enabled: _emit_post_loop_diagnostics( logger=logger, diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index c7f16008..06929765 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -93,12 +93,12 @@ def validate_V( "(e.g. from a NaN survival probability or a NaN fixed param).\n" "- A per-target state_transitions dict omits a reachable target " "(non-zero transition probability to an incomplete target).\n\n" - "To diagnose, re-solve with debug logging:\n\n" - ' model.solve(params=params, log_level="debug", ' - 'log_path="./debug/")\n\n' - "The snapshot saved on failure contains diagnostics that pinpoint " - "where NaN enters (U, E[V], or regime transitions). See the " - "debugging guide:\n" + "When `compute_intermediates` is wired up (the default), the " + "[NOTE] below pinpoints which intermediate (U, F, E[V], or " + "regime transitions) introduces the NaN and along which state " + "axes it concentrates. To persist a snapshot of the partial " + "solution for offline inspection, pass `log_path=...` to " + "`solve(...)` / `simulate(...)`. See the debugging guide:\n" "https://pylcm.readthedocs.io/en/latest/user_guide/debugging/" ) exc.partial_solution = partial_solution From c7745f34821f18042e59db49b47355905392f896 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 07:53:34 +0200 Subject: [PATCH 25/80] solve/simulate: surface snapshot path in NaN exception note MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When `log_path` is configured, the failure path already calls `save_solve_snapshot(...)` (`model.py:223-230` and `:334-341`) before re-raising — but the path it returns wasn't surfaced anywhere, so the user saw a generic "pass `log_path=...`" hint pointing them to do something they had already done. Capture the returned `snap_dir` and attach it via `exc.add_note(f"Snapshot saved to {snap_dir}")`. The note appears alongside the diagnostic-summary note that `_enrich_with_diagnostics` adds, so the user sees both the per-axis NaN breakdown and the exact `solve_snapshot_NNN/` directory in one exception. Drop the now-redundant `log_path=...` suggestion from `validate_V`'s message. Replace with a short pointer to the [NOTE] block: when `log_path` is set, the second note has the path; when it isn't, the inline diagnostic still pinpoints the offending intermediate. The debugging-guide link stays. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 6 ++++-- src/lcm/utils/error_handling.py | 11 +++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/lcm/model.py b/src/lcm/model.py index 7ed2981a..9412d660 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -221,13 +221,14 @@ def solve( ) except InvalidValueFunctionError as exc: if log_path is not None and exc.partial_solution is not None: - save_solve_snapshot( + snap_dir = save_solve_snapshot( model=self, params=params, period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] log_path=Path(log_path), log_keep_n_latest=log_keep_n_latest, ) + exc.add_note(f"Snapshot saved to {snap_dir}") raise if log_level == "debug" and log_path is not None: save_solve_snapshot( @@ -331,13 +332,14 @@ def simulate( ) except InvalidValueFunctionError as exc: if log_path is not None and exc.partial_solution is not None: - save_solve_snapshot( + snap_dir = save_solve_snapshot( model=self, params=params, period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] log_path=Path(log_path), log_keep_n_latest=log_keep_n_latest, ) + exc.add_note(f"Snapshot saved to {snap_dir}") raise result = simulate( internal_params=internal_params, diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index 06929765..ef631349 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -93,12 +93,11 @@ def validate_V( "(e.g. from a NaN survival probability or a NaN fixed param).\n" "- A per-target state_transitions dict omits a reachable target " "(non-zero transition probability to an incomplete target).\n\n" - "When `compute_intermediates` is wired up (the default), the " - "[NOTE] below pinpoints which intermediate (U, F, E[V], or " - "regime transitions) introduces the NaN and along which state " - "axes it concentrates. To persist a snapshot of the partial " - "solution for offline inspection, pass `log_path=...` to " - "`solve(...)` / `simulate(...)`. See the debugging guide:\n" + "See the [NOTE] below for the per-intermediate / per-axis " + "breakdown produced by `compute_intermediates`. When `log_path` " + "is configured, an additional [NOTE] points to the on-disk " + "snapshot directory written before this exception was raised. " + "Debugging guide:\n" "https://pylcm.readthedocs.io/en/latest/user_guide/debugging/" ) exc.partial_solution = partial_solution From bc067c10215733c563463d7dd5c35d8ff0679678 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 11:50:52 +0200 Subject: [PATCH 26/80] Model.n_subjects: AOT-compile simulate functions for fixed batch shape When the user declares the simulate batch size up front via `Model(n_subjects=N)`, the first matching `simulate(...)` call now AOT- compiles every unique simulate function for that shape in parallel (`ThreadPoolExecutor` over `lower(...).compile()`), mirroring solve's existing AOT path in `solve_brute._compile_all_functions`. Subsequent calls with the same size hit the cache; calls with a mismatching size warn once per size and fall back to the runtime-traced path. Also normalises `period_to_regime_to_V_arr` at the entry of `simulate` so every period dispatches with the same pytree (active-regime padding with zeros). Without this the last period's empty `next_regime_to_V_arr` breaks both the AOT pytree signature and JAX's own JIT cache reuse. `n_subjects=None` (the default) preserves the previous lazy behaviour. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 72 +++++- src/lcm/model_processing.py | 16 ++ src/lcm/simulation/compile.py | 359 ++++++++++++++++++++++++++ src/lcm/simulation/simulate.py | 67 ++++- tests/simulation/test_simulate_aot.py | 180 +++++++++++++ 5 files changed, 689 insertions(+), 5 deletions(-) create mode 100644 src/lcm/simulation/compile.py create mode 100644 tests/simulation/test_simulate_aot.py diff --git a/src/lcm/model.py b/src/lcm/model.py index 9412d660..8935535a 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -1,6 +1,7 @@ """Collection of classes that are used by the user to define the model and grids.""" import dataclasses +import logging from collections.abc import Mapping from pathlib import Path from types import MappingProxyType @@ -30,6 +31,7 @@ ) from lcm.regime import Regime from lcm.regime_building.processing import InternalRegime +from lcm.simulation.compile import compile_all_simulate_functions from lcm.simulation.initial_conditions import validate_initial_conditions from lcm.simulation.result import SimulationResult, get_simulation_output_dtypes from lcm.simulation.simulate import simulate @@ -85,6 +87,16 @@ class Model: fixed_params: UserParams """Parameters fixed at model initialization.""" + n_subjects: int | None = None + """Expected simulate batch size; enables AOT compile of simulate functions. + + When set, the first matching `simulate(...)` call AOT-compiles all simulate + functions for batch shape `n_subjects` in parallel. Subsequent calls with the + same size reuse the compiled programs. Calls with a mismatching size warn + once per size and fall back to the runtime-traced path. `None` keeps the + purely lazy behaviour. + """ + _params_template: ParamsTemplate """Template for the model parameters.""" @@ -100,6 +112,7 @@ def __init__( derived_categoricals: Mapping[FunctionName, DiscreteGrid] = MappingProxyType( {} ), + n_subjects: int | None = None, ) -> None: """Initialize the Model. @@ -115,17 +128,27 @@ def __init__( not in states/actions. Broadcast to all regimes (merged with each regime's own `derived_categoricals`). Raises if a regime already has a conflicting entry. + n_subjects: Expected simulate batch size; if set, the first matching + `simulate(...)` call AOT-compiles all simulate functions for + batch shape `n_subjects` in parallel. `None` keeps the purely + lazy behaviour. """ self.description = description self.ages = ages self.n_periods = ages.n_periods self.fixed_params = ensure_containers_are_immutable(fixed_params) + self.n_subjects = n_subjects + self._simulate_compile_cache: dict[ + int, MappingProxyType[RegimeName, InternalRegime] + ] = {} + self._warned_n_subjects: set[int] = set() validate_model_inputs( n_periods=self.n_periods, regimes=regimes, regime_id_class=regime_id_class, + n_subjects=n_subjects, ) self.regime_names_to_ids = MappingProxyType( dict( @@ -240,6 +263,46 @@ def solve( ) return period_to_regime_to_V_arr + def _resolve_simulate_internal_regimes( + self, + *, + actual_n_subjects: int, + internal_params: InternalParams, + log: logging.Logger, + max_compilation_workers: int | None, + ) -> MappingProxyType[RegimeName, InternalRegime]: + """Return internal_regimes to use for simulate; AOT cache when matching. + + Returns the original `internal_regimes` when `n_subjects` is `None` or + when the actual batch size mismatches the declared one (logging a + warning once per mismatching size). Otherwise builds and caches the + AOT-compiled regimes for the matching size. + """ + if self.n_subjects is None: + return self.internal_regimes + if actual_n_subjects != self.n_subjects: + if actual_n_subjects not in self._warned_n_subjects: + log.warning( + "simulate called with n_subjects=%d but model declared " + "n_subjects=%d; falling back to runtime compile.", + actual_n_subjects, + self.n_subjects, + ) + self._warned_n_subjects.add(actual_n_subjects) + return self.internal_regimes + if self.n_subjects not in self._simulate_compile_cache: + self._simulate_compile_cache[self.n_subjects] = ( + compile_all_simulate_functions( + internal_regimes=self.internal_regimes, + internal_params=internal_params, + ages=self.ages, + n_subjects=self.n_subjects, + max_compilation_workers=max_compilation_workers, + logger=log, + ) + ) + return self._simulate_compile_cache[self.n_subjects] + def simulate( self, *, @@ -341,10 +404,17 @@ def simulate( ) exc.add_note(f"Snapshot saved to {snap_dir}") raise + actual_n_subjects = len(next(iter(initial_conditions.values()))) + simulate_internal_regimes = self._resolve_simulate_internal_regimes( + actual_n_subjects=actual_n_subjects, + internal_params=internal_params, + log=log, + max_compilation_workers=max_compilation_workers, + ) result = simulate( internal_params=internal_params, initial_conditions=initial_conditions, - internal_regimes=self.internal_regimes, + internal_regimes=simulate_internal_regimes, regime_names_to_ids=self.regime_names_to_ids, logger=log, period_to_regime_to_V_arr=period_to_regime_to_V_arr, diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index 23dd60a7..d141c6a3 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -147,8 +147,11 @@ def validate_model_inputs( n_periods: int, regimes: Mapping[RegimeName, Regime], regime_id_class: type, + n_subjects: int | None = None, ) -> None: """Validate model constructor inputs.""" + _fail_if_invalid_n_subjects(n_subjects=n_subjects) + # Early exit if regimes are not lcm.Regime instances if not all(isinstance(regime, Regime) for regime in regimes.values()): raise ModelInitializationError( @@ -201,6 +204,19 @@ def validate_model_inputs( raise ModelInitializationError(msg) +def _fail_if_invalid_n_subjects(*, n_subjects: int | None) -> None: + """Raise TypeError if non-int, ValueError if non-positive.""" + if n_subjects is None: + return + # `bool` is a subclass of `int`; reject explicitly so True/False don't slip through. + if not isinstance(n_subjects, int) or isinstance(n_subjects, bool): + msg = f"n_subjects must be an int or None, got {type(n_subjects).__name__}." + raise TypeError(msg) + if n_subjects <= 0: + msg = f"n_subjects must be a positive integer, got {n_subjects}." + raise ValueError(msg) + + def _validate_all_variables_used(regimes: Mapping[RegimeName, Regime]) -> list[str]: """Validate that all states and actions are used somewhere in each regime. diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py new file mode 100644 index 00000000..316d877b --- /dev/null +++ b/src/lcm/simulation/compile.py @@ -0,0 +1,359 @@ +"""AOT-compile simulate functions for a fixed batch size. + +When `Model(n_subjects=N)` is set, `compile_all_simulate_functions(...)` returns +an `internal_regimes` mapping with each regime's `simulate_functions` callables +swapped for AOT-compiled programs sized for batch shape `N`. The existing +simulate call sites then pick them up transparently — no signature changes +downstream. + +Mirrors the pattern in `solve_brute._compile_all_functions`: deduplicate by +callable identity, sequentially lower (tracing is not thread-safe), then +parallel-compile via `ThreadPoolExecutor` (XLA releases the GIL). +""" + +import dataclasses +import logging +import time +from collections.abc import Callable, Hashable, Mapping +from concurrent.futures import ThreadPoolExecutor, as_completed +from types import MappingProxyType + +import jax +import jax.numpy as jnp +from dags.tree import tree_path_from_qname +from jax import Array + +from lcm.ages import AgeGrid +from lcm.interfaces import InternalRegime +from lcm.simulation.random import generate_simulation_keys +from lcm.solution.solve_brute import ( + _func_dedup_key, + _resolve_compilation_workers, +) +from lcm.typing import ( + FlatRegimeParams, + InternalParams, + RegimeName, +) +from lcm.utils.logging import format_duration +from lcm.utils.namespace import flatten_regime_namespace + + +def compile_all_simulate_functions( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + ages: AgeGrid, + n_subjects: int, + max_compilation_workers: int | None, + logger: logging.Logger, +) -> MappingProxyType[RegimeName, InternalRegime]: + """AOT-compile every unique simulate function for batch shape `n_subjects`. + + Args: + internal_regimes: Original internal regimes from the Model. + internal_params: Immutable mapping of regime names to flat parameter mappings. + ages: AgeGrid for the model. + n_subjects: Batch size for which to compile. + max_compilation_workers: Maximum threads for parallel XLA compilation. + Defaults to `os.cpu_count()`. + logger: Logger. + + Returns: + Immutable mapping of regime names to InternalRegime where each + regime's `simulate_functions` has its callables replaced by + AOT-compiled programs. + + """ + regime_V_shapes = _get_regime_V_shapes( + internal_regimes=internal_regimes, + internal_params=internal_params, + ) + next_regime_to_V_arr = MappingProxyType( + { + regime_name: jnp.zeros(shape) + for regime_name, shape in regime_V_shapes.items() + } + ) + + unique, func_keys = _collect_unique_simulate_functions( + internal_regimes=internal_regimes, + internal_params=internal_params, + ages=ages, + n_subjects=n_subjects, + next_regime_to_V_arr=next_regime_to_V_arr, + ) + + n_workers = _resolve_compilation_workers( + max_compilation_workers=max_compilation_workers + ) + n_unique = len(unique) + logger.info( + "Simulate AOT compilation: %d unique functions (%d workers)", + n_unique, + n_workers, + ) + + lowered: dict[Hashable, jax.stages.Lowered] = {} + for i, (key, (func, args, label)) in enumerate(unique.items(), 1): + logger.info("%d/%d %s", i, n_unique, label) + logger.info(" lowering ...") + start = time.monotonic() + # `func` is a `jax.jit`-wrapped callable; ty sees only the abstract + # Callable type, so it can't see `.lower(...)`. + lowered[key] = func.lower(**args) # ty: ignore[unresolved-attribute] + logger.info( + " lowered in %s", format_duration(seconds=time.monotonic() - start) + ) + + compiled: dict[Hashable, jax.stages.Compiled] = {} + + def _compile_and_log( + *, + key: Hashable, + low: jax.stages.Lowered, + label: str, + ) -> tuple[Hashable, jax.stages.Compiled]: + logger.info(" compiling %s ...", label) + start = time.monotonic() + result = low.compile() + logger.info( + " compiled %s %s", + label, + format_duration(seconds=time.monotonic() - start), + ) + return key, result + + with ThreadPoolExecutor(max_workers=n_workers) as pool: + futures = [ + pool.submit(_compile_and_log, key=key, low=low, label=unique[key][2]) + for key, low in lowered.items() + ] + for future in as_completed(futures): + k, c = future.result() + compiled[k] = c + + return _swap_in_compiled( + internal_regimes=internal_regimes, + compiled=compiled, + func_keys=func_keys, + ) + + +def _collect_unique_simulate_functions( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + ages: AgeGrid, + n_subjects: int, + next_regime_to_V_arr: MappingProxyType[RegimeName, Array], +) -> tuple[ + dict[Hashable, tuple[Callable, dict, str]], + dict[tuple[RegimeName, str, int | None], Hashable], +]: + """Walk every regime/period and dedup the simulate functions to compile.""" + unique: dict[Hashable, tuple[Callable, dict, str]] = {} + func_keys: dict[tuple[RegimeName, str, int | None], Hashable] = {} + + for regime_name, regime in internal_regimes.items(): + regime_params = internal_params.get(regime_name, MappingProxyType({})) + sf = regime.simulate_functions + + for period, argmax_func in sf.argmax_and_max_Q_over_a.items(): + args = _build_argmax_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + period=period, + n_subjects=n_subjects, + next_regime_to_V_arr=next_regime_to_V_arr, + ) + key = ("argmax", _func_dedup_key(func=argmax_func)) + func_keys[(regime_name, "argmax", period)] = key + if key not in unique: + label = ( + f"{regime_name}/argmax_and_max_Q_over_a " + f"(age {ages.values[period].item()})" + ) + unique[key] = (jax.jit(argmax_func), args, label) + + if not regime.terminal: + args = _build_next_state_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + n_subjects=n_subjects, + ) + key = ("next_state", _func_dedup_key(func=sf.next_state)) + func_keys[(regime_name, "next_state", None)] = key + if key not in unique: + unique[key] = (sf.next_state, args, f"{regime_name}/next_state") + + if sf.compute_regime_transition_probs is not None: + args = _build_crtp_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + n_subjects=n_subjects, + ) + key = ("crtp", _func_dedup_key(func=sf.compute_regime_transition_probs)) + func_keys[(regime_name, "crtp", None)] = key + if key not in unique: + unique[key] = ( + sf.compute_regime_transition_probs, + args, + f"{regime_name}/compute_regime_transition_probs", + ) + + return unique, func_keys + + +def _swap_in_compiled( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + compiled: dict[Hashable, jax.stages.Compiled], + func_keys: dict[tuple[RegimeName, str, int | None], Hashable], +) -> MappingProxyType[RegimeName, InternalRegime]: + """Swap compiled programs into each regime's `simulate_functions`.""" + new_regimes: dict[RegimeName, InternalRegime] = {} + for regime_name, regime in internal_regimes.items(): + sf = regime.simulate_functions + argmax_compiled = MappingProxyType( + { + period: compiled[func_keys[(regime_name, "argmax", period)]] + for period in sf.argmax_and_max_Q_over_a + } + ) + if regime.terminal: + next_state_compiled = sf.next_state + else: + next_state_compiled = compiled[func_keys[(regime_name, "next_state", None)]] + if sf.compute_regime_transition_probs is None: + crtp_compiled = None + else: + crtp_compiled = compiled[func_keys[(regime_name, "crtp", None)]] + + new_sf = dataclasses.replace( + sf, + argmax_and_max_Q_over_a=argmax_compiled, + next_state=next_state_compiled, + compute_regime_transition_probs=crtp_compiled, + ) + new_regimes[regime_name] = dataclasses.replace( + regime, simulate_functions=new_sf + ) + + return MappingProxyType(new_regimes) + + +def _get_regime_V_shapes( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, +) -> dict[RegimeName, tuple[int, ...]]: + shapes: dict[RegimeName, tuple[int, ...]] = {} + for regime_name, regime in internal_regimes.items(): + space = regime.state_action_space( + regime_params=internal_params.get(regime_name, MappingProxyType({})) + ) + shapes[regime_name] = tuple(len(v) for v in space.states.values()) + return shapes + + +def _build_argmax_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + period: int, + n_subjects: int, + next_regime_to_V_arr: MappingProxyType[RegimeName, Array], +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + return { + **subject_states, + **base.discrete_actions, + **base.continuous_actions, + "next_regime_to_V_arr": next_regime_to_V_arr, + **regime_params, + "period": jnp.int32(period), + "age": ages.values[period], + } + + +def _build_next_state_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + n_subjects: int, +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_actions = _subject_shape_arrays( + {**base.discrete_actions, **base.continuous_actions}, + n_subjects=n_subjects, + ) + + stoch_transition_names = ( + internal_regime.simulate_functions.stochastic_transition_names + ) + stoch_next_func_names = sorted( + next_func_name + for next_func_name in flatten_regime_namespace( + internal_regime.simulate_functions.transitions + ) + if tree_path_from_qname(next_func_name)[-1] in stoch_transition_names + ) + _, stoch_keys = generate_simulation_keys( + key=jax.random.key(0), + names=stoch_next_func_names, + n_initial_states=n_subjects, + ) + + # `period` is passed as a plain Python int by `calculate_next_states` + # (transitions.py), which traces as the default-precision int. Match that + # here so the lowered shape signature lines up with the runtime call. + return { + **subject_states, + **subject_actions, + **stoch_keys, + "period": 0, + "age": ages.values[0], + **regime_params, + } + + +def _build_crtp_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + n_subjects: int, +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_actions = _subject_shape_arrays( + {**base.discrete_actions, **base.continuous_actions}, + n_subjects=n_subjects, + ) + return { + **subject_states, + **subject_actions, + "period": 0, + "age": ages.values[0], + **regime_params, + } + + +def _subject_shape_arrays( + base_arrays: Mapping[str, Array], + *, + n_subjects: int, +) -> dict[str, Array]: + """Return zeros of shape `(n_subjects,)` mirroring each base array's dtype.""" + return { + name: jnp.zeros((n_subjects,), dtype=arr.dtype) + for name, arr in base_arrays.items() + } diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d54d2674..d90e2175 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -108,6 +108,16 @@ def simulate( # Build reverse lookup for regime transition logging ids_to_names: dict[int, RegimeName] = {v: k for k, v in regime_names_to_ids.items()} + # Normalize V arrays so every period (including the post-last fallback) + # has the same pytree structure: all regime keys, zeros for inactive + # regimes. This collapses the per-period dispatch to a single JIT-trace + # signature and matches the AOT-compiled programs. + period_to_regime_to_V_arr, empty_next_V = _normalize_period_V_arr( + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + internal_regimes=internal_regimes, + internal_params=internal_params, + ) + for period, age in enumerate(ages.values): period_start = time.monotonic() @@ -146,6 +156,7 @@ def simulate( subject_regime_ids=subject_regime_ids, new_subject_regime_ids=new_subject_regime_ids, period_to_regime_to_V_arr=period_to_regime_to_V_arr, + empty_next_V=empty_next_V, internal_params=internal_params, regime_names_to_ids=regime_names_to_ids, active_regimes_next_period=active_regimes_next_period, @@ -204,6 +215,7 @@ def _simulate_regime_in_period( period_to_regime_to_V_arr: MappingProxyType[ int, MappingProxyType[RegimeName, FloatND] ], + empty_next_V: MappingProxyType[RegimeName, FloatND], internal_params: InternalParams, regime_names_to_ids: MappingProxyType[RegimeName, int], active_regimes_next_period: tuple[RegimeName, ...], @@ -249,10 +261,10 @@ def _simulate_regime_in_period( # Compute optimal actions # We need to pass the value function array of the next period to the # argmax_and_max_Q_over_a function, as the current Q-function requires the - # next period's value function. In the last period, we pass an empty dict. - next_regime_to_V_arr = period_to_regime_to_V_arr.get( - period + 1, MappingProxyType({}) - ) + # next period's value function. In the last period the next-period V is + # zeros (same shape as a populated period) — this keeps a single JIT-trace + # signature across all periods and matches the AOT-compile signature. + next_regime_to_V_arr = period_to_regime_to_V_arr.get(period + 1, empty_next_V) # The Q-function values contain the information of how much value each # action combination is worth. To find the optimal discrete action, we @@ -364,6 +376,53 @@ def _lookup_values_from_indices( vmapped_unravel_index = vmap(jnp.unravel_index, in_axes=(0, None)) +def _normalize_period_V_arr( + *, + period_to_regime_to_V_arr: MappingProxyType[ + int, MappingProxyType[RegimeName, FloatND] + ], + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, +) -> tuple[ + MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], + MappingProxyType[RegimeName, FloatND], +]: + """Fill missing regime keys with zero V arrays so every period has the same shape. + + `solve()` returns active-only mappings per period; AOT compilation and the + JIT cache both work best with a fixed pytree, so we pad with zeros here. + + Returns: + Tuple of (normalized period→regime→V mapping, all-zeros fallback for + post-last periods). + + """ + regime_V_shapes: dict[RegimeName, tuple[int, ...]] = {} + for regime_name, regime in internal_regimes.items(): + space = regime.state_action_space( + regime_params=internal_params.get(regime_name, MappingProxyType({})) + ) + regime_V_shapes[regime_name] = tuple(len(v) for v in space.states.values()) + + empty_next_V = MappingProxyType( + { + regime_name: jnp.zeros(shape) + for regime_name, shape in regime_V_shapes.items() + } + ) + + normalized: dict[int, MappingProxyType[RegimeName, FloatND]] = {} + for period, regime_to_V in period_to_regime_to_V_arr.items(): + normalized[period] = MappingProxyType( + { + regime_name: regime_to_V.get(regime_name, empty_next_V[regime_name]) + for regime_name in internal_regimes + } + ) + + return MappingProxyType(normalized), empty_next_V + + def _compute_starting_periods( *, initial_ages: Array, diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py new file mode 100644 index 00000000..3b8e57a1 --- /dev/null +++ b/tests/simulation/test_simulate_aot.py @@ -0,0 +1,180 @@ +"""Tests for simulate-AOT compilation via `Model.n_subjects`. + +When `Model(n_subjects=N)` is set, the first matching `simulate(...)` call +parallel-compiles all simulate functions for batch shape `N`. Subsequent calls +with size `N` reuse the cache; calls with a mismatching size warn once per size +and fall back to the runtime-traced path. +""" + +import logging + +import jax.numpy as jnp +import jax.stages +import pytest + +from lcm import Model +from lcm.ages import AgeGrid +from tests.test_models.deterministic.regression import ( + RegimeId, + dead, + get_params, + working_life, +) + + +def _build_test_model(*, n_periods: int, n_subjects: int | None = None) -> Model: + """Construct the small 2-regime regression model with optional n_subjects.""" + final_age_alive = 18 + n_periods - 2 + return Model( + regimes={ + "working_life": working_life.replace( + active=lambda age: age <= final_age_alive, + ), + "dead": dead, + }, + ages=AgeGrid(start=18, stop=final_age_alive + 1, step="Y"), + regime_id_class=RegimeId, + n_subjects=n_subjects, + ) + + +def _build_initial_conditions(*, n_subjects: int) -> dict: + """Subject array of size `n_subjects` matching the regression test model.""" + wealths = jnp.linspace(20.0, 320.0, num=n_subjects) + return { + "wealth": wealths, + "age": jnp.full((n_subjects,), 18.0), + "regime": jnp.array([RegimeId.working_life] * n_subjects), + } + + +@pytest.mark.parametrize("invalid", [0, -3]) +def test_n_subjects_validation_rejects_non_positive(invalid: int) -> None: + with pytest.raises(ValueError, match="n_subjects"): + _build_test_model(n_periods=3, n_subjects=invalid) + + +def test_n_subjects_validation_rejects_non_int() -> None: + with pytest.raises(TypeError, match="n_subjects"): + _build_test_model(n_periods=3, n_subjects=1.5) # ty: ignore[invalid-argument-type] + + +def test_n_subjects_none_keeps_lazy_behavior() -> None: + """Without n_subjects, simulate works and no AOT cache is populated.""" + n_periods = 3 + model = _build_test_model(n_periods=n_periods, n_subjects=None) + params = get_params(n_periods=n_periods) + + result = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=4), + ) + + assert result.n_subjects == 4 + assert model.n_subjects is None + assert not getattr(model, "_simulate_compile_cache", {}) + + +def test_simulate_compiles_only_once_with_matching_n_subjects( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """First simulate call AOT-compiles; second call hits the cache.""" + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + counter = {"count": 0} + original_compile = jax.stages.Lowered.compile + + def counting_compile(self: jax.stages.Lowered, *args, **kwargs): + counter["count"] += 1 + return original_compile(self, *args, **kwargs) + + monkeypatch.setattr(jax.stages.Lowered, "compile", counting_compile) + + initial_conditions = _build_initial_conditions(n_subjects=n_subjects) + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + n_first = counter["count"] + counter["count"] = 0 + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + n_second = counter["count"] + + assert n_first > 0, "First simulate must trigger compilation." + assert n_second == 0, "Second simulate must hit the AOT cache." + assert n_subjects in model._simulate_compile_cache + + +def test_simulate_warns_on_n_subjects_mismatch( + caplog: pytest.LogCaptureFixture, +) -> None: + """Mismatching size logs WARNING naming both N and M, falls back to lazy path.""" + n_periods = 3 + declared_n = 4 + actual_n = 7 + model = _build_test_model(n_periods=n_periods, n_subjects=declared_n) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + with caplog.at_level(logging.WARNING, logger="lcm"): + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=_build_initial_conditions(n_subjects=actual_n), + ) + + mismatch_warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() + ] + assert len(mismatch_warnings) == 1 + msg = mismatch_warnings[0].getMessage() + assert str(declared_n) in msg + assert str(actual_n) in msg + # Cache is NOT populated for mismatching size — fallback path was taken. + assert actual_n not in model._simulate_compile_cache + + +def test_simulate_caches_recompiled_size_no_second_warning( + caplog: pytest.LogCaptureFixture, +) -> None: + """Two calls with the same mismatching size produce only one WARNING.""" + n_periods = 3 + declared_n = 4 + actual_n = 7 + model = _build_test_model(n_periods=n_periods, n_subjects=declared_n) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + initial_conditions = _build_initial_conditions(n_subjects=actual_n) + + with caplog.at_level(logging.WARNING, logger="lcm"): + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + + mismatch_warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() + ] + assert len(mismatch_warnings) == 1 From 8bb8259cf64849423b49e0d4b4bc3a6db7ca46d7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 12:21:51 +0200 Subject: [PATCH 27/80] simulate AOT: match runtime's sparse pytree, drop runtime padding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous round padded `next_regime_to_V_arr` to all 19 regime keys at every period inside `simulate.simulate(...)`. That was a workaround for a pytree mismatch I'd introduced on the AOT side, not a real requirement — runtime has always passed only the active-at-P+1 regime keys (or `{}` past the last period), and `argmax_and_max_Q_over_a` traced fine against that sparse mapping. Padding everywhere widened the live device footprint of every dispatch (aca-baseline benchmark went 539 MB → 1.03 GB peak GPU, +11% execution time). Fix: keep the runtime path sparse and have AOT compile against the same sparse pytree per period. `_collect_unique_simulate_functions` now keys the argmax dedup on `(func_id, active_at_next_period)` so two periods sharing the same Q_and_F closure but seeing different active-regime sets at P+1 each get their own compiled program. The lower-args template is built per period from those active regimes only. Net effect: - Default (lazy) path: identical pytree to before this PR; the benchmark regression goes away. - AOT path: same correctness, programs sized to the actual runtime signature, dedup still effective when consecutive periods share the active set. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/simulation/compile.py | 47 +++++++++++++++++++----- src/lcm/simulation/simulate.py | 67 ++-------------------------------- 2 files changed, 41 insertions(+), 73 deletions(-) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index 316d877b..9ee4d8c8 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -65,23 +65,20 @@ def compile_all_simulate_functions( AOT-compiled programs. """ + # Per-regime V-shape lookup for building period-specific templates that + # match the *sparse* mapping `simulate.simulate(...)` actually dispatches: + # `period_to_regime_to_V_arr.get(P+1, {})` — only regimes active at P+1. regime_V_shapes = _get_regime_V_shapes( internal_regimes=internal_regimes, internal_params=internal_params, ) - next_regime_to_V_arr = MappingProxyType( - { - regime_name: jnp.zeros(shape) - for regime_name, shape in regime_V_shapes.items() - } - ) unique, func_keys = _collect_unique_simulate_functions( internal_regimes=internal_regimes, internal_params=internal_params, ages=ages, n_subjects=n_subjects, - next_regime_to_V_arr=next_regime_to_V_arr, + regime_V_shapes=regime_V_shapes, ) n_workers = _resolve_compilation_workers( @@ -146,12 +143,19 @@ def _collect_unique_simulate_functions( internal_params: InternalParams, ages: AgeGrid, n_subjects: int, - next_regime_to_V_arr: MappingProxyType[RegimeName, Array], + regime_V_shapes: dict[RegimeName, tuple[int, ...]], ) -> tuple[ dict[Hashable, tuple[Callable, dict, str]], dict[tuple[RegimeName, str, int | None], Hashable], ]: - """Walk every regime/period and dedup the simulate functions to compile.""" + """Walk every regime/period and dedup the simulate functions to compile. + + `argmax_and_max_Q_over_a` dedup keys on `(func_id, active_at_next_period)` + so two periods that share the same argmax callable but see a different + `next_regime_to_V_arr` pytree (different active-regime set at P+1) get + separate compiled programs whose signature matches what runtime actually + dispatches. + """ unique: dict[Hashable, tuple[Callable, dict, str]] = {} func_keys: dict[tuple[RegimeName, str, int | None], Hashable] = {} @@ -160,6 +164,12 @@ def _collect_unique_simulate_functions( sf = regime.simulate_functions for period, argmax_func in sf.argmax_and_max_Q_over_a.items(): + active_next = _active_regimes_at_period( + internal_regimes=internal_regimes, period=period + 1 + ) + next_regime_to_V_arr = MappingProxyType( + {name: jnp.zeros(regime_V_shapes[name]) for name in active_next} + ) args = _build_argmax_args( internal_regime=regime, regime_params=regime_params, @@ -168,7 +178,7 @@ def _collect_unique_simulate_functions( n_subjects=n_subjects, next_regime_to_V_arr=next_regime_to_V_arr, ) - key = ("argmax", _func_dedup_key(func=argmax_func)) + key = ("argmax", _func_dedup_key(func=argmax_func), active_next) func_keys[(regime_name, "argmax", period)] = key if key not in unique: label = ( @@ -260,6 +270,23 @@ def _get_regime_V_shapes( return shapes +def _active_regimes_at_period( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + period: int, +) -> tuple[RegimeName, ...]: + """Tuple of regime names active at `period`, in `internal_regimes` order. + + Returned as a `tuple` so it's hashable and pytree-key-stable. An empty + tuple matches the runtime fallback for periods past the last (`{}`). + """ + return tuple( + regime_name + for regime_name, regime in internal_regimes.items() + if period in regime.active_periods + ) + + def _build_argmax_args( *, internal_regime: InternalRegime, diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d90e2175..d54d2674 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -108,16 +108,6 @@ def simulate( # Build reverse lookup for regime transition logging ids_to_names: dict[int, RegimeName] = {v: k for k, v in regime_names_to_ids.items()} - # Normalize V arrays so every period (including the post-last fallback) - # has the same pytree structure: all regime keys, zeros for inactive - # regimes. This collapses the per-period dispatch to a single JIT-trace - # signature and matches the AOT-compiled programs. - period_to_regime_to_V_arr, empty_next_V = _normalize_period_V_arr( - period_to_regime_to_V_arr=period_to_regime_to_V_arr, - internal_regimes=internal_regimes, - internal_params=internal_params, - ) - for period, age in enumerate(ages.values): period_start = time.monotonic() @@ -156,7 +146,6 @@ def simulate( subject_regime_ids=subject_regime_ids, new_subject_regime_ids=new_subject_regime_ids, period_to_regime_to_V_arr=period_to_regime_to_V_arr, - empty_next_V=empty_next_V, internal_params=internal_params, regime_names_to_ids=regime_names_to_ids, active_regimes_next_period=active_regimes_next_period, @@ -215,7 +204,6 @@ def _simulate_regime_in_period( period_to_regime_to_V_arr: MappingProxyType[ int, MappingProxyType[RegimeName, FloatND] ], - empty_next_V: MappingProxyType[RegimeName, FloatND], internal_params: InternalParams, regime_names_to_ids: MappingProxyType[RegimeName, int], active_regimes_next_period: tuple[RegimeName, ...], @@ -261,10 +249,10 @@ def _simulate_regime_in_period( # Compute optimal actions # We need to pass the value function array of the next period to the # argmax_and_max_Q_over_a function, as the current Q-function requires the - # next period's value function. In the last period the next-period V is - # zeros (same shape as a populated period) — this keeps a single JIT-trace - # signature across all periods and matches the AOT-compile signature. - next_regime_to_V_arr = period_to_regime_to_V_arr.get(period + 1, empty_next_V) + # next period's value function. In the last period, we pass an empty dict. + next_regime_to_V_arr = period_to_regime_to_V_arr.get( + period + 1, MappingProxyType({}) + ) # The Q-function values contain the information of how much value each # action combination is worth. To find the optimal discrete action, we @@ -376,53 +364,6 @@ def _lookup_values_from_indices( vmapped_unravel_index = vmap(jnp.unravel_index, in_axes=(0, None)) -def _normalize_period_V_arr( - *, - period_to_regime_to_V_arr: MappingProxyType[ - int, MappingProxyType[RegimeName, FloatND] - ], - internal_regimes: MappingProxyType[RegimeName, InternalRegime], - internal_params: InternalParams, -) -> tuple[ - MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], - MappingProxyType[RegimeName, FloatND], -]: - """Fill missing regime keys with zero V arrays so every period has the same shape. - - `solve()` returns active-only mappings per period; AOT compilation and the - JIT cache both work best with a fixed pytree, so we pad with zeros here. - - Returns: - Tuple of (normalized period→regime→V mapping, all-zeros fallback for - post-last periods). - - """ - regime_V_shapes: dict[RegimeName, tuple[int, ...]] = {} - for regime_name, regime in internal_regimes.items(): - space = regime.state_action_space( - regime_params=internal_params.get(regime_name, MappingProxyType({})) - ) - regime_V_shapes[regime_name] = tuple(len(v) for v in space.states.values()) - - empty_next_V = MappingProxyType( - { - regime_name: jnp.zeros(shape) - for regime_name, shape in regime_V_shapes.items() - } - ) - - normalized: dict[int, MappingProxyType[RegimeName, FloatND]] = {} - for period, regime_to_V in period_to_regime_to_V_arr.items(): - normalized[period] = MappingProxyType( - { - regime_name: regime_to_V.get(regime_name, empty_next_V[regime_name]) - for regime_name in internal_regimes - } - ) - - return MappingProxyType(normalized), empty_next_V - - def _compute_starting_periods( *, initial_ages: Array, From 54c72a0dd95114b1cea7963001470ffa1aaf5f94 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 13:31:21 +0200 Subject: [PATCH 28/80] bench_aca_baseline: pass n_subjects=_N_SUBJECTS to create_benchmark_model Exercises the AOT-simulate path so the benchmark actually measures it. --- benchmarks/bench_aca_baseline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 477d5635..a9364879 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -54,7 +54,7 @@ def _build() -> tuple[object, object, object]: get_benchmark_params, ) - model = create_benchmark_model() + model = create_benchmark_model(n_subjects=_N_SUBJECTS) _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 From 92d038caee1b19852a1d8a453ff1346d30668580 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 18:58:34 +0200 Subject: [PATCH 29/80] benchmarks: bump aca-model rev to carry n_subjects on factories The benchmark env pins aca-model by SHA. The previous SHA pre-dates `create_benchmark_model(n_subjects=...)`, so the aca-baseline benchmark fails at `setup_cache` with `unexpected keyword argument 'n_subjects'`. Bump to the tip of `feature/runtime-consumption-points`. --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index e507ccc8..1940247e 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=9e252051ad53683a8ad65e9dba68a910240103c0#9e252051ad53683a8ad65e9dba68a910240103c0 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=9e252051ad53683a8ad65e9dba68a910240103c0#9e252051ad53683a8ad65e9dba68a910240103c0 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev130+g3769d2d63.d20260429 - sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e + version: 0.0.2.dev150+g8d796d824.d20260502 + sha256: dccd6f67e478bf10ca02dffd861e45de074f35d82a50262aa2eb47d3132022d6 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 3fe01162..3104a252 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "9e252051ad53683a8ad65e9dba68a910240103c0" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 648afcc9b623c6864433b95073f7e0185f15a7a6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 21:03:11 +0200 Subject: [PATCH 30/80] benchmarks: bump aca-model rev + pass max_consumption to factory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit aca-model now requires `max_consumption` on every `create_model*` factory (no default) — pass `_MAX_CONSUMPTION=300_000.0` to `create_benchmark_model` so the benchmark builds. --- benchmarks/bench_aca_baseline.py | 5 ++++- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index a9364879..064281fe 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -44,6 +44,7 @@ from benchmarks import _gpu_mem _N_SUBJECTS = 1000 +_MAX_CONSUMPTION = 300_000.0 def _build() -> tuple[object, object, object]: @@ -54,7 +55,9 @@ def _build() -> tuple[object, object, object]: get_benchmark_params, ) - model = create_benchmark_model(n_subjects=_N_SUBJECTS) + model = create_benchmark_model( + n_subjects=_N_SUBJECTS, max_consumption=_MAX_CONSUMPTION + ) _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 diff --git a/pixi.lock b/pixi.lock index 1940247e..b2f7c782 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=9e252051ad53683a8ad65e9dba68a910240103c0#9e252051ad53683a8ad65e9dba68a910240103c0 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=cdd10169a13fbc74604f9e879276ddb4c17b53c4#cdd10169a13fbc74604f9e879276ddb4c17b53c4 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=9e252051ad53683a8ad65e9dba68a910240103c0#9e252051ad53683a8ad65e9dba68a910240103c0 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=cdd10169a13fbc74604f9e879276ddb4c17b53c4#cdd10169a13fbc74604f9e879276ddb4c17b53c4 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev150+g8d796d824.d20260502 - sha256: dccd6f67e478bf10ca02dffd861e45de074f35d82a50262aa2eb47d3132022d6 + version: 0.0.2.dev151+gbb822917b.d20260502 + sha256: f321dce724999c07d44f6fe6a27c571bde1a33e464cc81c18a613f93889169cc requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 3104a252..64e8ae0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "9e252051ad53683a8ad65e9dba68a910240103c0" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "cdd10169a13fbc74604f9e879276ddb4c17b53c4" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 596f150885324db74922b68f285ef3249584ec23 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 21:13:22 +0200 Subject: [PATCH 31/80] simulate AOT: re-jit `next_state` / `compute_regime_transition_probs` When `fixed_params` are partialled into a regime via `_partial_fixed_params_into_regimes`, `simulate_functions.next_state` and `simulate_functions.compute_regime_transition_probs` become `functools.partial(jit_wrapped, **fixed)` objects, which don't have `.lower()`. The argmax path was already wrapping with `jax.jit` before lowering; do the same for these two so AOT compile works whether or not the regime carries fixed params. --- src/lcm/simulation/compile.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index 9ee4d8c8..2d492aef 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -197,7 +197,14 @@ def _collect_unique_simulate_functions( key = ("next_state", _func_dedup_key(func=sf.next_state)) func_keys[(regime_name, "next_state", None)] = key if key not in unique: - unique[key] = (sf.next_state, args, f"{regime_name}/next_state") + # Re-wrap with `jax.jit`: when `fixed_params` are partialled + # into the regime, `sf.next_state` is a `functools.partial` + # (no `.lower()`); plain jit objects are also fine to re-jit. + unique[key] = ( + jax.jit(sf.next_state), + args, + f"{regime_name}/next_state", + ) if sf.compute_regime_transition_probs is not None: args = _build_crtp_args( @@ -210,7 +217,7 @@ def _collect_unique_simulate_functions( func_keys[(regime_name, "crtp", None)] = key if key not in unique: unique[key] = ( - sf.compute_regime_transition_probs, + jax.jit(sf.compute_regime_transition_probs), args, f"{regime_name}/compute_regime_transition_probs", ) From 9fd0524049dbe72762877b9508204f0a3979e6f0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 07:41:26 +0200 Subject: [PATCH 32/80] simulate AOT: only compile active-period argmax, not the full age range MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `sf.argmax_and_max_Q_over_a` carries an entry for every period in the model's age grid (pylcm builds Q_and_F per period across the whole range), but a regime is only dispatched at runtime for periods in `regime.active_periods`. The lazy path never traces the unused entries — and that turns out to matter: a regime's stale per-period `complete_targets` can include reachable-by-spec-but-actually-inactive target regimes whose `transitions[target]` is missing the `next_` for some target state. Tracing such an entry blows up in `Q_and_F`'s `extra_kw = {k: states_actions_params[k] ...}` lookup on a `next_` key that runtime would never have produced. Concrete repro on aca-baseline: `retiree_oamc_forced_canwork`'s argmax for age 61 (the regime is only active at age 70+) carries choose-target entries with `next_claimed_ss` missing. The lazy path never dispatches that (regime, period) pair. AOT was iterating `sf.argmax_and_max_Q_over_a.items()` unconditionally, so it tried to lower it and tripped the KeyError. Fix: iterate `regime.active_periods` for the argmax sweep, and have `_swap_in_compiled` leave the inactive-period entries as the original closure (they're never invoked anyway). --- src/lcm/simulation/compile.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index 2d492aef..2e64ab0d 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -163,7 +163,18 @@ def _collect_unique_simulate_functions( regime_params = internal_params.get(regime_name, MappingProxyType({})) sf = regime.simulate_functions - for period, argmax_func in sf.argmax_and_max_Q_over_a.items(): + # `sf.argmax_and_max_Q_over_a` has entries for *every* period + # (pylcm builds them across the full age grid), but the regime is + # only dispatched at runtime for periods in `regime.active_periods`. + # The unused entries can carry a stale `complete_targets` set + # whose shape doesn't match the regime's actual transitions + # (e.g. a forced-canwork regime's argmax for a pre-FRA period + # has choose targets in scope, even though the regime never + # reaches that period at runtime). Tracing those would surface + # `next_` bookkeeping inconsistencies that the lazy path + # never trips. Restrict AOT to active periods to mirror runtime. + for period in regime.active_periods: + argmax_func = sf.argmax_and_max_Q_over_a[period] active_next = _active_regimes_at_period( internal_regimes=internal_regimes, period=period + 1 ) @@ -235,10 +246,18 @@ def _swap_in_compiled( new_regimes: dict[RegimeName, InternalRegime] = {} for regime_name, regime in internal_regimes.items(): sf = regime.simulate_functions + # Only active periods are AOT-compiled (see + # `_collect_unique_simulate_functions`); leave inactive-period + # entries untouched so the existing closure stays in place — they + # are never dispatched at runtime anyway. + argmax_compiled_for_active = { + period: compiled[func_keys[(regime_name, "argmax", period)]] + for period in regime.active_periods + } argmax_compiled = MappingProxyType( { - period: compiled[func_keys[(regime_name, "argmax", period)]] - for period in sf.argmax_and_max_Q_over_a + period: argmax_compiled_for_active.get(period, original_func) + for period, original_func in sf.argmax_and_max_Q_over_a.items() } ) if regime.terminal: From dfb0e8b1e930fe6952b138bb370b02aa4c92895c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 09:28:25 +0200 Subject: [PATCH 33/80] simulate AOT: int32 for discrete state lower-args (match runtime) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Runtime always builds discrete state arrays as `jnp.int32` (forced by `simulation.initial_conditions.build_initial_states` regardless of the grid array's own dtype), but my AOT lower-args were using `arr.dtype` from the regime's `state_action_space.states[name]`. For `DiscreteGrid` states (`health`, `pref_type`, `lagged_labor_supply`, `spousal_income`), JAX's default `jnp.array([0, 1, 2])` gives `int64` when x64 is enabled — so the AOT-compiled programs expected `int64`, runtime called with `int32`, and JAX raised: TypeError: Argument types differ from the types for which this computation was compiled. Argument 'health' compiled with int64[1000] and called with int32[1000] Fix: surfaced via `_subject_shape_state_arrays` which now consults `internal_regime.grids[name]` and forces `int32` for `DiscreteGrid` states, matching `build_initial_states`. Continuous states keep the grid's dtype (which already matches runtime since continuous states flow through `initial_states` directly). Action grids (full grids, not subject-shaped) keep `arr.dtype` — runtime calls argmax with the same grid arrays, so dtypes already match there. --- src/lcm/simulation/compile.py | 42 ++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index 2e64ab0d..f8d5971b 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -24,6 +24,7 @@ from jax import Array from lcm.ages import AgeGrid +from lcm.grids import DiscreteGrid, Grid from lcm.interfaces import InternalRegime from lcm.simulation.random import generate_simulation_keys from lcm.solution.solve_brute import ( @@ -323,7 +324,11 @@ def _build_argmax_args( next_regime_to_V_arr: MappingProxyType[RegimeName, Array], ) -> dict[str, object]: base = internal_regime.state_action_space(regime_params=regime_params) - subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_states = _subject_shape_state_arrays( + states=base.states, + grids=internal_regime.grids, + n_subjects=n_subjects, + ) return { **subject_states, **base.discrete_actions, @@ -343,7 +348,11 @@ def _build_next_state_args( n_subjects: int, ) -> dict[str, object]: base = internal_regime.state_action_space(regime_params=regime_params) - subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_states = _subject_shape_state_arrays( + states=base.states, + grids=internal_regime.grids, + n_subjects=n_subjects, + ) subject_actions = _subject_shape_arrays( {**base.discrete_actions, **base.continuous_actions}, n_subjects=n_subjects, @@ -386,7 +395,11 @@ def _build_crtp_args( n_subjects: int, ) -> dict[str, object]: base = internal_regime.state_action_space(regime_params=regime_params) - subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_states = _subject_shape_state_arrays( + states=base.states, + grids=internal_regime.grids, + n_subjects=n_subjects, + ) subject_actions = _subject_shape_arrays( {**base.discrete_actions, **base.continuous_actions}, n_subjects=n_subjects, @@ -410,3 +423,26 @@ def _subject_shape_arrays( name: jnp.zeros((n_subjects,), dtype=arr.dtype) for name, arr in base_arrays.items() } + + +def _subject_shape_state_arrays( + *, + states: Mapping[str, Array], + grids: Mapping[str, Grid], + n_subjects: int, +) -> dict[str, Array]: + """Subject-shape arrays for state inputs that match runtime dtypes. + + Discrete states arrive at runtime as `int32` (forced by + `simulation.initial_conditions.build_initial_states`), regardless of + the grid array's own dtype. Continuous states inherit the grid dtype. + AOT lower-args must mirror this so the compiled program's expected + signature lines up with runtime dispatch. + """ + out: dict[str, Array] = {} + for name, arr in states.items(): + if isinstance(grids.get(name), DiscreteGrid): + out[name] = jnp.zeros((n_subjects,), dtype=jnp.int32) + else: + out[name] = jnp.zeros((n_subjects,), dtype=arr.dtype) + return out From 43723c43d0f4e915fb53aac83fe27766ae76cd45 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 09:41:53 +0200 Subject: [PATCH 34/80] build_initial_states: cast discrete states to grid dtype (one-shot) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The simulate loop carried a dtype regression that was invisible until AOT compile flagged it: `build_initial_states` forced `int32` for discrete states (matching the hardcoded `MISSING_CAT_CODE` dtype), but `next_state_vmapped` returns post-transition values at the grid's dtype — int64 under x64 mode (which `aca_model` enables). Result: period 0 dispatches argmax with int32 state args, period 1+ dispatches with int64 (after the first transition rolls forward). The lazy JIT cache shrugs and compiles two argmax variants per regime; the AOT path can only ship one signature and breaks. Fix: have `build_initial_states` cast every discrete state — both the user-supplied entry and the missing-state sentinel — to the grid's dtype. All periods now share one signature, the lazy cache shrinks to one program per regime, and the AOT path's lower-args (which already used `arr.dtype` from the grid) lines up. Also revert the previous AOT-side dtype hack in compile.py: the `_subject_shape_state_arrays` workaround was patching the symptom, not the cause, and would have left initial-vs-transition-dtype mismatches in place. --- src/lcm/simulation/compile.py | 49 +++++------------------- src/lcm/simulation/initial_conditions.py | 20 ++++++++-- 2 files changed, 26 insertions(+), 43 deletions(-) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index f8d5971b..c325b27f 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -24,7 +24,6 @@ from jax import Array from lcm.ages import AgeGrid -from lcm.grids import DiscreteGrid, Grid from lcm.interfaces import InternalRegime from lcm.simulation.random import generate_simulation_keys from lcm.solution.solve_brute import ( @@ -324,11 +323,7 @@ def _build_argmax_args( next_regime_to_V_arr: MappingProxyType[RegimeName, Array], ) -> dict[str, object]: base = internal_regime.state_action_space(regime_params=regime_params) - subject_states = _subject_shape_state_arrays( - states=base.states, - grids=internal_regime.grids, - n_subjects=n_subjects, - ) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) return { **subject_states, **base.discrete_actions, @@ -348,11 +343,7 @@ def _build_next_state_args( n_subjects: int, ) -> dict[str, object]: base = internal_regime.state_action_space(regime_params=regime_params) - subject_states = _subject_shape_state_arrays( - states=base.states, - grids=internal_regime.grids, - n_subjects=n_subjects, - ) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) subject_actions = _subject_shape_arrays( {**base.discrete_actions, **base.continuous_actions}, n_subjects=n_subjects, @@ -395,11 +386,7 @@ def _build_crtp_args( n_subjects: int, ) -> dict[str, object]: base = internal_regime.state_action_space(regime_params=regime_params) - subject_states = _subject_shape_state_arrays( - states=base.states, - grids=internal_regime.grids, - n_subjects=n_subjects, - ) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) subject_actions = _subject_shape_arrays( {**base.discrete_actions, **base.continuous_actions}, n_subjects=n_subjects, @@ -418,31 +405,13 @@ def _subject_shape_arrays( *, n_subjects: int, ) -> dict[str, Array]: - """Return zeros of shape `(n_subjects,)` mirroring each base array's dtype.""" + """Return zeros of shape `(n_subjects,)` mirroring each base array's dtype. + + With `build_initial_states` casting discrete states to the grid dtype, + runtime states (initial + post-transition) share the grid's dtype, so + using `arr.dtype` from the regime's grid here matches runtime. + """ return { name: jnp.zeros((n_subjects,), dtype=arr.dtype) for name, arr in base_arrays.items() } - - -def _subject_shape_state_arrays( - *, - states: Mapping[str, Array], - grids: Mapping[str, Grid], - n_subjects: int, -) -> dict[str, Array]: - """Subject-shape arrays for state inputs that match runtime dtypes. - - Discrete states arrive at runtime as `int32` (forced by - `simulation.initial_conditions.build_initial_states`), regardless of - the grid array's own dtype. Continuous states inherit the grid dtype. - AOT lower-args must mirror this so the compiled program's expected - signature lines up with runtime dispatch. - """ - out: dict[str, Array] = {} - for name, arr in states.items(): - if isinstance(grids.get(name), DiscreteGrid): - out[name] = jnp.zeros((n_subjects,), dtype=jnp.int32) - else: - out[name] = jnp.zeros((n_subjects,), dtype=arr.dtype) - return out diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 2cdd3c3b..22a8dd2e 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -64,10 +64,24 @@ def build_initial_states( for regime_name, internal_regime in internal_regimes.items(): for state_name in _get_regime_state_names(internal_regime): key = f"{regime_name}__{state_name}" - if state_name in initial_states: + grid = internal_regime.grids[state_name] + if isinstance(grid, DiscreteGrid): + # Match the grid's index dtype so the state is index-stable + # across the simulate loop. Without this, period-0 dispatch + # carries the user-supplied dtype (often int32) but post- + # transition states are promoted to the grid dtype (int64 + # under x64), forcing JAX to compile two argmax variants + # per regime and breaking AOT-compiled programs that key + # on a single signature. + target_dtype = grid.to_jax().dtype + if state_name in initial_states: + flat[key] = initial_states[state_name].astype(target_dtype) + else: + flat[key] = jnp.full( + n_subjects, MISSING_CAT_CODE, dtype=target_dtype + ) + elif state_name in initial_states: flat[key] = initial_states[state_name] - elif isinstance(internal_regime.grids[state_name], DiscreteGrid): - flat[key] = jnp.full(n_subjects, MISSING_CAT_CODE, dtype=jnp.int32) else: flat[key] = jnp.full(n_subjects, jnp.nan) From 99f3f15ee92eeca50ee2e515edda62ae181b3eb9 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 10:09:01 +0200 Subject: [PATCH 35/80] DiscreteGrid: pin to_jax() to int32 regardless of x64 mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `jnp.array([...])` defaults to `int32` without x64 and `int64` with x64. That makes every downstream value carrying a discrete code — transitions, V-array indexing, action lookups — depend silently on a process-wide setting. The lazy JIT cache hides the consequences (it just compiles two variants per regime, one per dtype), but AOT can only ship one signature, and any caller that mixes init-time int32 with post-transition int64 (or vice versa) trips the `int32[...] vs int64[...]` mismatch JAX raises at AOT call time. Fix: pin `DiscreteGrid.to_jax()` to `int32`. Every derived discrete value inherits that dtype — initial-state casts (already grid-dtype since the previous fix), transitions, indexing, action lookups — locking the discrete dtype to a single explicit choice rather than relying on whichever JAX precision mode happens to be active. `int32` covers any realistic category count and matches the existing `MISSING_CAT_CODE` sentinel. --- src/lcm/grids/discrete.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/lcm/grids/discrete.py b/src/lcm/grids/discrete.py index f844aad6..acdb850d 100644 --- a/src/lcm/grids/discrete.py +++ b/src/lcm/grids/discrete.py @@ -48,5 +48,15 @@ def batch_size(self) -> int: return self.__batch_size def to_jax(self) -> Int1D: - """Convert the grid to a Jax array.""" - return jnp.array(self.codes) + """Convert the grid to a Jax array. + + Discrete state/action codes are pinned to `int32` regardless of the + ambient `jax_enable_x64` setting. `jnp.array([...])` would otherwise + produce `int32` in 32-bit mode and `int64` in x64 mode, and + downstream values (transitions, V-array indexing, action lookups) + inherit that ambiguity — which silently splits the JIT cache into + per-period int32/int64 variants and breaks any AOT-compiled + program that ships a single signature. `int32` covers any realistic + category count and matches the `MISSING_CAT_CODE` sentinel. + """ + return jnp.array(self.codes, dtype=jnp.int32) From 458d36fcc6c8ccd47e7b07abb52556db7e478aae Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 16:15:16 +0200 Subject: [PATCH 36/80] Lock integer dtype to int32 end-to-end (#341) Co-authored-by: Claude Opus 4.7 (1M context) --- src/lcm/pandas_utils.py | 3 +- src/lcm/regime_building/argmax.py | 4 +-- src/lcm/simulation/initial_conditions.py | 4 +-- src/lcm/simulation/simulate.py | 4 ++- src/lcm/typing.py | 12 +++---- src/lcm/utils/error_handling.py | 2 +- tests/test_int_dtype_invariants.py | 41 ++++++++++++++++++++++++ 7 files changed, 57 insertions(+), 13 deletions(-) create mode 100644 tests/test_int_dtype_invariants.py diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 2d696a45..cf1b4cb0 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -153,7 +153,8 @@ def initial_conditions_from_dataframe( # noqa: C901 for col, arr in result_arrays.items() } initial_conditions["regime"] = jnp.array( - df["regime"].map(dict(regime_names_to_ids)).to_numpy() + df["regime"].map(dict(regime_names_to_ids)).to_numpy(), + dtype=jnp.int32, ) return initial_conditions diff --git a/src/lcm/regime_building/argmax.py b/src/lcm/regime_building/argmax.py index 0e48cf2b..a4271e2f 100644 --- a/src/lcm/regime_building/argmax.py +++ b/src/lcm/regime_building/argmax.py @@ -43,7 +43,7 @@ def argmax_and_max( # When there are no dimensions to reduce over, return: # - index 0 (trivial argmax since there's only one element) # - the array itself (already the maximum) - return jnp.array(0), a + return jnp.array(0, dtype=jnp.int32), a # Move axis over which to compute the argmax to the back and flatten last dims # ================================================================================== @@ -65,7 +65,7 @@ def argmax_and_max( max_value_mask = a == _max if where is not None: max_value_mask = jnp.logical_and(max_value_mask, where) - _argmax = jnp.argmax(max_value_mask, axis=-1) + _argmax = jnp.argmax(max_value_mask, axis=-1).astype(jnp.int32) return _argmax, _max.reshape(_argmax.shape) diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 22a8dd2e..127059e1 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -362,7 +362,7 @@ def _collect_structural_errors( active_mask = active_mask & (~in_regime | period_active) if not jnp.all(active_mask): - invalid_indices = jnp.where(~active_mask)[0] + invalid_indices = jnp.where(~active_mask)[0].astype(jnp.int32) invalid_combos = { (ids_to_regime_names[int(regime_id_arr[i])], float(age_values[i])) for i in invalid_indices @@ -406,7 +406,7 @@ def _collect_feasibility_errors( errors: list[str] = [] for regime_name, internal_regime in internal_regimes.items(): regime_id = regime_names_to_ids[regime_name] - idx_arr = jnp.where(regime_id_arr == regime_id)[0] + idx_arr = jnp.where(regime_id_arr == regime_id)[0].astype(jnp.int32) subject_indices = idx_arr.tolist() if idx_arr.size > 0 else [] if not subject_indices: continue diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d54d2674..d1ab42ab 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -99,7 +99,9 @@ def simulate( starting_periods = _compute_starting_periods( initial_ages=initial_states["age"], ages=ages ) - subject_regime_ids = jnp.full_like(initial_conditions["regime"], MISSING_CAT_CODE) + subject_regime_ids = jnp.full_like( + initial_conditions["regime"], MISSING_CAT_CODE, dtype=jnp.int32 + ) # Forward simulation simulation_results: dict[RegimeName, dict[int, PeriodRegimeSimulationData]] = { diff --git a/src/lcm/typing.py b/src/lcm/typing.py index c73b4c33..62492770 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -4,27 +4,27 @@ import pandas as pd from jax import Array -from jaxtyping import Bool, Float, Int, Scalar +from jaxtyping import Bool, Float, Int32, Scalar from lcm.params import MappingLeaf from lcm.params.sequence_leaf import SequenceLeaf type ContinuousState = Float[Array, "..."] type ContinuousAction = Float[Array, "..."] -type DiscreteState = Int[Array, "..."] -type DiscreteAction = Int[Array, "..."] +type DiscreteState = Int32[Array, "..."] +type DiscreteAction = Int32[Array, "..."] type FloatND = Float[Array, "..."] -type IntND = Int[Array, "..."] +type IntND = Int32[Array, "..."] type BoolND = Bool[Array, "..."] type Float1D = Float[Array, "_"] # noqa: F821 -type Int1D = Int[Array, "_"] # noqa: F821 +type Int1D = Int32[Array, "_"] # noqa: F821 type Bool1D = Bool[Array, "_"] # noqa: F821 # Many JAX functions are designed to work with scalar numerical values. This also # includes zero dimensional jax arrays. -type ScalarInt = int | Int[Scalar, ""] +type ScalarInt = int | Int32[Scalar, ""] type ScalarFloat = float | Float[Scalar, ""] type Period = int | Int1D diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index ef631349..faeed9fe 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -372,7 +372,7 @@ def _format_sum_violation( {name: jnp.atleast_1d(arr) for name, arr in state_action_values.items()} ) failing_mask = ~jnp.isclose(sum_all, 1.0) - failing_indices = jnp.where(failing_mask)[0] + failing_indices = jnp.where(failing_mask)[0].astype(jnp.int32) failing_sums = sum_all[failing_mask] n_failing = int(failing_indices.shape[0]) n_show = min(n_failing, 5) diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py new file mode 100644 index 00000000..3147de3f --- /dev/null +++ b/tests/test_int_dtype_invariants.py @@ -0,0 +1,41 @@ +"""Integer dtypes are pinned to int32 across pylcm regardless of x64 mode.""" + +import jax.numpy as jnp + +from lcm.simulation.initial_conditions import ( + MISSING_CAT_CODE, + build_initial_states, +) +from tests.test_models.deterministic.regression import get_model + + +def test_discrete_grid_to_jax_is_int32() -> None: + model = get_model(n_periods=3) + for regime in model.regimes.values(): + for grid in {**regime.states, **regime.actions}.values(): + jax_arr = grid.to_jax() + if jax_arr.dtype.kind == "i": + assert jax_arr.dtype == jnp.int32, ( + f"Discrete grid yielded {jax_arr.dtype}, expected int32." + ) + + +def test_build_initial_states_discrete_dtype_is_int32() -> None: + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.array([20.0, 50.0]), + "age": jnp.array([18.0, 18.0]), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + for key, arr in flat.items(): + if arr.dtype.kind == "i": + assert arr.dtype == jnp.int32, ( + f"Initial state {key} has dtype {arr.dtype}, expected int32." + ) + + +def test_missing_cat_code_is_int32_minimum() -> None: + assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE From e8ede00f0ede530e4c0142b759cc7f69a39f8474 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 17:31:48 +0200 Subject: [PATCH 37/80] benchmarks: bump aca-model rev; drop max_consumption kwarg aca-model dropped the `max_consumption` factory kwarg in favour of a canonical `MAX_CONSUMPTION` constant in `baseline.regimes._common` attached to `model.max_consumption`. Update the pin and the bench-script call accordingly. Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmarks/bench_aca_baseline.py | 5 +---- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 064281fe..a9364879 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -44,7 +44,6 @@ from benchmarks import _gpu_mem _N_SUBJECTS = 1000 -_MAX_CONSUMPTION = 300_000.0 def _build() -> tuple[object, object, object]: @@ -55,9 +54,7 @@ def _build() -> tuple[object, object, object]: get_benchmark_params, ) - model = create_benchmark_model( - n_subjects=_N_SUBJECTS, max_consumption=_MAX_CONSUMPTION - ) + model = create_benchmark_model(n_subjects=_N_SUBJECTS) _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 diff --git a/pixi.lock b/pixi.lock index b2f7c782..46a12d93 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=cdd10169a13fbc74604f9e879276ddb4c17b53c4#cdd10169a13fbc74604f9e879276ddb4c17b53c4 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=31a0ad20e70c9e1859f4268cd954979070d8b17f#31a0ad20e70c9e1859f4268cd954979070d8b17f - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=cdd10169a13fbc74604f9e879276ddb4c17b53c4#cdd10169a13fbc74604f9e879276ddb4c17b53c4 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=31a0ad20e70c9e1859f4268cd954979070d8b17f#31a0ad20e70c9e1859f4268cd954979070d8b17f name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev151+gbb822917b.d20260502 - sha256: f321dce724999c07d44f6fe6a27c571bde1a33e464cc81c18a613f93889169cc + version: 0.0.2.dev159+g430609206.d20260503 + sha256: ae00b59638af14bf39d12c29ba8ee6c2664799b254ebc4a2c7dffe20d0013e80 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 64e8ae0c..b8d5b7fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "cdd10169a13fbc74604f9e879276ddb4c17b53c4" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "31a0ad20e70c9e1859f4268cd954979070d8b17f" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 4316b6e2c5a674502df651e8d2f7ad8274fdf620 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 20:36:39 +0200 Subject: [PATCH 38/80] solve_brute: merge resolved_fixed_params into NaN diagnostic regime_params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `_raise_at` calls `compute_intermediates` directly (not through the solve loop's partialled closures), but was passing only the per-iteration `regime_params` slice from `internal_params`. The diagnostic was traced against the regime's full `flat_param_names` — per-iteration plus fixed — so the call failed with `ValueError: ... missing: {}` and the "Diagnostic enrichment failed; raising original NaN error" warning fired instead of surfacing the actual NaN-source breakdown. Mirror the merge pattern from `interfaces.state_action_space` and `simulation.result.to_dataframe`: build `{**resolved_fixed_params, **regime_params}` before forwarding to `validate_V`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 653b013d..c80dceb2 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -518,6 +518,15 @@ def _raise_at( """Run the enriched NaN diagnostic on a single offending row and raise.""" internal_regime = internal_regimes[row.regime_name] regime_params = internal_params[row.regime_name] + # `compute_intermediates` was built from the regime's full `flat_param_names` + # (per-iteration params + fixed params); the live solve loop merges + # `resolved_fixed_params` into `regime_params` implicitly via the partialled + # closures, but we have to do it by hand here to call the diagnostic + # directly. Same merge order as `interfaces.state_action_space` and + # `simulation.result`. + diag_params = MappingProxyType( + {**internal_regime.resolved_fixed_params, **regime_params} + ) state_action_space = internal_regime.state_action_space(regime_params=regime_params) next_regime_to_V_arr = _reconstruct_next_regime_to_V_arr( period=row.period, @@ -537,7 +546,7 @@ def _raise_at( compute_intermediates=compute_intermediates, state_action_space=state_action_space, next_regime_to_V_arr=next_regime_to_V_arr, - internal_params=regime_params, + internal_params=diag_params, period=row.period, ) From 866a5bbb138d3946b2ffab08d5245d849ec57f6f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 07:17:21 +0200 Subject: [PATCH 39/80] benchmarks: bump aca-model rev to 714fee0 (assets-floor margin) Co-Authored-By: Claude Opus 4.7 (1M context) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 46a12d93..340187f8 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=31a0ad20e70c9e1859f4268cd954979070d8b17f#31a0ad20e70c9e1859f4268cd954979070d8b17f + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=714fee0496c63547da047670fae058acfae6bfa2#714fee0496c63547da047670fae058acfae6bfa2 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=31a0ad20e70c9e1859f4268cd954979070d8b17f#31a0ad20e70c9e1859f4268cd954979070d8b17f +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=714fee0496c63547da047670fae058acfae6bfa2#714fee0496c63547da047670fae058acfae6bfa2 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev159+g430609206.d20260503 - sha256: ae00b59638af14bf39d12c29ba8ee6c2664799b254ebc4a2c7dffe20d0013e80 + version: 0.0.2.dev161+g5e09d4684.d20260504 + sha256: 6f213ba89f5e74b2c69744cc062f8f1088d8f3b31475e1f8d142128b0d111504 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index b8d5b7fd..bc8e74b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "31a0ad20e70c9e1859f4268cd954979070d8b17f" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "714fee0496c63547da047670fae058acfae6bfa2" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From a51edae6f8d759ad8bf5856e4fd80382c38a1936 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 10:53:23 +0200 Subject: [PATCH 40/80] regime_template: exempt next_ names from fixed_param extraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A state transition that consumes another state transition's output (e.g. `next_wealth(next_aime, ...)`) was rejected at solve time with `InvalidParamsError: Missing required parameter: '__next_wealth__next_aime'`. The upstream cause is in `create_regime_params_template`, which classified `next_` references inside transition signatures as regime-level fixed_params. `get_next_state_function_for_solution` already merges all transitions and DAG functions into a single dict before calling `concatenate_functions`, so dags resolves the chain at evaluation time. The block is purely the params-template step. Extending the exempt set with `{f"next_{name}" for name in regime.states}` lets dags do its job. This unlocks per-target transition factories whose outputs feed each other — e.g., a `next_assets` that reads `next_aime` to compute a next-period imputed value (the aca-model pension correction use case). No JAX parallelism implications: the fix is build-time bookkeeping only. JIT scope, vmap structure, and scan layout are unchanged; the new dependency edge runs per-gridpoint inside the same merged DAG. --- src/lcm/params/regime_template.py | 16 ++- .../test_create_regime_params_template.py | 40 +++++++ tests/test_chained_state_transitions.py | 100 ++++++++++++++++++ 3 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 tests/test_chained_state_transitions.py diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 85a66b47..8f9765a5 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -23,8 +23,15 @@ def create_regime_params_template( """Create parameter template from a regime specification. Discover parameters from function signatures via `dags.tree`. Parameters are - function arguments that are not states, actions, other regime functions, or - special variables (period, age, E_next_V). + function arguments that are not states, actions, other regime functions, + `next_` transition outputs, or special variables (period, age, + E_next_V). + + The `next_` exemption lets a state transition consume the output of + another state transition: dags resolves the chain at evaluation time + (`get_next_state_function_for_solution` merges all transitions and DAG + functions into a single dict before calling `concatenate_functions`), so + these names must not surface as user-facing fixed_params. For `SolveSimulateFunctionPair` entries, the template contains the **union** of both variants' parameters so the user can provide a single flat params @@ -42,7 +49,10 @@ def create_regime_params_template( """ H_variables = {*regime.functions, "period", "age", "E_next_V"} - variables = H_variables | set(regime.actions) | set(regime.states) + next_state_names = {f"next_{name}" for name in regime.states} + variables = ( + H_variables | set(regime.actions) | set(regime.states) | next_state_names + ) function_params: dict[FunctionName, dict[str, str]] = {} diff --git a/tests/regime_building/test_create_regime_params_template.py b/tests/regime_building/test_create_regime_params_template.py index 9f912070..62d5a5cb 100644 --- a/tests/regime_building/test_create_regime_params_template.py +++ b/tests/regime_building/test_create_regime_params_template.py @@ -150,3 +150,43 @@ def test_regular_function_taking_state_as_argument_no_error(binary_category_clas "next_regime": {}, } ) + + +def test_state_transition_consuming_other_next_state_is_not_a_param( + binary_category_class, +): + """`next_` names are exempt from param-template extraction. + + A state transition (here, `next_wealth`) that consumes another transition's + output (here, `next_aime`) must not have `next_aime` classified as a + regime-level fixed_param. dags resolves the chain at evaluation time + (`get_next_state_function_for_solution` merges all transitions into a + single dict before calling `concatenate_functions`). + """ + + def next_wealth(wealth: float, next_aime: float) -> float: + return wealth + next_aime + + regime = RegimeMock( + actions={"a": DiscreteGrid(binary_category_class)}, + states={ + "wealth": DiscreteGrid(binary_category_class), + "aime": DiscreteGrid(binary_category_class), + }, + state_transitions={ + "wealth": next_wealth, + "aime": lambda aime: aime, + }, + transition=lambda: 0, + functions={"utility": lambda a, wealth, aime: None}, # noqa: ARG005 + ) + got = create_regime_params_template(regime) # ty: ignore[invalid-argument-type] + assert got == ensure_containers_are_immutable( + { + "H": {"discount_factor": "float"}, + "utility": {}, + "next_wealth": {}, + "next_aime": {}, + "next_regime": {}, + } + ) diff --git a/tests/test_chained_state_transitions.py b/tests/test_chained_state_transitions.py new file mode 100644 index 00000000..4d8a61d4 --- /dev/null +++ b/tests/test_chained_state_transitions.py @@ -0,0 +1,100 @@ +"""End-to-end check that one state transition can consume another's output. + +dags resolves dependencies between state-transition functions when they +appear in the merged transitions+functions dict that +`get_next_state_function_for_solution` builds. The blocker fixed here is in +the upstream `create_regime_params_template`: it must not classify +`next_` names as regime-level fixed_params, otherwise param resolution +fails before dags ever runs. +""" + +import jax.numpy as jnp + +from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.typing import DiscreteAction, FloatND, ScalarInt + + +@categorical(ordered=False) +class _LaborSupply: + work: int + rest: int + + +@categorical(ordered=False) +class _RegimeId: + active: int + dead: int + + +def _next_aime(aime: float, labor_supply: DiscreteAction) -> FloatND: + """AIME accumulates only when working.""" + return aime + jnp.where(labor_supply == _LaborSupply.work, 1.0, 0.0) + + +def _next_wealth(wealth: float, consumption: float, next_aime: FloatND) -> FloatND: + """Next-period wealth depends on next-period AIME (the chained transition). + + The economically interesting use is `pia = f(next_aime)` feeding + next_wealth via a pension correction. Here we keep the dependency simple + so the test focuses on the wiring, not the economics. + """ + return wealth - consumption + 0.1 * next_aime + + +def _utility(consumption: float, labor_supply: DiscreteAction) -> FloatND: + disutility = jnp.where(labor_supply == _LaborSupply.work, 0.5, 0.0) + return jnp.log(jnp.maximum(consumption, 1e-6)) - disutility + + +def _next_regime(age: int, final_age_alive: float) -> ScalarInt: + return jnp.where(age >= final_age_alive, _RegimeId.dead, _RegimeId.active) + + +_active = Regime( + transition=_next_regime, + actions={ + "labor_supply": DiscreteGrid(_LaborSupply), + "consumption": LinSpacedGrid(start=0.5, stop=2.0, n_points=3), + }, + states={ + "aime": LinSpacedGrid(start=0.0, stop=4.0, n_points=3), + "wealth": LinSpacedGrid(start=0.5, stop=5.0, n_points=3), + }, + state_transitions={ + "aime": _next_aime, + "wealth": _next_wealth, + }, + functions={"utility": _utility}, + active=lambda age: age < 2, +) + + +_dead = Regime(transition=None, functions={"utility": lambda: jnp.array(0.0)}) + + +def _build_model() -> Model: + return Model( + regimes={"active": _active, "dead": _dead}, + ages=AgeGrid(start=0, stop=3, step="Y"), + regime_id_class=_RegimeId, + ) + + +def test_solve_resolves_chain_via_dags() -> None: + """`solve()` runs and dags wires `next_aime → next_wealth` correctly. + + Before the fix, `_resolve_fixed_params` raised + `InvalidParamsError: Missing required parameter: + 'active__next_wealth__next_aime'` because `create_regime_params_template` + classified `next_aime` (a `next_` reference inside another + transition's signature) as a regime-level fixed_param. + """ + model = _build_model() + params = { + "discount_factor": 0.9, + "final_age_alive": 1.0, + } + period_to_regime_to_V_arr = model.solve(params=params) + for regime_to_V_arr in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V_arr.values(): + assert not jnp.any(jnp.isnan(V_arr)) From d66d85a580c9b437112ec2510113d6f4e62817b0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:30:42 +0200 Subject: [PATCH 41/80] Boilerplate refresh: dags module, current pixi/hook pins, drop stale entries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - AGENTS.md: include @.ai-instructions/modules/dags.md (pylcm is dags-built; the convention/idioms are project-relevant). - .pre-commit-config.yaml: narrow GFM mdformat `files:` to root-level (AGENTS|CLAUDE|README); the `modules/.*|profiles/.*` patterns were intended for the .ai-instructions repo and never matched here. Bump check-jsonschema 0.37.1 → 0.37.2 and ruff v0.15.11 → v0.15.12 to current. - .gitignore: drop unused `# pytask` section (pylcm doesn't use pytask). - .github/workflows/main.yml: pixi-version v0.66.0 → v0.67.2 to match the locally installed pixi. All other pinned hooks/actions verified at-or-newer than the boilerplate template baseline. --- .github/workflows/main.yml | 10 +++++----- .gitignore | 3 --- .pre-commit-config.yaml | 6 +++--- AGENTS.md | 1 + 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d0ea4a07..80b272c1 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cpu @@ -59,7 +59,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: type-checking @@ -82,7 +82,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cuda12 @@ -101,7 +101,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cuda12 @@ -116,7 +116,7 @@ jobs: # - uses: actions/checkout@v6 # - uses: prefix-dev/setup-pixi@v0.9.5 # with: - # pixi-version: v0.66.0 + # pixi-version: v0.67.2 # cache: true # cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} # environments: tests-cpu diff --git a/.gitignore b/.gitignore index db181e4d..03acd3c8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,9 +31,6 @@ docs/_build/ .pixi/ node_modules/ -# pytask -.pytask.sqlite3 - # Python __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 270b9b25..f51ffb31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,11 +50,11 @@ repos: hooks: - id: yamllint - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.37.1 + rev: 0.37.2 hooks: - id: check-github-workflows - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.11 + rev: v0.15.12 hooks: - id: ruff-check args: @@ -86,7 +86,7 @@ repos: args: - --wrap - '88' - files: (AGENTS\.md|CLAUDE\.md|README\.md|modules/.*\.md|profiles/.*\.md) + files: (AGENTS\.md|CLAUDE\.md|README\.md) - id: mdformat additional_dependencies: - mdformat-myst diff --git a/AGENTS.md b/AGENTS.md index f9777113..7aea2844 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,6 @@ @.ai-instructions/profiles/tier-a.md @.ai-instructions/modules/jax.md @.ai-instructions/modules/pandas.md @.ai-instructions/modules/plotting.md +@.ai-instructions/modules/dags.md # PyLCM From f18d27c7b78ece836bf35bdfddd65d9e35432fc6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:33:21 +0200 Subject: [PATCH 42/80] Bump .ai-instructions: pyproject-fmt + ruff + check-jsonschema rev pins --- .ai-instructions | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ai-instructions b/.ai-instructions index d15d554a..528a0118 160000 --- a/.ai-instructions +++ b/.ai-instructions @@ -1 +1 @@ -Subproject commit d15d554a54785e84cb80165443fa432a22adc45c +Subproject commit 528a0118ff9e02233bfc073da891b60e81b34754 From e45cf4fb6a1847739d9140c031fa69ce2deb5cbc Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 13:47:53 +0200 Subject: [PATCH 43/80] regime_template: reject next_ params on regular DAG functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Companion to the `next_` exemption: the same exemption that lets state transitions consume each other's outputs would silently filter a typo'd `next_` parameter out of the template for any non-transition function (utility, helpers, custom H), and dags would wire that param to the actual transition output at evaluation time — yielding a wrong result with no error. `_fail_if_non_transition_consumes_next_state` catches the mistake at template-construction time. State transitions and constraints are exempt: constraints legitimately depend on transition outputs (issue #230, e.g. `borrowing_constraint( next_assets) -> next_assets >= 0`). New tests: - non-transition function with `next_` param raises InvalidNameError - constraint with `next_` param is allowed (regression guard for #230) --- src/lcm/params/regime_template.py | 47 ++++++++++++++++- .../test_create_regime_params_template.py | 51 +++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 8f9765a5..04628c20 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -50,6 +50,7 @@ def create_regime_params_template( """ H_variables = {*regime.functions, "period", "age", "E_next_V"} next_state_names = {f"next_{name}" for name in regime.states} + constraint_names = set(regime.constraints) variables = ( H_variables | set(regime.actions) | set(regime.states) | next_state_names ) @@ -64,13 +65,22 @@ def create_regime_params_template( else: tree = dt.create_tree_with_input_types({name: func}) + path = tree_path_from_qname(name) + + _fail_if_non_transition_consumes_next_state( + func_name=name, + path=path, + param_names=set(tree), + next_state_names=next_state_names, + constraint_names=constraint_names, + ) + # H is exempt from param-template extraction for state/action names # that appear in its signature: pylcm wires those values through # `states_actions_params` at call time, so they must not surface as # user-facing params in the template. params = {k: v for k, v in sorted(tree.items()) if k not in variables} - path = tree_path_from_qname(name) template_key = f"to_{path[1]}_{path[0]}" if len(path) > 1 else name if template_key in function_params: @@ -168,6 +178,41 @@ def _collect_all_functions_for_template( return result +def _fail_if_non_transition_consumes_next_state( + *, + func_name: str, + path: tuple[str, ...], + param_names: set[str], + next_state_names: set[str], + constraint_names: set[str], +) -> None: + """Reject `next_` parameters on regular DAG functions. + + The `next_` exemption from fixed_param extraction lets state + transitions consume each other's outputs (dags resolves the chain at + evaluation time). Constraints also legitimately depend on transition + outputs (e.g. `borrowing_constraint(next_assets)` — see issue #230). + For everything else (utility, helpers, custom H), a `next_` + parameter name is almost always a typo where the user meant the + current-period ``. Without this guard, the typo'd param would + be silently filtered from the template and wired to the transition + output via dags, yielding a wrong result with no error. Catch the + mistake at template-construction time. + """ + head = path[0] + if head == "next_regime" or head in next_state_names or head in constraint_names: + return + typos = sorted(param_names & next_state_names) + if typos: + raise InvalidNameError( + f"Function {func_name!r} has parameter(s) {typos} matching " + f"reserved `next_` transition-output names. Drop the " + f"'next_' prefix to use the current-period state, or move the " + f"logic into a state transition (or constraint) if the " + f"next-period value is genuinely needed." + ) + + def _validate_no_shadowing( function_params: dict[FunctionName, dict[str, str]], regime: Regime, diff --git a/tests/regime_building/test_create_regime_params_template.py b/tests/regime_building/test_create_regime_params_template.py index 62d5a5cb..2052961f 100644 --- a/tests/regime_building/test_create_regime_params_template.py +++ b/tests/regime_building/test_create_regime_params_template.py @@ -1,3 +1,6 @@ +import pytest + +from lcm.exceptions import InvalidNameError from lcm.grids import DiscreteGrid from lcm.interfaces import SolveSimulateFunctionPair from lcm.params.regime_template import ( @@ -152,6 +155,54 @@ def test_regular_function_taking_state_as_argument_no_error(binary_category_clas ) +def test_non_transition_function_with_next_state_param_raises( + binary_category_class, +): + """A non-transition function declaring a `next_` parameter must error. + + Without this guard, a typo like `def utility(consumption, next_wealth)` + (intended: `wealth`) would silently be wired to the `next_wealth` + transition output via dags, returning a wrong result rather than raising + a missing-param error. + """ + + def bad_utility(a, wealth, next_wealth): # noqa: ARG001 + return None + + regime = RegimeMock( + actions={"a": DiscreteGrid(binary_category_class)}, + states={"wealth": DiscreteGrid(binary_category_class)}, + state_transitions={"wealth": lambda wealth: wealth}, + transition=lambda: 0, + functions={"utility": bad_utility}, + ) + with pytest.raises(InvalidNameError, match="next_wealth"): + create_regime_params_template(regime) # ty: ignore[invalid-argument-type] + + +def test_constraint_consuming_next_state_param_is_allowed(binary_category_class): + """Constraints may depend on transition outputs (issue #230). + + The `next_` validator skips constraints because checks like + `borrowing_constraint(next_assets) -> next_assets >= 0` are the + intended use of the chained-transition resolution. + """ + + def borrowing_constraint(next_wealth): # noqa: ARG001 + return None + + regime = RegimeMock( + actions={"a": DiscreteGrid(binary_category_class)}, + states={"wealth": DiscreteGrid(binary_category_class)}, + state_transitions={"wealth": lambda wealth: wealth}, + transition=lambda: 0, + functions={"utility": lambda a, wealth: None}, # noqa: ARG005 + constraints={"borrowing_constraint": borrowing_constraint}, + ) + # Must not raise; constraint legitimately consumes `next_wealth`. + create_regime_params_template(regime) # ty: ignore[invalid-argument-type] + + def test_state_transition_consuming_other_next_state_is_not_a_param( binary_category_class, ): From c585f4b821f2951dcd6d37625db28a98702cf82d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 17:50:13 +0200 Subject: [PATCH 44/80] Bump aca-model benchmark pin to 83f22500 (post pension correction) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The benchmarks feature was pinned to 134286108 (pre-pension-correction aca-model). The aca-baseline benchmark on top of that rev surfaces the worst-shock / low-asset / mid-AIME corner NaN at age 51 in retiree_nomc_inelig_canwork — exactly the bug the per-target pension imputation correction (aca-model 4ae4446) was written to fix. Bump to 83f22500 ('Bump pyproject-fmt + ruff-pre-commit pins'), the current tip of feature/runtime-consumption-points, which carries the correction. --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index e507ccc8..c384e3ca 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev130+g3769d2d63.d20260429 - sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e + version: 0.0.2.dev152+ge45cf4fb6.d20260504 + sha256: b7ae3c66cce67b2575697d5e1c3e1b3cf8ed3f3ea41d9a5a5e609536b1ca5d89 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 3fe01162..e1a6a008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "83f22500e97a6675aa4cd15235dea359dae94f2d" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From f0dd7b5b9cb4a8e1b63831b54446dbae9fbde668 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 18:39:52 +0200 Subject: [PATCH 45/80] Revert aca-model pin: this branch lacks Model.n_subjects (introduced on #340) aca-model 83f22500's create_benchmark_model requires n_subjects=, which in turn requires pylcm's Model.n_subjects argument. That work lives on feat/simulate-aot-n-subjects (#340) and is not in this branch yet, so keep the pin at 134286108 here. The pension-correction bump only takes effect on #340 once that PR's benchmark call site is updated to pass n_subjects=_N_SUBJECTS. --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index c384e3ca..5c8127d0 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev152+ge45cf4fb6.d20260504 - sha256: b7ae3c66cce67b2575697d5e1c3e1b3cf8ed3f3ea41d9a5a5e609536b1ca5d89 + version: 0.0.2.dev153+gc585f4b82.d20260504 + sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index e1a6a008..3fe01162 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "83f22500e97a6675aa4cd15235dea359dae94f2d" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From c00b610a8a45fb8ceed993f2802a95dd52c00dd3 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 18:41:19 +0200 Subject: [PATCH 46/80] Bump aca-model pin to 83f22500 on #340 (carries pension correction) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The merge from #339 reverted the pin to 134286108, which is pre- n_subjects-requirement and would TypeError on the benchmark's create_benchmark_model(n_subjects=_N_SUBJECTS) call. Restore the 83f22500 pin (post-pension-correction, accepts n_subjects) — only this branch's pylcm has Model.n_subjects, so the bump is safe here. --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 5c8127d0..c4749747 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev153+gc585f4b82.d20260504 - sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e + version: 0.0.2.dev184+g588c9c413.d20260504 + sha256: b7ae3c66cce67b2575697d5e1c3e1b3cf8ed3f3ea41d9a5a5e609536b1ca5d89 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 3fe01162..e1a6a008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "83f22500e97a6675aa4cd15235dea359dae94f2d" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From e89d5e4144affcde0987d44d82626017b57bbc47 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 18:54:46 +0200 Subject: [PATCH 47/80] Revert "regime_template: reject next_ params on regular DAG functions" This reverts commit e45cf4fb6a1847739d9140c031fa69ce2deb5cbc. --- src/lcm/params/regime_template.py | 47 +---------------- .../test_create_regime_params_template.py | 51 ------------------- 2 files changed, 1 insertion(+), 97 deletions(-) diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 04628c20..8f9765a5 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -50,7 +50,6 @@ def create_regime_params_template( """ H_variables = {*regime.functions, "period", "age", "E_next_V"} next_state_names = {f"next_{name}" for name in regime.states} - constraint_names = set(regime.constraints) variables = ( H_variables | set(regime.actions) | set(regime.states) | next_state_names ) @@ -65,22 +64,13 @@ def create_regime_params_template( else: tree = dt.create_tree_with_input_types({name: func}) - path = tree_path_from_qname(name) - - _fail_if_non_transition_consumes_next_state( - func_name=name, - path=path, - param_names=set(tree), - next_state_names=next_state_names, - constraint_names=constraint_names, - ) - # H is exempt from param-template extraction for state/action names # that appear in its signature: pylcm wires those values through # `states_actions_params` at call time, so they must not surface as # user-facing params in the template. params = {k: v for k, v in sorted(tree.items()) if k not in variables} + path = tree_path_from_qname(name) template_key = f"to_{path[1]}_{path[0]}" if len(path) > 1 else name if template_key in function_params: @@ -178,41 +168,6 @@ def _collect_all_functions_for_template( return result -def _fail_if_non_transition_consumes_next_state( - *, - func_name: str, - path: tuple[str, ...], - param_names: set[str], - next_state_names: set[str], - constraint_names: set[str], -) -> None: - """Reject `next_` parameters on regular DAG functions. - - The `next_` exemption from fixed_param extraction lets state - transitions consume each other's outputs (dags resolves the chain at - evaluation time). Constraints also legitimately depend on transition - outputs (e.g. `borrowing_constraint(next_assets)` — see issue #230). - For everything else (utility, helpers, custom H), a `next_` - parameter name is almost always a typo where the user meant the - current-period ``. Without this guard, the typo'd param would - be silently filtered from the template and wired to the transition - output via dags, yielding a wrong result with no error. Catch the - mistake at template-construction time. - """ - head = path[0] - if head == "next_regime" or head in next_state_names or head in constraint_names: - return - typos = sorted(param_names & next_state_names) - if typos: - raise InvalidNameError( - f"Function {func_name!r} has parameter(s) {typos} matching " - f"reserved `next_` transition-output names. Drop the " - f"'next_' prefix to use the current-period state, or move the " - f"logic into a state transition (or constraint) if the " - f"next-period value is genuinely needed." - ) - - def _validate_no_shadowing( function_params: dict[FunctionName, dict[str, str]], regime: Regime, diff --git a/tests/regime_building/test_create_regime_params_template.py b/tests/regime_building/test_create_regime_params_template.py index 2052961f..62d5a5cb 100644 --- a/tests/regime_building/test_create_regime_params_template.py +++ b/tests/regime_building/test_create_regime_params_template.py @@ -1,6 +1,3 @@ -import pytest - -from lcm.exceptions import InvalidNameError from lcm.grids import DiscreteGrid from lcm.interfaces import SolveSimulateFunctionPair from lcm.params.regime_template import ( @@ -155,54 +152,6 @@ def test_regular_function_taking_state_as_argument_no_error(binary_category_clas ) -def test_non_transition_function_with_next_state_param_raises( - binary_category_class, -): - """A non-transition function declaring a `next_` parameter must error. - - Without this guard, a typo like `def utility(consumption, next_wealth)` - (intended: `wealth`) would silently be wired to the `next_wealth` - transition output via dags, returning a wrong result rather than raising - a missing-param error. - """ - - def bad_utility(a, wealth, next_wealth): # noqa: ARG001 - return None - - regime = RegimeMock( - actions={"a": DiscreteGrid(binary_category_class)}, - states={"wealth": DiscreteGrid(binary_category_class)}, - state_transitions={"wealth": lambda wealth: wealth}, - transition=lambda: 0, - functions={"utility": bad_utility}, - ) - with pytest.raises(InvalidNameError, match="next_wealth"): - create_regime_params_template(regime) # ty: ignore[invalid-argument-type] - - -def test_constraint_consuming_next_state_param_is_allowed(binary_category_class): - """Constraints may depend on transition outputs (issue #230). - - The `next_` validator skips constraints because checks like - `borrowing_constraint(next_assets) -> next_assets >= 0` are the - intended use of the chained-transition resolution. - """ - - def borrowing_constraint(next_wealth): # noqa: ARG001 - return None - - regime = RegimeMock( - actions={"a": DiscreteGrid(binary_category_class)}, - states={"wealth": DiscreteGrid(binary_category_class)}, - state_transitions={"wealth": lambda wealth: wealth}, - transition=lambda: 0, - functions={"utility": lambda a, wealth: None}, # noqa: ARG005 - constraints={"borrowing_constraint": borrowing_constraint}, - ) - # Must not raise; constraint legitimately consumes `next_wealth`. - create_regime_params_template(regime) # ty: ignore[invalid-argument-type] - - def test_state_transition_consuming_other_next_state_is_not_a_param( binary_category_class, ): From e92eeeca2a33e7e8b54ed7541d35e751fcfa6842 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 20:48:50 +0200 Subject: [PATCH 48/80] Bump aca-model pin to 3453080 (filters stale benchmark_params key) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index c4749747..af799833 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=3453080fd08afa049483f6ddda215a998a55b757#3453080fd08afa049483f6ddda215a998a55b757 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=83f22500e97a6675aa4cd15235dea359dae94f2d#83f22500e97a6675aa4cd15235dea359dae94f2d +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=3453080fd08afa049483f6ddda215a998a55b757#3453080fd08afa049483f6ddda215a998a55b757 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev184+g588c9c413.d20260504 - sha256: b7ae3c66cce67b2575697d5e1c3e1b3cf8ed3f3ea41d9a5a5e609536b1ca5d89 + version: 0.0.2.dev189+gefeedbf75.d20260504 + sha256: 82d8bf94d16bbabdd3dd3f642cc6a28f0d42192f60a8c7e16bd83ee1026d9fe1 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index e1a6a008..71d3b6c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "83f22500e97a6675aa4cd15235dea359dae94f2d" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "3453080fd08afa049483f6ddda215a998a55b757" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From c117f4cdb8e6fd8275eea959fca7adad964c3c04 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:04:42 +0200 Subject: [PATCH 49/80] Bump aca-model pin to b2e90bb (synthesise shifted imputation arrays) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index af799833..cb5b244b 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=3453080fd08afa049483f6ddda215a998a55b757#3453080fd08afa049483f6ddda215a998a55b757 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=b2e90bb58a1c6721046a3e860a95a29485b25117#b2e90bb58a1c6721046a3e860a95a29485b25117 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=3453080fd08afa049483f6ddda215a998a55b757#3453080fd08afa049483f6ddda215a998a55b757 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=b2e90bb58a1c6721046a3e860a95a29485b25117#b2e90bb58a1c6721046a3e860a95a29485b25117 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev189+gefeedbf75.d20260504 - sha256: 82d8bf94d16bbabdd3dd3f642cc6a28f0d42192f60a8c7e16bd83ee1026d9fe1 + version: 0.0.2.dev189+ge92eeeca2.d20260504 + sha256: 1f023d258794ddbf0191f6ef516539809a1e33f8b095645de46845a1ce573176 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 71d3b6c1..55015c16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "3453080fd08afa049483f6ddda215a998a55b757" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "b2e90bb58a1c6721046a3e860a95a29485b25117" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From e422876934c5d6792ea6c7859d3272a3bd533261 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:18:58 +0200 Subject: [PATCH 50/80] Bump aca-model pin to 35eddcc (declare target_his derived categorical) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index cb5b244b..e82c2f83 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=b2e90bb58a1c6721046a3e860a95a29485b25117#b2e90bb58a1c6721046a3e860a95a29485b25117 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=35eddcc9ee06b960c30c6ea09e3f07541f3144a6#35eddcc9ee06b960c30c6ea09e3f07541f3144a6 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=b2e90bb58a1c6721046a3e860a95a29485b25117#b2e90bb58a1c6721046a3e860a95a29485b25117 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=35eddcc9ee06b960c30c6ea09e3f07541f3144a6#35eddcc9ee06b960c30c6ea09e3f07541f3144a6 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev189+ge92eeeca2.d20260504 - sha256: 1f023d258794ddbf0191f6ef516539809a1e33f8b095645de46845a1ce573176 + version: 0.0.2.dev190+gc117f4cdb.d20260504 + sha256: 935c7a53314794cfdd4f71bf017c5c92b32a9c7ceb63230aa9d0af512fd54bda requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 55015c16..375c84e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "b2e90bb58a1c6721046a3e860a95a29485b25117" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "35eddcc9ee06b960c30c6ea09e3f07541f3144a6" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 4b9bea35a8cd490b8471d727c1d822a4c05e02fd Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:33:02 +0200 Subject: [PATCH 51/80] Bump aca-model pin to 64d6567 (rename shifted-array level to target_his) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index e82c2f83..59863a71 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=35eddcc9ee06b960c30c6ea09e3f07541f3144a6#35eddcc9ee06b960c30c6ea09e3f07541f3144a6 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=64d656791230ebf20622c247bf2935880de2fcfd#64d656791230ebf20622c247bf2935880de2fcfd - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=35eddcc9ee06b960c30c6ea09e3f07541f3144a6#35eddcc9ee06b960c30c6ea09e3f07541f3144a6 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=64d656791230ebf20622c247bf2935880de2fcfd#64d656791230ebf20622c247bf2935880de2fcfd name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev190+gc117f4cdb.d20260504 - sha256: 935c7a53314794cfdd4f71bf017c5c92b32a9c7ceb63230aa9d0af512fd54bda + version: 0.0.2.dev191+ge42287693.d20260504 + sha256: 1d637215998bf403e114fb5af8f298687f723eb24920bc026f03c0139a5f2e37 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 375c84e9..293eef8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "35eddcc9ee06b960c30c6ea09e3f07541f3144a6" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "64d656791230ebf20622c247bf2935880de2fcfd" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 9c7edb6c6273d1bcb74716f619e9ce07c8324a6e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 22:03:58 +0200 Subject: [PATCH 52/80] Bump aca-model pin to f09b5e3 (per-target next_assets, dead-target terminal version) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 59863a71..a65be014 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=64d656791230ebf20622c247bf2935880de2fcfd#64d656791230ebf20622c247bf2935880de2fcfd + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=64d656791230ebf20622c247bf2935880de2fcfd#64d656791230ebf20622c247bf2935880de2fcfd +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev191+ge42287693.d20260504 - sha256: 1d637215998bf403e114fb5af8f298687f723eb24920bc026f03c0139a5f2e37 + version: 0.0.2.dev192+g4b9bea35a.d20260504 + sha256: 44c6bd65422fdc0a7d3167cf852107aeca15bf6687a44b57a6749ad553943f11 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 293eef8f..bb10a893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "64d656791230ebf20622c247bf2935880de2fcfd" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From e6066fac77def95bab11ec76cf8a8725e34a3d96 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 22:40:39 +0200 Subject: [PATCH 53/80] Roll #340 aca-model pin back to 63d2a38 (pre-pension-correction) The pension correction (aca-model 4ae4446 onwards) hits a pylcm architectural mismatch: solve uses per-target unqualified keys (`state_transitions[target_X]` is a dict where `next_aime`, `next_assets`, etc. are bare names), while simulate uses `flatten_regime_namespace` which qualifies everything to `__next_aime`. `imputed_pension_wealth_next_period`'s unqualified `next_aime` parameter resolves cleanly in solve but leaks as a kernel input in simulate. Fixing that asymmetry is a separate pylcm change (compile_all_simulate_functions should mirror solve's per-target dispatch). Pin to 63d2a38 (revert of MAX_CONSUMPTION margin, post-n_subjects- requirement) to unblock the benchmark CI on #340 without exercising the pension correction. The full correction can be re-enabled once the simulate-path mismatch is fixed. --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index a65be014..c00a3867 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=63d2a3819d08cf33f95c0149b6d3531b5292e729#63d2a3819d08cf33f95c0149b6d3531b5292e729 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=63d2a3819d08cf33f95c0149b6d3531b5292e729#63d2a3819d08cf33f95c0149b6d3531b5292e729 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev192+g4b9bea35a.d20260504 - sha256: 44c6bd65422fdc0a7d3167cf852107aeca15bf6687a44b57a6749ad553943f11 + version: 0.0.2.dev194+g3a6b1ecdd + sha256: c8f3e15320f6a2ee0895a47b2763b68ec56509b420d4aafa14ec19fef3dac441 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index bb10a893..6088bec9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "63d2a3819d08cf33f95c0149b6d3531b5292e729" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From a908c8405eb4e8fb5a653ca4db89f44d0675edfb Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 5 May 2026 05:28:19 +0200 Subject: [PATCH 54/80] get_next_state_function_for_simulation: per-target DAG mirrors solve MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The simulate path flattened all transitions into one DAG keyed by `__`, which prevented an unqualified `next_` parameter on a transition or auxiliary function from resolving across the per-target boundary. The solve path doesn't have this problem because it builds one DAG per target with bare `next_` keys. Switch the simulate compile to mirror that structure: one DAG per target with unqualified keys, then merge per-target outputs into a single flat dict keyed by `__next_` for the downstream `_update_states_for_subjects` consumer. Stochastic-transition wrappers keep their target-qualified `key___next_` and `weight___next_` arg names so multi-target callers still draw distinct realisations per target. This unblocks the aca-model pension imputation correction, where `imputed_pension_wealth_next_period(next_aime, ...)` consumes unqualified `next_aime` from the per-target DAG. Extends test_chained_state_transitions with a `model.simulate(...)` case — the original test only exercised solve, which masked this asymmetry. --- src/lcm/regime_building/next_state.py | 170 +++++++++++++++--------- tests/test_chained_state_transitions.py | 43 +++++- 2 files changed, 148 insertions(+), 65 deletions(-) diff --git a/src/lcm/regime_building/next_state.py b/src/lcm/regime_building/next_state.py index a8636992..4e0ba4b1 100644 --- a/src/lcm/regime_building/next_state.py +++ b/src/lcm/regime_building/next_state.py @@ -1,12 +1,13 @@ """Generate function that compute the next states for solution and simulation.""" +import inspect from collections.abc import Callable from types import MappingProxyType import jax import pandas as pd from dags import concatenate_functions, with_signature -from dags.tree import qname_from_tree_path, tree_path_from_qname +from dags.tree import qname_from_tree_path from jax import Array from lcm.grids import Grid @@ -69,6 +70,18 @@ def get_next_state_function_for_simulation( ) -> NextStateSimulationFunction: """Get function that computes the next states during the simulation. + Builds one DAG per target regime using unqualified `next_` keys, mirroring + the per-target structure of {func}`get_next_state_function_for_solution`. This + lets a transition function or auxiliary regime function consume another + transition's `next_` output via plain name resolution within the same + target's DAG. Per-target outputs are then merged into a single flat dict keyed + by `__next_`, matching the shape consumed downstream by + {func}`lcm.simulation.transitions._update_states_for_subjects`. + + Stochastic-transition wrappers expose `key___next_` and + `weight___next_` as external arguments so callers can pass a + distinct random key and pre-computed weight per target. + Args: transitions: Nested mapping of target regime names to transition functions. functions: Immutable mapping of auxiliary functions of a regime. @@ -78,30 +91,28 @@ def get_next_state_function_for_simulation( Returns: Function that computes the next states. Depends on states and actions of the - current period, and the regime parameters ("params"). If target is "simulate", - the function also depends on the dictionary of random keys ("keys"), which - corresponds to the names of stochastic next functions. + current period, and the regime parameters ("params"). The function also + depends on the dictionary of random keys ("keys") for stochastic transitions. """ - flat_transitions = flatten_regime_namespace(transitions) - - # For the simulation target, we need to extend the functions dictionary with - # stochastic next states functions and their weights. - extended_transitions = _extend_transitions_for_simulation( - all_grids=all_grids, - flat_transitions=flat_transitions, - variable_info=variable_info, - stochastic_transition_names=stochastic_transition_names, - ) - functions_to_concatenate = extended_transitions | dict(functions) + per_target_funcs: dict[RegimeName, Callable[..., dict[str, Array]]] = {} + for target, target_trans in transitions.items(): + extended = _extend_target_transitions_for_simulation( + target=target, + target_trans=target_trans, + all_grids=all_grids, + variable_info=variable_info, + stochastic_transition_names=stochastic_transition_names, + ) + per_target_funcs[target] = concatenate_functions( + functions=dict(extended) | dict(functions), + targets=list(extended.keys()), + return_type="dict", + enforce_signature=False, + set_annotations=True, + ) - return concatenate_functions( - functions=functions_to_concatenate, - targets=list(flat_transitions.keys()), - return_type="dict", - enforce_signature=False, - set_annotations=True, - ) + return _build_combined_simulation_function(per_target_funcs=per_target_funcs) def get_next_stochastic_weights_function( @@ -137,64 +148,101 @@ def get_next_stochastic_weights_function( ) -def _extend_transitions_for_simulation( +def _extend_target_transitions_for_simulation( *, + target: RegimeName, + target_trans: MappingProxyType[TransitionFunctionName, Callable[..., Array]], all_grids: MappingProxyType[RegimeName, MappingProxyType[StateOrActionName, Grid]], - flat_transitions: FunctionsMapping, variable_info: pd.DataFrame, stochastic_transition_names: frozenset[TransitionFunctionName], ) -> dict[TransitionFunctionName, Callable[..., Array]]: - """Extend the functions dictionary for the simulation target. + """Replace stochastic transitions for one target with realisation wrappers. + + Deterministic transitions are passed through unchanged. Stochastic transitions + are replaced by wrappers that draw a realisation from a precomputed weight + vector and a random key. The wrapper's external argument names use + target-qualified form (`key___`, + `weight___`) so multi-target callers can supply distinct + random keys per target. The dict key keeps the unqualified `next_` so + other transitions or regime functions in the same target's DAG can resolve + it by name. Args: + target: Target regime name. + target_trans: Mapping of unqualified `next_` transition names to + functions, restricted to one target regime. all_grids: Immutable mapping of regime names to Grid spec objects. - flat_transitions: Flattened mapping of transition names to functions. variable_info: Variable info of the current regime. stochastic_transition_names: Frozenset of stochastic transition function names. Returns: - Extended functions dictionary. + Extended transitions dictionary keyed by unqualified `next_` names. """ shock_names: set[ShockName] = set(variable_info.query("is_shock").index.to_list()) flat_grids = flatten_regime_namespace(all_grids) - discrete_stochastic_targets = [ - func_name - for func_name in flat_transitions - if tree_path_from_qname(func_name)[-1] in stochastic_transition_names - and tree_path_from_qname(func_name)[-1].removeprefix("next_") not in shock_names - ] - continuous_stochastic_targets = [ - func_name - for func_name in flat_transitions - if tree_path_from_qname(func_name)[-1] in stochastic_transition_names - and tree_path_from_qname(func_name)[-1].removeprefix("next_") in shock_names - ] - # Handle stochastic next states functions - # ---------------------------------------------------------------------------------- - # We generate stochastic next states functions that simulate the next state given - # a random key (think of a seed) and the weights corresponding to the labels of the - # stochastic variable. The weights are computed using the stochastic weight - # functions, which we add the to functions dict. `dags.concatenate_functions` then - # generates a function that computes the weights and simulates the next state in - # one go. - # ---------------------------------------------------------------------------------- - discrete_stochastic_next = { - name: _create_discrete_stochastic_next_func( - name=name, labels=flat_grids[name.replace("next_", "")].to_jax() - ) - for name in discrete_stochastic_targets - } - continuous_stochastic_next = { - name: _create_continuous_stochastic_next_func(name=name, flat_grids=flat_grids) - for name in continuous_stochastic_targets + extended: dict[TransitionFunctionName, Callable[..., Array]] = dict(target_trans) + for next_state_name in target_trans: + if next_state_name not in stochastic_transition_names: + continue + qname = qname_from_tree_path((target, next_state_name)) + raw_state_name = next_state_name.removeprefix("next_") + if raw_state_name in shock_names: + extended[next_state_name] = _create_continuous_stochastic_next_func( + name=qname, flat_grids=flat_grids + ) + else: + extended[next_state_name] = _create_discrete_stochastic_next_func( + name=qname, + labels=flat_grids[ + qname_from_tree_path((target, raw_state_name)) + ].to_jax(), + ) + return extended + + +def _build_combined_simulation_function( + *, + per_target_funcs: dict[RegimeName, Callable[..., dict[str, Array]]], +) -> NextStateSimulationFunction: + """Combine per-target simulation DAGs into one function with qualified outputs. + + Each per-target callable returns `{next_: array}` (unqualified). The + combined callable returns `{__next_: array}` after dispatching + inputs to the relevant per-target function based on its signature. + + Args: + per_target_funcs: Mapping of target regime names to per-target simulation + DAGs returning unqualified `{next_: array}` outputs. + + Returns: + A single callable that takes the union of all per-target inputs and + returns target-qualified outputs. + + """ + target_args: dict[RegimeName, tuple[str, ...]] = { + target: tuple(inspect.signature(func).parameters) + for target, func in per_target_funcs.items() } + all_args: list[str] = sorted({arg for args in target_args.values() for arg in args}) - # Overwrite regime transitions with generated stochastic next states functions - # ---------------------------------------------------------------------------------- - return ( - dict(flat_transitions) | discrete_stochastic_next | continuous_stochastic_next + @with_signature( + args=dict.fromkeys(all_args, "Array"), + return_annotation="dict[str, Array]", + enforce=False, ) + def combined(*args: Array, **kwargs: Array) -> dict[str, Array]: + if args: + kwargs = {**dict(zip(all_args, args, strict=False)), **kwargs} + out: dict[str, Array] = {} + for target, func in per_target_funcs.items(): + target_kwargs = {arg: kwargs[arg] for arg in target_args[target]} + target_out = func(**target_kwargs) + for next_state_name, value in target_out.items(): + out[qname_from_tree_path((target, next_state_name))] = value + return out + + return combined # ty: ignore[invalid-return-type] def _create_discrete_stochastic_next_func( diff --git a/tests/test_chained_state_transitions.py b/tests/test_chained_state_transitions.py index 4d8a61d4..a563b9ad 100644 --- a/tests/test_chained_state_transitions.py +++ b/tests/test_chained_state_transitions.py @@ -2,10 +2,16 @@ dags resolves dependencies between state-transition functions when they appear in the merged transitions+functions dict that -`get_next_state_function_for_solution` builds. The blocker fixed here is in -the upstream `create_regime_params_template`: it must not classify -`next_` names as regime-level fixed_params, otherwise param resolution -fails before dags ever runs. +`get_next_state_function_for_solution` builds. The blocker fixed in +`create_regime_params_template`: it must not classify `next_` names +as regime-level fixed_params, otherwise param resolution fails before dags +ever runs. + +The same chained-resolution must also work in the simulation path. Earlier, +`get_next_state_function_for_simulation` flattened transitions into a single +DAG keyed by `__`, which prevented an unqualified +`next_` parameter from resolving across the per-target boundary. The +fix mirrors the solve path's per-target structure. """ import jax.numpy as jnp @@ -98,3 +104,32 @@ def test_solve_resolves_chain_via_dags() -> None: for regime_to_V_arr in period_to_regime_to_V_arr.values(): for V_arr in regime_to_V_arr.values(): assert not jnp.any(jnp.isnan(V_arr)) + + +def test_simulate_resolves_chain_via_dags() -> None: + """`simulate()` runs and the simulation DAG resolves `next_aime → next_wealth`. + + The old `get_next_state_function_for_simulation` flattened transitions + into one DAG keyed by `__`, so an unqualified + `next_aime` parameter on `_next_wealth` could not resolve. The per-target + rewrite mirrors the solve path's structure. + """ + model = _build_model() + params = { + "discount_factor": 0.9, + "final_age_alive": 1.0, + } + initial_conditions = { + "age": jnp.array([0.0, 0.0]), + "aime": jnp.array([0.0, 1.0]), + "wealth": jnp.array([2.0, 3.0]), + "regime": jnp.array([_RegimeId.active, _RegimeId.active]), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + ) + df = result.to_dataframe().query('regime == "active"') + assert not df["wealth"].isna().any() + assert not df["aime"].isna().any() From 18d4ade9269f43ae2cb5137feaeec2921b72fbb6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 5 May 2026 05:28:51 +0200 Subject: [PATCH 55/80] Revert aca-model rollback: restore f09b5e3 pin (with pension correction) The previous commit fixes the pylcm simulate-path asymmetry that motivated the rollback. Restore the post-pension-correction aca-model pin so the benchmark CI exercises the corrected model and confirms the original NaN-at-age-51 motivation is resolved. Pin: f09b5e3 (per-target next_assets, dead-target terminal version, pension imputation correction wired). --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index c00a3867..4115222a 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=63d2a3819d08cf33f95c0149b6d3531b5292e729#63d2a3819d08cf33f95c0149b6d3531b5292e729 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=63d2a3819d08cf33f95c0149b6d3531b5292e729#63d2a3819d08cf33f95c0149b6d3531b5292e729 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev194+g3a6b1ecdd - sha256: c8f3e15320f6a2ee0895a47b2763b68ec56509b420d4aafa14ec19fef3dac441 + version: 0.0.2.dev195+ga908c8405.d20260505 + sha256: 44c6bd65422fdc0a7d3167cf852107aeca15bf6687a44b57a6749ad553943f11 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 6088bec9..bb10a893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "63d2a3819d08cf33f95c0149b6d3531b5292e729" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 6c64a77bba6ff5d4c994a8684237cfd354c34d83 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 5 May 2026 05:28:19 +0200 Subject: [PATCH 56/80] get_next_state_function_for_simulation: per-target DAG mirrors solve MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The simulate path flattened all transitions into one DAG keyed by `__`, which prevented an unqualified `next_` parameter on a transition or auxiliary function from resolving across the per-target boundary. The solve path doesn't have this problem because it builds one DAG per target with bare `next_` keys. Switch the simulate compile to mirror that structure: one DAG per target with unqualified keys, then merge per-target outputs into a single flat dict keyed by `__next_` for the downstream `_update_states_for_subjects` consumer. Stochastic-transition wrappers keep their target-qualified `key___next_` and `weight___next_` arg names so multi-target callers still draw distinct realisations per target. This unblocks the aca-model pension imputation correction, where `imputed_pension_wealth_next_period(next_aime, ...)` consumes unqualified `next_aime` from the per-target DAG. Extends test_chained_state_transitions with a `model.simulate(...)` case — the original test only exercised solve, which masked this asymmetry. --- src/lcm/regime_building/next_state.py | 170 +++++++++++++++--------- tests/test_chained_state_transitions.py | 43 +++++- 2 files changed, 148 insertions(+), 65 deletions(-) diff --git a/src/lcm/regime_building/next_state.py b/src/lcm/regime_building/next_state.py index a8636992..4e0ba4b1 100644 --- a/src/lcm/regime_building/next_state.py +++ b/src/lcm/regime_building/next_state.py @@ -1,12 +1,13 @@ """Generate function that compute the next states for solution and simulation.""" +import inspect from collections.abc import Callable from types import MappingProxyType import jax import pandas as pd from dags import concatenate_functions, with_signature -from dags.tree import qname_from_tree_path, tree_path_from_qname +from dags.tree import qname_from_tree_path from jax import Array from lcm.grids import Grid @@ -69,6 +70,18 @@ def get_next_state_function_for_simulation( ) -> NextStateSimulationFunction: """Get function that computes the next states during the simulation. + Builds one DAG per target regime using unqualified `next_` keys, mirroring + the per-target structure of {func}`get_next_state_function_for_solution`. This + lets a transition function or auxiliary regime function consume another + transition's `next_` output via plain name resolution within the same + target's DAG. Per-target outputs are then merged into a single flat dict keyed + by `__next_`, matching the shape consumed downstream by + {func}`lcm.simulation.transitions._update_states_for_subjects`. + + Stochastic-transition wrappers expose `key___next_` and + `weight___next_` as external arguments so callers can pass a + distinct random key and pre-computed weight per target. + Args: transitions: Nested mapping of target regime names to transition functions. functions: Immutable mapping of auxiliary functions of a regime. @@ -78,30 +91,28 @@ def get_next_state_function_for_simulation( Returns: Function that computes the next states. Depends on states and actions of the - current period, and the regime parameters ("params"). If target is "simulate", - the function also depends on the dictionary of random keys ("keys"), which - corresponds to the names of stochastic next functions. + current period, and the regime parameters ("params"). The function also + depends on the dictionary of random keys ("keys") for stochastic transitions. """ - flat_transitions = flatten_regime_namespace(transitions) - - # For the simulation target, we need to extend the functions dictionary with - # stochastic next states functions and their weights. - extended_transitions = _extend_transitions_for_simulation( - all_grids=all_grids, - flat_transitions=flat_transitions, - variable_info=variable_info, - stochastic_transition_names=stochastic_transition_names, - ) - functions_to_concatenate = extended_transitions | dict(functions) + per_target_funcs: dict[RegimeName, Callable[..., dict[str, Array]]] = {} + for target, target_trans in transitions.items(): + extended = _extend_target_transitions_for_simulation( + target=target, + target_trans=target_trans, + all_grids=all_grids, + variable_info=variable_info, + stochastic_transition_names=stochastic_transition_names, + ) + per_target_funcs[target] = concatenate_functions( + functions=dict(extended) | dict(functions), + targets=list(extended.keys()), + return_type="dict", + enforce_signature=False, + set_annotations=True, + ) - return concatenate_functions( - functions=functions_to_concatenate, - targets=list(flat_transitions.keys()), - return_type="dict", - enforce_signature=False, - set_annotations=True, - ) + return _build_combined_simulation_function(per_target_funcs=per_target_funcs) def get_next_stochastic_weights_function( @@ -137,64 +148,101 @@ def get_next_stochastic_weights_function( ) -def _extend_transitions_for_simulation( +def _extend_target_transitions_for_simulation( *, + target: RegimeName, + target_trans: MappingProxyType[TransitionFunctionName, Callable[..., Array]], all_grids: MappingProxyType[RegimeName, MappingProxyType[StateOrActionName, Grid]], - flat_transitions: FunctionsMapping, variable_info: pd.DataFrame, stochastic_transition_names: frozenset[TransitionFunctionName], ) -> dict[TransitionFunctionName, Callable[..., Array]]: - """Extend the functions dictionary for the simulation target. + """Replace stochastic transitions for one target with realisation wrappers. + + Deterministic transitions are passed through unchanged. Stochastic transitions + are replaced by wrappers that draw a realisation from a precomputed weight + vector and a random key. The wrapper's external argument names use + target-qualified form (`key___`, + `weight___`) so multi-target callers can supply distinct + random keys per target. The dict key keeps the unqualified `next_` so + other transitions or regime functions in the same target's DAG can resolve + it by name. Args: + target: Target regime name. + target_trans: Mapping of unqualified `next_` transition names to + functions, restricted to one target regime. all_grids: Immutable mapping of regime names to Grid spec objects. - flat_transitions: Flattened mapping of transition names to functions. variable_info: Variable info of the current regime. stochastic_transition_names: Frozenset of stochastic transition function names. Returns: - Extended functions dictionary. + Extended transitions dictionary keyed by unqualified `next_` names. """ shock_names: set[ShockName] = set(variable_info.query("is_shock").index.to_list()) flat_grids = flatten_regime_namespace(all_grids) - discrete_stochastic_targets = [ - func_name - for func_name in flat_transitions - if tree_path_from_qname(func_name)[-1] in stochastic_transition_names - and tree_path_from_qname(func_name)[-1].removeprefix("next_") not in shock_names - ] - continuous_stochastic_targets = [ - func_name - for func_name in flat_transitions - if tree_path_from_qname(func_name)[-1] in stochastic_transition_names - and tree_path_from_qname(func_name)[-1].removeprefix("next_") in shock_names - ] - # Handle stochastic next states functions - # ---------------------------------------------------------------------------------- - # We generate stochastic next states functions that simulate the next state given - # a random key (think of a seed) and the weights corresponding to the labels of the - # stochastic variable. The weights are computed using the stochastic weight - # functions, which we add the to functions dict. `dags.concatenate_functions` then - # generates a function that computes the weights and simulates the next state in - # one go. - # ---------------------------------------------------------------------------------- - discrete_stochastic_next = { - name: _create_discrete_stochastic_next_func( - name=name, labels=flat_grids[name.replace("next_", "")].to_jax() - ) - for name in discrete_stochastic_targets - } - continuous_stochastic_next = { - name: _create_continuous_stochastic_next_func(name=name, flat_grids=flat_grids) - for name in continuous_stochastic_targets + extended: dict[TransitionFunctionName, Callable[..., Array]] = dict(target_trans) + for next_state_name in target_trans: + if next_state_name not in stochastic_transition_names: + continue + qname = qname_from_tree_path((target, next_state_name)) + raw_state_name = next_state_name.removeprefix("next_") + if raw_state_name in shock_names: + extended[next_state_name] = _create_continuous_stochastic_next_func( + name=qname, flat_grids=flat_grids + ) + else: + extended[next_state_name] = _create_discrete_stochastic_next_func( + name=qname, + labels=flat_grids[ + qname_from_tree_path((target, raw_state_name)) + ].to_jax(), + ) + return extended + + +def _build_combined_simulation_function( + *, + per_target_funcs: dict[RegimeName, Callable[..., dict[str, Array]]], +) -> NextStateSimulationFunction: + """Combine per-target simulation DAGs into one function with qualified outputs. + + Each per-target callable returns `{next_: array}` (unqualified). The + combined callable returns `{__next_: array}` after dispatching + inputs to the relevant per-target function based on its signature. + + Args: + per_target_funcs: Mapping of target regime names to per-target simulation + DAGs returning unqualified `{next_: array}` outputs. + + Returns: + A single callable that takes the union of all per-target inputs and + returns target-qualified outputs. + + """ + target_args: dict[RegimeName, tuple[str, ...]] = { + target: tuple(inspect.signature(func).parameters) + for target, func in per_target_funcs.items() } + all_args: list[str] = sorted({arg for args in target_args.values() for arg in args}) - # Overwrite regime transitions with generated stochastic next states functions - # ---------------------------------------------------------------------------------- - return ( - dict(flat_transitions) | discrete_stochastic_next | continuous_stochastic_next + @with_signature( + args=dict.fromkeys(all_args, "Array"), + return_annotation="dict[str, Array]", + enforce=False, ) + def combined(*args: Array, **kwargs: Array) -> dict[str, Array]: + if args: + kwargs = {**dict(zip(all_args, args, strict=False)), **kwargs} + out: dict[str, Array] = {} + for target, func in per_target_funcs.items(): + target_kwargs = {arg: kwargs[arg] for arg in target_args[target]} + target_out = func(**target_kwargs) + for next_state_name, value in target_out.items(): + out[qname_from_tree_path((target, next_state_name))] = value + return out + + return combined # ty: ignore[invalid-return-type] def _create_discrete_stochastic_next_func( diff --git a/tests/test_chained_state_transitions.py b/tests/test_chained_state_transitions.py index 4d8a61d4..a563b9ad 100644 --- a/tests/test_chained_state_transitions.py +++ b/tests/test_chained_state_transitions.py @@ -2,10 +2,16 @@ dags resolves dependencies between state-transition functions when they appear in the merged transitions+functions dict that -`get_next_state_function_for_solution` builds. The blocker fixed here is in -the upstream `create_regime_params_template`: it must not classify -`next_` names as regime-level fixed_params, otherwise param resolution -fails before dags ever runs. +`get_next_state_function_for_solution` builds. The blocker fixed in +`create_regime_params_template`: it must not classify `next_` names +as regime-level fixed_params, otherwise param resolution fails before dags +ever runs. + +The same chained-resolution must also work in the simulation path. Earlier, +`get_next_state_function_for_simulation` flattened transitions into a single +DAG keyed by `__`, which prevented an unqualified +`next_` parameter from resolving across the per-target boundary. The +fix mirrors the solve path's per-target structure. """ import jax.numpy as jnp @@ -98,3 +104,32 @@ def test_solve_resolves_chain_via_dags() -> None: for regime_to_V_arr in period_to_regime_to_V_arr.values(): for V_arr in regime_to_V_arr.values(): assert not jnp.any(jnp.isnan(V_arr)) + + +def test_simulate_resolves_chain_via_dags() -> None: + """`simulate()` runs and the simulation DAG resolves `next_aime → next_wealth`. + + The old `get_next_state_function_for_simulation` flattened transitions + into one DAG keyed by `__`, so an unqualified + `next_aime` parameter on `_next_wealth` could not resolve. The per-target + rewrite mirrors the solve path's structure. + """ + model = _build_model() + params = { + "discount_factor": 0.9, + "final_age_alive": 1.0, + } + initial_conditions = { + "age": jnp.array([0.0, 0.0]), + "aime": jnp.array([0.0, 1.0]), + "wealth": jnp.array([2.0, 3.0]), + "regime": jnp.array([_RegimeId.active, _RegimeId.active]), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + ) + df = result.to_dataframe().query('regime == "active"') + assert not df["wealth"].isna().any() + assert not df["aime"].isna().any() From c969b1a3c60c5b614f499b435e92f25b209660b7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 05:13:49 +0200 Subject: [PATCH 57/80] next_state: real signature for combined; fix trivially-passing test Two cleanups in the per-target simulate path landed via #339/#340. (1) `_build_combined_simulation_function` previously wrapped a `(*args, **kwargs)` shim with `with_signature`, then zipped positional args back to names. Functionally fine, but the inner function's signature did not match the advertised one, requiring a `# ty: ignore` and protocol-mismatch warnings. Synthesise a real function with named positional-or-keyword parameters via `exec()` (same pattern as `dataclasses` and `attrs`). `vmap_1d` and other introspecting callers now see a faithful signature. (2) `test_get_next_state_function_with_simulate_target` claimed to assert qualified output keys (`mock__next_a`, `mock__next_b`) but compared against unqualified `{"a": ..., "b": ...}`. The test passed only because `pybaum.tree_equal` ignores dict keys when leaf counts and values match. Rewrite to assert keys directly via `set(got.keys())` and check both regime outputs; drop the unused stochastic scaffolding (`f_weight_b`, `f_b` returning None, key kwarg). --- src/lcm/regime_building/next_state.py | 30 ++++++++++++++++-------- tests/test_next_state.py | 33 ++++++++++++++------------- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/lcm/regime_building/next_state.py b/src/lcm/regime_building/next_state.py index 4e0ba4b1..b384e486 100644 --- a/src/lcm/regime_building/next_state.py +++ b/src/lcm/regime_building/next_state.py @@ -1,7 +1,7 @@ """Generate function that compute the next states for solution and simulation.""" import inspect -from collections.abc import Callable +from collections.abc import Callable, Mapping from types import MappingProxyType import jax @@ -224,16 +224,11 @@ def _build_combined_simulation_function( target: tuple(inspect.signature(func).parameters) for target, func in per_target_funcs.items() } - all_args: list[str] = sorted({arg for args in target_args.values() for arg in args}) - - @with_signature( - args=dict.fromkeys(all_args, "Array"), - return_annotation="dict[str, Array]", - enforce=False, + all_args: tuple[str, ...] = tuple( + sorted({arg for args in target_args.values() for arg in args}) ) - def combined(*args: Array, **kwargs: Array) -> dict[str, Array]: - if args: - kwargs = {**dict(zip(all_args, args, strict=False)), **kwargs} + + def _dispatch(kwargs: Mapping[str, Array]) -> dict[str, Array]: out: dict[str, Array] = {} for target, func in per_target_funcs.items(): target_kwargs = {arg: kwargs[arg] for arg in target_args[target]} @@ -242,6 +237,21 @@ def combined(*args: Array, **kwargs: Array) -> dict[str, Array]: out[qname_from_tree_path((target, next_state_name))] = value return out + # Generate a real function with named positional-or-keyword parameters so + # `vmap_1d` (and any other introspecting caller) sees a faithful signature + # rather than a `(*args, **kwargs)` shim. Mirrors the strategy used by + # `dataclasses` and `attrs` to synthesise typed `__init__` methods. + src = ( + f"def combined({', '.join(all_args)}) -> 'dict[str, Array]':\n" + f" return _dispatch({{{', '.join(f'{a!r}: {a}' for a in all_args)}}})\n" + ) + namespace: dict[str, object] = {"_dispatch": _dispatch} + exec(src, namespace) # noqa: S102 + combined = namespace["combined"] + combined.__annotations__ = { + **dict.fromkeys(all_args, "Array"), + "return": "dict[str, Array]", + } return combined # ty: ignore[invalid-return-type] diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 162650cb..a3b98fa0 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -3,7 +3,6 @@ import jax.numpy as jnp import pandas as pd -from pybaum import tree_equal from lcm.ages import AgeGrid from lcm.grids import DiscreteGrid @@ -13,10 +12,7 @@ get_next_state_function_for_simulation, get_next_state_function_for_solution, ) -from lcm.typing import ( - ContinuousState, - FloatND, -) +from lcm.typing import ContinuousState from tests.test_models.deterministic.regression import dead, working_life @@ -59,14 +55,18 @@ def test_get_next_state_function_with_solve_target(): def test_get_next_state_function_with_simulate_target(): + """Outputs are namespaced by target regime: `__`. + + The combined function dispatches inputs to the per-target DAG and + qualifies each output with its target name, matching the format + `_update_states_for_subjects` consumes. + """ + def f_a(state: ContinuousState) -> ContinuousState: - return state[0] + return state * 2.0 def f_b(state: ContinuousState) -> ContinuousState: - return None # ty: ignore[invalid-return-type] - - def f_weight_b(state: ContinuousState) -> FloatND: - return jnp.array([0.0, 1.0]) + return state + 1.0 @dataclass class MockCategory: @@ -76,11 +76,12 @@ class MockCategory: all_grids = MappingProxyType( {"mock": MappingProxyType({"b": DiscreteGrid(MockCategory)})} ) - variable_info = pd.DataFrame({"is_shock": [False]}) + variable_info = pd.DataFrame({"is_shock": [False]}, index=["b"]) transitions = MappingProxyType( {"mock": MappingProxyType({"next_a": f_a, "next_b": f_b})} ) - functions = MappingProxyType({"utility": lambda: 0, "f_weight_b": f_weight_b}) + functions = MappingProxyType({"utility": lambda: 0}) + got_func = get_next_state_function_for_simulation( transitions=transitions, # ty: ignore[invalid-argument-type] functions=functions, # ty: ignore[invalid-argument-type] @@ -88,11 +89,11 @@ class MockCategory: variable_info=variable_info, ) - key = jnp.arange(2, dtype="uint32") - got = got_func(state=jnp.arange(2), key_b=key) + got = got_func(state=jnp.array([1.0, 2.0])) - expected = {"a": jnp.array([0]), "b": jnp.array([1])} - assert tree_equal(expected, got) + assert set(got.keys()) == {"mock__next_a", "mock__next_b"} + assert jnp.array_equal(got["mock__next_a"], jnp.array([2.0, 4.0])) + assert jnp.array_equal(got["mock__next_b"], jnp.array([2.0, 3.0])) def test_create_stochastic_next_func(): From 8a2ad4f50559422f56db01998135d88191623e9a Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 05:41:35 +0200 Subject: [PATCH 58/80] Address #342 review: simulate-path uses concatenate_functions; cleanups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mj023's inline (next_state.py): drop hand-rolled `_build_combined_simulation_function` (`with_signature` over a `(*args, **kwargs)` shim that zipped positional args back to names). Use `concatenate_functions` directly on `per_target_funcs`. The combined function now returns the nested form `{target_regime_name: {next_: array}}` instead of the flat `__` shape, which is more natural to construct and exactly what the consumer needs. Move the per-target flattening into the consumer `_update_states_for_subjects`: iterate the outer regime keys, strip `next_` from inner keys, rebuild the flat `__` lookup into `all_states`. Update the test fixture and the `NextStateSimulationFunction` protocol's return type accordingly, plus broaden the `vmap_1d` `FunctionWithArrayReturn` typevar bound to admit the new nested-mapping return. timmens' inline (regime_template.py): fold `next_state_names` into `H_variables` rather than carrying it as a separate "exemption". The docstring no longer needs the multi-paragraph rationale; the unified `H_variables` set documents itself: regime functions, `period`, `age`, `E_next_V`, and `next_` outputs are all internal wiring that pylcm resolves at evaluation time, never user-facing fixed_params. timmens' inline (next_state.py:99): rename `target_trans` to `target_transitions` in the loop variable and the `_extend_target_transitions_for_simulation` signature. timmens' inline (test_chained_state_transitions.py): rewrite both tests to assert behavior in user-facing terms instead of rehearsing the prior bug. `test_solve_..._returns_finite_value_function` checks that the active regime's V is finite. `test_simulate_..._yields_ expected_next_wealth` checks that `next_wealth_t = wealth_t - consumption_t + 0.1 * next_aime_t` holds period-over-period in the simulated DataFrame, which can only succeed if the chained dependency `next_aime → next_wealth` was wired correctly. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/params/regime_template.py | 27 +++--- src/lcm/regime_building/next_state.py | 75 ++++------------ src/lcm/simulation/transitions.py | 34 +++++--- src/lcm/typing.py | 7 +- src/lcm/utils/dispatchers.py | 5 +- tests/simulation/test_update_states.py | 12 ++- tests/test_chained_state_transitions.py | 111 +++++++++++++----------- 7 files changed, 127 insertions(+), 144 deletions(-) diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 8f9765a5..51d3c29b 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -22,16 +22,10 @@ def create_regime_params_template( ) -> RegimeParamsTemplate: """Create parameter template from a regime specification. - Discover parameters from function signatures via `dags.tree`. Parameters are - function arguments that are not states, actions, other regime functions, - `next_` transition outputs, or special variables (period, age, - E_next_V). - - The `next_` exemption lets a state transition consume the output of - another state transition: dags resolves the chain at evaluation time - (`get_next_state_function_for_solution` merges all transitions and DAG - functions into a single dict before calling `concatenate_functions`), so - these names must not surface as user-facing fixed_params. + Discover parameters from function signatures via `dags.tree`. Parameters + are function arguments that are not states, actions, regime functions, + `next_` outputs, or special variables (`period`, `age`, + `E_next_V`). For `SolveSimulateFunctionPair` entries, the template contains the **union** of both variants' parameters so the user can provide a single flat params @@ -48,11 +42,14 @@ def create_regime_params_template( The regime parameter template with type annotations as values. """ - H_variables = {*regime.functions, "period", "age", "E_next_V"} - next_state_names = {f"next_{name}" for name in regime.states} - variables = ( - H_variables | set(regime.actions) | set(regime.states) | next_state_names - ) + H_variables = { + *regime.functions, + "period", + "age", + "E_next_V", + *(f"next_{name}" for name in regime.states), + } + variables = H_variables | set(regime.actions) | set(regime.states) function_params: dict[FunctionName, dict[str, str]] = {} diff --git a/src/lcm/regime_building/next_state.py b/src/lcm/regime_building/next_state.py index 4e0ba4b1..537a5e5d 100644 --- a/src/lcm/regime_building/next_state.py +++ b/src/lcm/regime_building/next_state.py @@ -1,6 +1,5 @@ """Generate function that compute the next states for solution and simulation.""" -import inspect from collections.abc import Callable from types import MappingProxyType @@ -74,9 +73,8 @@ def get_next_state_function_for_simulation( the per-target structure of {func}`get_next_state_function_for_solution`. This lets a transition function or auxiliary regime function consume another transition's `next_` output via plain name resolution within the same - target's DAG. Per-target outputs are then merged into a single flat dict keyed - by `__next_`, matching the shape consumed downstream by - {func}`lcm.simulation.transitions._update_states_for_subjects`. + target's DAG. The combined function returns a nested mapping keyed by target + regime name, with each inner dict using unqualified `next_` keys. Stochastic-transition wrappers expose `key___next_` and `weight___next_` as external arguments so callers can pass a @@ -93,13 +91,14 @@ def get_next_state_function_for_simulation( Function that computes the next states. Depends on states and actions of the current period, and the regime parameters ("params"). The function also depends on the dictionary of random keys ("keys") for stochastic transitions. + Returns `{target_regime_name: {next_: array}}`. """ per_target_funcs: dict[RegimeName, Callable[..., dict[str, Array]]] = {} - for target, target_trans in transitions.items(): + for target, target_transitions in transitions.items(): extended = _extend_target_transitions_for_simulation( target=target, - target_trans=target_trans, + target_transitions=target_transitions, all_grids=all_grids, variable_info=variable_info, stochastic_transition_names=stochastic_transition_names, @@ -112,7 +111,13 @@ def get_next_state_function_for_simulation( set_annotations=True, ) - return _build_combined_simulation_function(per_target_funcs=per_target_funcs) + return concatenate_functions( + functions=per_target_funcs, + targets=list(per_target_funcs.keys()), + return_type="dict", + enforce_signature=False, + set_annotations=True, + ) def get_next_stochastic_weights_function( @@ -151,7 +156,7 @@ def get_next_stochastic_weights_function( def _extend_target_transitions_for_simulation( *, target: RegimeName, - target_trans: MappingProxyType[TransitionFunctionName, Callable[..., Array]], + target_transitions: MappingProxyType[TransitionFunctionName, Callable[..., Array]], all_grids: MappingProxyType[RegimeName, MappingProxyType[StateOrActionName, Grid]], variable_info: pd.DataFrame, stochastic_transition_names: frozenset[TransitionFunctionName], @@ -169,8 +174,8 @@ def _extend_target_transitions_for_simulation( Args: target: Target regime name. - target_trans: Mapping of unqualified `next_` transition names to - functions, restricted to one target regime. + target_transitions: Mapping of unqualified `next_` transition names + to functions, restricted to one target regime. all_grids: Immutable mapping of regime names to Grid spec objects. variable_info: Variable info of the current regime. stochastic_transition_names: Frozenset of stochastic transition function names. @@ -181,8 +186,10 @@ def _extend_target_transitions_for_simulation( """ shock_names: set[ShockName] = set(variable_info.query("is_shock").index.to_list()) flat_grids = flatten_regime_namespace(all_grids) - extended: dict[TransitionFunctionName, Callable[..., Array]] = dict(target_trans) - for next_state_name in target_trans: + extended: dict[TransitionFunctionName, Callable[..., Array]] = dict( + target_transitions + ) + for next_state_name in target_transitions: if next_state_name not in stochastic_transition_names: continue qname = qname_from_tree_path((target, next_state_name)) @@ -201,50 +208,6 @@ def _extend_target_transitions_for_simulation( return extended -def _build_combined_simulation_function( - *, - per_target_funcs: dict[RegimeName, Callable[..., dict[str, Array]]], -) -> NextStateSimulationFunction: - """Combine per-target simulation DAGs into one function with qualified outputs. - - Each per-target callable returns `{next_: array}` (unqualified). The - combined callable returns `{__next_: array}` after dispatching - inputs to the relevant per-target function based on its signature. - - Args: - per_target_funcs: Mapping of target regime names to per-target simulation - DAGs returning unqualified `{next_: array}` outputs. - - Returns: - A single callable that takes the union of all per-target inputs and - returns target-qualified outputs. - - """ - target_args: dict[RegimeName, tuple[str, ...]] = { - target: tuple(inspect.signature(func).parameters) - for target, func in per_target_funcs.items() - } - all_args: list[str] = sorted({arg for args in target_args.values() for arg in args}) - - @with_signature( - args=dict.fromkeys(all_args, "Array"), - return_annotation="dict[str, Array]", - enforce=False, - ) - def combined(*args: Array, **kwargs: Array) -> dict[str, Array]: - if args: - kwargs = {**dict(zip(all_args, args, strict=False)), **kwargs} - out: dict[str, Array] = {} - for target, func in per_target_funcs.items(): - target_kwargs = {arg: kwargs[arg] for arg in target_args[target]} - target_out = func(**target_kwargs) - for next_state_name, value in target_out.items(): - out[qname_from_tree_path((target, next_state_name))] = value - return out - - return combined # ty: ignore[invalid-return-type] - - def _create_discrete_stochastic_next_func( *, name: str, labels: DiscreteState ) -> StochasticNextFunction: diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 18bdbb6b..c494d996 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -19,6 +19,8 @@ from lcm.typing import ( ActionName, Bool1D, + ContinuousState, + DiscreteState, FlatRegimeParams, Int1D, RegimeName, @@ -257,19 +259,23 @@ def random_id( def _update_states_for_subjects( *, all_states: MappingProxyType[str, Array], - computed_next_states: MappingProxyType[str, Array], + computed_next_states: MappingProxyType[ + RegimeName, MappingProxyType[str, DiscreteState | ContinuousState] + ], subject_indices: Bool1D, ) -> MappingProxyType[str, Array]: """Update the global states dictionary with next states for specific subjects. - The transition functions add a 'next_' prefix to state variable names. This function - removes that prefix and updates only the entries corresponding to the specified - subjects, leaving other subjects' states unchanged. + Outputs from `get_next_state_function_for_simulation` are nested by target + regime, with inner keys carrying the `next_` prefix + (`{target: {next_: array}}`). Strip the prefix and combine with the + target name into the flat `__` key used in `all_states`, + updating only the entries corresponding to the specified subjects. Args: all_states: Current states for all subjects across all regimes. - computed_next_states: Newly computed states (with 'next_' prefix) for specific - subjects. + computed_next_states: Newly computed states, nested by target regime + and keyed by `next_`, for specific subjects. subject_indices: Indices of subjects whose states should be updated. Returns: @@ -277,13 +283,13 @@ def _update_states_for_subjects( """ updated_states = dict(all_states) - for next_state_name, next_state_values in computed_next_states.items(): - # Namespaced outputs: "regime__next_wealth" → "regime__wealth" - state_name = next_state_name.replace("__next_", "__", 1) - updated_states[state_name] = jnp.where( - subject_indices, - next_state_values, - all_states[state_name], - ) + for target, target_next_states in computed_next_states.items(): + for next_state_name, next_state_values in target_next_states.items(): + state_name = f"{target}__{next_state_name.removeprefix('next_')}" + updated_states[state_name] = jnp.where( + subject_indices, + next_state_values, + all_states[state_name], + ) return MappingProxyType(updated_states) diff --git a/src/lcm/typing.py b/src/lcm/typing.py index c73b4c33..310c7bc6 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -188,14 +188,17 @@ def __call__(self, **kwargs: Array) -> Array: ... class NextStateSimulationFunction(Protocol): """The function that computes the next states during the simulation. - Only used for type checking. + Returns a nested mapping `{target_regime: {next_: array}}`. Only + used for type checking. """ def __call__( self, **kwargs: Array | Period | Age, - ) -> MappingProxyType[str, DiscreteState | ContinuousState]: ... + ) -> MappingProxyType[ + RegimeName, MappingProxyType[str, DiscreteState | ContinuousState] + ]: ... class ActiveFunction(Protocol): diff --git a/src/lcm/utils/dispatchers.py b/src/lcm/utils/dispatchers.py index f019eb59..22367d64 100644 --- a/src/lcm/utils/dispatchers.py +++ b/src/lcm/utils/dispatchers.py @@ -17,7 +17,10 @@ "FunctionWithArrayReturn", bound=Callable[ ..., - Array | tuple[Array, Array] | MappingProxyType[str, Array], + Array + | tuple[Array, Array] + | MappingProxyType[str, Array] + | MappingProxyType[str, MappingProxyType[str, Array]], ], ) diff --git a/tests/simulation/test_update_states.py b/tests/simulation/test_update_states.py index 69809fd5..f626fc84 100644 --- a/tests/simulation/test_update_states.py +++ b/tests/simulation/test_update_states.py @@ -14,7 +14,7 @@ def test_update_states_strips_next_prefix(): ) computed_next_states = MappingProxyType( { - "working__next_wealth": jnp.array([15.0, 25.0, 35.0]), + "working": MappingProxyType({"next_wealth": jnp.array([15.0, 25.0, 35.0])}), } ) subject_indices = jnp.array([True, False, True]) @@ -38,8 +38,12 @@ def test_update_states_multiple_regimes_and_states(): ) computed_next_states = MappingProxyType( { - "working__next_wealth": jnp.array([15.0, 25.0]), - "working__next_health": jnp.array([1.5, 2.5]), + "working": MappingProxyType( + { + "next_wealth": jnp.array([15.0, 25.0]), + "next_health": jnp.array([1.5, 2.5]), + } + ), } ) subject_indices = jnp.array([True, True]) @@ -64,7 +68,7 @@ def test_update_states_no_subjects_selected(): ) computed_next_states = MappingProxyType( { - "r__next_wealth": jnp.array([99.0, 99.0]), + "r": MappingProxyType({"next_wealth": jnp.array([99.0, 99.0])}), } ) subject_indices = jnp.array([False, False]) diff --git a/tests/test_chained_state_transitions.py b/tests/test_chained_state_transitions.py index a563b9ad..df1954ac 100644 --- a/tests/test_chained_state_transitions.py +++ b/tests/test_chained_state_transitions.py @@ -1,20 +1,13 @@ -"""End-to-end check that one state transition can consume another's output. - -dags resolves dependencies between state-transition functions when they -appear in the merged transitions+functions dict that -`get_next_state_function_for_solution` builds. The blocker fixed in -`create_regime_params_template`: it must not classify `next_` names -as regime-level fixed_params, otherwise param resolution fails before dags -ever runs. - -The same chained-resolution must also work in the simulation path. Earlier, -`get_next_state_function_for_simulation` flattened transitions into a single -DAG keyed by `__`, which prevented an unqualified -`next_` parameter from resolving across the per-target boundary. The -fix mirrors the solve path's per-target structure. +"""A state transition can consume another transition's output. + +The model defines `next_aime` and `next_wealth`, with `next_wealth` +referencing `next_aime` in its signature. dags resolves the chain at +evaluation time. These tests assert the chain produces the +mathematically expected next-period values in solve and simulate. """ import jax.numpy as jnp +import numpy as np from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical from lcm.typing import DiscreteAction, FloatND, ScalarInt @@ -33,17 +26,12 @@ class _RegimeId: def _next_aime(aime: float, labor_supply: DiscreteAction) -> FloatND: - """AIME accumulates only when working.""" + """AIME accumulates by 1 unit when working, 0 otherwise.""" return aime + jnp.where(labor_supply == _LaborSupply.work, 1.0, 0.0) def _next_wealth(wealth: float, consumption: float, next_aime: FloatND) -> FloatND: - """Next-period wealth depends on next-period AIME (the chained transition). - - The economically interesting use is `pia = f(next_aime)` feeding - next_wealth via a pension correction. Here we keep the dependency simple - so the test focuses on the wiring, not the economics. - """ + """Next-period wealth depends on next-period AIME (the chained transition).""" return wealth - consumption + 0.1 * next_aime @@ -86,50 +74,69 @@ def _build_model() -> Model: ) -def test_solve_resolves_chain_via_dags() -> None: - """`solve()` runs and dags wires `next_aime → next_wealth` correctly. +def test_solve_with_chained_transitions_returns_finite_value_function() -> None: + """`solve()` returns a finite value function for the active regime. - Before the fix, `_resolve_fixed_params` raised - `InvalidParamsError: Missing required parameter: - 'active__next_wealth__next_aime'` because `create_regime_params_template` - classified `next_aime` (a `next_` reference inside another - transition's signature) as a regime-level fixed_param. + With `discount_factor=0.9` and a 2-period horizon, each state in the + active regime must produce a finite expected value: the chain + `next_aime → next_wealth` resolves and feeds back into the agent's + next-period continuation value. """ model = _build_model() - params = { - "discount_factor": 0.9, - "final_age_alive": 1.0, - } + params = {"discount_factor": 0.9, "final_age_alive": 1.0} + period_to_regime_to_V_arr = model.solve(params=params) - for regime_to_V_arr in period_to_regime_to_V_arr.values(): - for V_arr in regime_to_V_arr.values(): - assert not jnp.any(jnp.isnan(V_arr)) + V_active = period_to_regime_to_V_arr[0]["active"] + assert jnp.all(jnp.isfinite(V_active)) -def test_simulate_resolves_chain_via_dags() -> None: - """`simulate()` runs and the simulation DAG resolves `next_aime → next_wealth`. - The old `get_next_state_function_for_simulation` flattened transitions - into one DAG keyed by `__`, so an unqualified - `next_aime` parameter on `_next_wealth` could not resolve. The per-target - rewrite mirrors the solve path's structure. +def test_simulate_with_chained_transitions_yields_expected_next_wealth() -> None: + """`next_wealth_t = wealth_t - c_t + 0.1 * next_aime_t` holds in simulation. + + For each subject, `next_aime` is the value of `_next_aime(aime, ls)` at + the chosen labor supply, and `next_wealth` must equal + `wealth - c + 0.1 * next_aime` exactly. Solving for the optimum, the + test then checks that successive `wealth` values in the simulated + DataFrame satisfy this identity, which can only hold if the chained + dependency was wired correctly. """ model = _build_model() - params = { - "discount_factor": 0.9, - "final_age_alive": 1.0, - } + params = {"discount_factor": 0.9, "final_age_alive": 1.0} initial_conditions = { "age": jnp.array([0.0, 0.0]), "aime": jnp.array([0.0, 1.0]), "wealth": jnp.array([2.0, 3.0]), "regime": jnp.array([_RegimeId.active, _RegimeId.active]), } - result = model.simulate( - params=params, - initial_conditions=initial_conditions, - period_to_regime_to_V_arr=None, + + df = ( + model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + ) + .to_dataframe() + .query('regime == "active"') + .sort_values(["subject_id", "period"]) + .reset_index(drop=True) ) - df = result.to_dataframe().query('regime == "active"') - assert not df["wealth"].isna().any() - assert not df["aime"].isna().any() + + for subject_id in df["subject_id"].unique(): + rows = df.loc[df["subject_id"] == subject_id].sort_values("period") + for i in range(len(rows) - 1): + prev = rows.iloc[i] + curr = rows.iloc[i + 1] + work = prev["labor_supply"] == "work" + expected_next_aime = float(prev["aime"]) + (1.0 if work else 0.0) + expected_next_wealth = ( + float(prev["wealth"]) + - float(prev["consumption"]) + + 0.1 * expected_next_aime + ) + np.testing.assert_allclose( + float(curr["aime"]), expected_next_aime, atol=1e-6 + ) + np.testing.assert_allclose( + float(curr["wealth"]), expected_next_wealth, atol=1e-6 + ) From 17347c8fdb335f77d28f9a0fdf099f37ae04d5d9 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:04:33 +0200 Subject: [PATCH 59/80] Get rid of H_variables entirely in regime_template. --- src/lcm/params/regime_template.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 51d3c29b..67ddfe75 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -17,21 +17,18 @@ ) -def create_regime_params_template( - regime: Regime, -) -> RegimeParamsTemplate: +def create_regime_params_template(regime: Regime) -> RegimeParamsTemplate: """Create parameter template from a regime specification. Discover parameters from function signatures via `dags.tree`. Parameters are function arguments that are not states, actions, regime functions, - `next_` outputs, or special variables (`period`, `age`, - `E_next_V`). + `next_` outputs, or special variables (`period`, `age`, `E_next_V`). For `SolveSimulateFunctionPair` entries, the template contains the **union** of both variants' parameters so the user can provide a single flat params dict that satisfies both phases. - Grids with runtime-supplied values (IrregSpacedGrid without points, + Grids with runtime-supplied values (`IrregSpacedGrid` without points, `_ShockGrid` without full shock_params) add entries to the template under pseudo-function keys matching the state or action name. @@ -42,14 +39,15 @@ def create_regime_params_template( The regime parameter template with type annotations as values. """ - H_variables = { + variables = { + *set(regime.states), + *set(regime.actions), *regime.functions, + *(f"next_{name}" for name in regime.states), "period", "age", "E_next_V", - *(f"next_{name}" for name in regime.states), } - variables = H_variables | set(regime.actions) | set(regime.states) function_params: dict[FunctionName, dict[str, str]] = {} From 076b9b696fe08a06d230ec1f75778bbd9f5cfe29 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:09:04 +0200 Subject: [PATCH 60/80] Bump .ai-instructions: TDD always; behavior-focused docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 5cd2187 on .ai-instructions/main adds a "Test-Driven Development — always" section to AGENTS.md (with two further subsections on behavior-focused docstrings and concrete-value assertions). pylcm's own .ai-instructions submodule is independent of aca-dev's, so this bump is needed for agents working in pylcm directly to see the new guidance — prompted by the pylcm #342 review feedback that motivated the TDD section in the first place. Also picks up 528a011..135a3cd: pinned-tool-version bump and Tier B gitignore additions for pytask.lock files. Co-Authored-By: Claude Opus 4.7 (1M context) --- .ai-instructions | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ai-instructions b/.ai-instructions index 528a0118..5cd21877 160000 --- a/.ai-instructions +++ b/.ai-instructions @@ -1 +1 @@ -Subproject commit 528a0118ff9e02233bfc073da891b60e81b34754 +Subproject commit 5cd218770c463b02698f90129443799abdc00864 From 164a88bb6a9d29addc5323113cc5c4c7cbaf3d81 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:23:24 +0200 Subject: [PATCH 61/80] AGENTS.md: inline TDD-always testing section directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The TDD guidance landed canonically in .ai-instructions/AGENTS.md (5cd2187), but reaches pylcm only via the @-include chain through the .ai-instructions submodule. Inline directly in pylcm/AGENTS.md so the policy is load-bearing in the file an agent sees first when working in pylcm, not contingent on the submodule pointer being current. Restructure: promote `Testing` to its own top-level section before `Development Notes` (was a `### Testing Style` subsection inside it). Three new TDD subsections sit above the existing pytest-mechanics bullets: - Test-Driven Development — always (red-green-refactor cycle, applied to features / bug fixes / refactors). - Test docstrings — describe behavior, not history (pretend the reader has never seen the PR). - Concrete-value assertions (assert what the result is, not just that it didn't crash). Verbatim from .ai-instructions/AGENTS.md so the two stay in sync. Co-Authored-By: Claude Opus 4.7 (1M context) --- AGENTS.md | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 5 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 7aea2844..6a575e9d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -294,6 +294,72 @@ initial_conditions = { - `model.n_periods` - Number of periods in the model (derived from `ages`) - `model.regime_names_to_ids` - Immutable mapping from regime names to integer indices +## Testing + +### Test-Driven Development — always + +**Always write the test first, watch it fail, then implement.** No exceptions for new +behavior or bug fixes. Tests are not an afterthought, they are the spec. + +The cycle: + +1. **Red.** Write a failing test that asserts the desired behavior in user-facing terms. + Run it. Confirm it fails for the *right* reason (the missing behavior — not a typo, + not an import error). +1. **Green.** Write the smallest amount of code that makes the test pass. +1. **Refactor.** Clean up while keeping the test green. + +Apply per case: + +- **New feature** → red-green-refactor. +- **Bug fix** → reproduce as a failing test before writing the fix. The test then + prevents regression. +- **Refactor (no behavior change)** → existing tests are the spec. Keep them green + before, during, and after. No new test needed if behavior is unchanged; if you find a + behavior gap, fill it with a new test *before* refactoring. + +### Test docstrings — describe behavior, not history + +Test docstrings state what *should* be true, in user-facing terms. Pretend the reader +has never seen the PR. They should not need to. + +```python +# Good — behavior, in plain language +def test_simulate_with_chained_transitions_yields_expected_next_wealth(): + """`next_wealth_t = wealth_t - c_t + 0.1 * next_aime_t` holds in simulation.""" + + +# Bad — rehearses the prior bug or implementation history +def test_solve_resolves_chain_via_dags(): + """Before the fix, `_resolve_fixed_params` raised + `InvalidParamsError: Missing required parameter: ...` because + `create_regime_params_template` classified ...""" +``` + +Rule of thumb: **would the docstring still make sense in 9 months without the PR +context?** If not, rewrite it. + +### Concrete-value assertions + +Assert *what* the result is, not just that it didn't crash. + +```python +# Good — analytical value with explicit tolerance +np.testing.assert_allclose(curr["wealth"], expected_next_wealth, atol=1e-6) + +# Bad — passes whether the math is right or not +assert not jnp.any(jnp.isnan(V_arr)) +assert df["wealth"].notna().all() +``` + +`not isnan` and `no exception raised` belong in CI smoke tests, not in the unit tests +for the feature itself. + +### Mechanics + +- Use plain pytest functions, never test classes (`class TestFoo`) +- Use `@pytest.mark.parametrize` for test variations + ## Development Notes ### JAX Integration @@ -401,11 +467,6 @@ Code structure should be self-evident from function names and ordering. display math, and `[text](url)` for links. Never use rST-style ``` `` code `` ```, `:math:`, `:func:`, or `` `link `_ ``. -### Testing Style - -- Use plain pytest functions, never test classes (`class TestFoo`) -- Use `@pytest.mark.parametrize` for test variations - ### Plotting - Always use **plotly** for visualizations, never matplotlib. Use `plotly.graph_objects` From 5261e29c02c86f6f2cf9b0b809ce33180a5d2703 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:29:24 +0200 Subject: [PATCH 62/80] Address #339 review: drop field-count test; tighten claudish docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit timmens' inline (test_nan_diagnostics.py): delete `test_diagnostic_row_holds_only_python_scalars`. The test asserted that `_DiagnosticRow.__dataclass_fields__` equals a fixed three-name set with a docstring rehearsing why those specific fields were removed. That's testing the implementation, not behavior — an OOM regression would not be caught by counting dataclass fields. The load-bearing constraint ("no device-backed references on the row") already lives in the `_DiagnosticRow` docstring, where readers will find it. timmens' inline (solve_brute.py:431): tighten `_DiagnosticRow` and `_emit_post_loop_diagnostics` docstrings. Both rehearsed the prior design ("The earlier design captured ...", "16 GB device that was OOMing on the previous stack-and-flush pattern") in second-paragraph "before the fix" framing. Drop those paragraphs; keep the forward-looking constraint ("only Python-scalar metadata, no device references") and the actual mechanism ("two `.item()` calls decide whether to enter the per-row failure path"). Same anti-pattern the new TDD-always section in AGENTS.md just codified for tests — apply it to source docstrings as well. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 23 ++++++++--------------- tests/test_nan_diagnostics.py | 22 ---------------------- 2 files changed, 8 insertions(+), 37 deletions(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 653b013d..791da20a 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -419,16 +419,11 @@ class _DiagnosticRow: """Metadata captured during the backward-induction loop. Holds only Python-scalar metadata — no device-array references — so - every (regime, period) row stays at a few bytes. The expensive bits - (state-action space, next-period V mapping, params, the - `compute_intermediates` closure) are reconstructed lazily on the + every (regime, period) row stays at a few bytes regardless of grid + size. State-action space, next-period V mapping, regime params, and + the `compute_intermediates` closure are reconstructed lazily on the failure path from `internal_regimes`, `internal_params`, and the - partial `solution` that has been built up to that point. - - The earlier design captured those device-backed objects directly on - each row, which pinned every period's V template in device memory - until the post-loop flush — at production grid sizes that hits OOM - well before the loop completes. + partial `solution` built up to that point. """ regime_name: RegimeName @@ -454,12 +449,10 @@ def _emit_post_loop_diagnostics( ) -> None: """Flush async diagnostics: raise on NaN, warn on Inf, log debug stats. - Two host transfers (the `.item()` calls on the running scalars) - decide whether we enter the per-row failure-path localisation. On - a healthy solve neither inner walk runs and no per-row scalar is - materialised — the property that lets a production-sized solve at - `log_level="warning"` fit on a 16 GB device that was OOMing on the - previous stack-and-flush pattern. + The two `.item()` calls on the running scalars decide whether to + enter the per-row failure-path localisation. On a healthy solve + neither inner walk runs and no per-row scalar is materialised, so + device memory stays bounded by the V templates currently in flight. """ if running_any_nan.item(): _raise_first_nan_row( diff --git a/tests/test_nan_diagnostics.py b/tests/test_nan_diagnostics.py index 779f90c8..47c42f32 100644 --- a/tests/test_nan_diagnostics.py +++ b/tests/test_nan_diagnostics.py @@ -146,28 +146,6 @@ def borrowing_constraint( return model, params -def test_diagnostic_row_holds_only_python_scalars() -> None: - """`_DiagnosticRow` must not pin device-backed objects. - - Earlier the row stored `state_action_space`, `next_regime_to_V_arr`, - `regime_params`, and a `compute_intermediates` closure (which itself - captured the state_action_space). Across periods these refs accumulated, - pinning every period's V template in device memory until the post-loop - flush — at production grid sizes that hits OOM well before the loop - completes. The failure path now reconstructs those objects from `solution`, - `internal_regimes`, and `internal_params` instead. - """ - from lcm.solution.solve_brute import _DiagnosticRow # noqa: PLC0415 - - expected = {"regime_name", "period", "age"} - actual = set(_DiagnosticRow.__dataclass_fields__) - assert actual == expected, ( - f"_DiagnosticRow must hold only {expected}; got {actual}. Adding " - "device-backed fields here pins per-period V templates in device " - "memory and re-introduces the OOM during long backward inductions." - ) - - def test_nan_diagnostics_end_to_end() -> None: """Real model: `model.solve()` attaches a diagnostics dict when V has NaN. From 110cc0b594039a0aaf255a595301368ee3b2d027 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:39:58 +0200 Subject: [PATCH 63/80] regime_template: collapse H_variables into single variables set MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per timmens' #342 review (and a polish pass on the prior fold-into- H_variables commit): drop the H_variables intermediate. There's no remaining reason to keep two sets — every name in the union is "internal wiring that pylcm resolves at evaluation time, never user-facing fixed_params" — so build `variables` directly with all six categories (states, actions, regime functions, next_ outputs, period, age, E_next_V). Also: collapse the function signature to one line, tighten the docstring (`(period, age, E_next_V)` on one line, backtick-quote `IrregSpacedGrid`). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/params/regime_template.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 51d3c29b..67ddfe75 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -17,21 +17,18 @@ ) -def create_regime_params_template( - regime: Regime, -) -> RegimeParamsTemplate: +def create_regime_params_template(regime: Regime) -> RegimeParamsTemplate: """Create parameter template from a regime specification. Discover parameters from function signatures via `dags.tree`. Parameters are function arguments that are not states, actions, regime functions, - `next_` outputs, or special variables (`period`, `age`, - `E_next_V`). + `next_` outputs, or special variables (`period`, `age`, `E_next_V`). For `SolveSimulateFunctionPair` entries, the template contains the **union** of both variants' parameters so the user can provide a single flat params dict that satisfies both phases. - Grids with runtime-supplied values (IrregSpacedGrid without points, + Grids with runtime-supplied values (`IrregSpacedGrid` without points, `_ShockGrid` without full shock_params) add entries to the template under pseudo-function keys matching the state or action name. @@ -42,14 +39,15 @@ def create_regime_params_template( The regime parameter template with type annotations as values. """ - H_variables = { + variables = { + *set(regime.states), + *set(regime.actions), *regime.functions, + *(f"next_{name}" for name in regime.states), "period", "age", "E_next_V", - *(f"next_{name}" for name in regime.states), } - variables = H_variables | set(regime.actions) | set(regime.states) function_params: dict[FunctionName, dict[str, str]] = {} From a7b9e9a0097911eedfb9b2474968121de7824833 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:42:31 +0200 Subject: [PATCH 64/80] solve_brute: drop misleading "~2 MB each" magic number from comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit V_arr shape is model-dependent (states × grid points × regimes), so the per-period isnan/isinf intermediate buffer size is too. The qualitative point — these allocations stack up across periods if not freed — is what matters; the size figure was a distraction that implied a fixed scale only true on whichever model the comment was written against. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 791da20a..fb21d7c4 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -83,12 +83,13 @@ def solve( # Per-period `block_until_ready()` after the running update forces # the device kernel to finish before the next period dispatches. # This frees the per-period `isnan(V_arr)` / `isinf(V_arr)` - # intermediate buffers (~2 MB each at production grid sizes) so - # they don't stack up. `block_until_ready` is a *device-only* sync - # — no host transfer, no PCIe round-trip — so it doesn't - # re-introduce the per-period host stalls that #334 removed; if - # `max_Q_over_a` (the dominant per-period kernel) is in flight, - # the call returns immediately when the small reduction is done. + # intermediate buffers (V_arr-shaped, so model-dependent) so they + # don't stack up across the loop. `block_until_ready` is a + # *device-only* sync — no host transfer, no PCIe round-trip — so + # it doesn't re-introduce the per-period host stalls that #334 + # removed; if `max_Q_over_a` (the dominant per-period kernel) is + # in flight, the call returns immediately when the small reduction + # is done. # # One host transfer per stat at end of solve (`.item()` on the # running scalars) decides whether to enter the failure-path From 16b570f756d2a536a49906e6e4cc005b083bf487 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:53:35 +0200 Subject: [PATCH 65/80] =?UTF-8?q?AGENTS.md:=20docstring=20style=20?= =?UTF-8?q?=E2=80=94=20describe=20state,=20no=20PR=20refs,=20bulleted=20li?= =?UTF-8?q?sts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Codifies three patterns observed across this PR's review cleanups and #339's, where source docstrings and inline comments rehearsed prior implementations, cited PR numbers, and quoted model-specific figures. None of that survives the 9-month-without-PR-context test. The new "## Docstring Style" section sits next to "## Testing" and adds three subsections, each with good/bad examples drawn from real recent code: - Describe state, not history (no "earlier", "previously", "the old design", "before the fix"). - No PR numbers, no model-specific magic numbers (PRs rot; "~2 MB at production grid sizes" only holds on one model/box). - Bulleted lists for enumerated cases (one bullet per case beats running prose for log levels, regime kinds, dispatch strategies). Cross-references the existing "Test docstrings — describe behavior, not history" subsection rather than duplicating the same rule with test-specific framing. Also bump .ai-instructions submodule pointer to 609ac4a, which landed the same docstring-style guidance on the canonical version. Co-Authored-By: Claude Opus 4.7 (1M context) --- .ai-instructions | 2 +- AGENTS.md | 78 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/.ai-instructions b/.ai-instructions index 5cd21877..609ac4a8 160000 --- a/.ai-instructions +++ b/.ai-instructions @@ -1 +1 @@ -Subproject commit 5cd218770c463b02698f90129443799abdc00864 +Subproject commit 609ac4a8d9262f93594f36ea382d30cd94ea07a4 diff --git a/AGENTS.md b/AGENTS.md index 6a575e9d..1929fced 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -360,6 +360,84 @@ for the feature itself. - Use plain pytest functions, never test classes (`class TestFoo`) - Use `@pytest.mark.parametrize` for test variations +## Docstring Style + +Docstrings and inline comments describe the code's *current* state in user-facing terms. +The 9-month-without-PR-context reader is the audience: a docstring that survives that +test stays useful; one that rehearses the diff or the prior implementation rots +immediately. + +This applies to **all** docstrings and comments — source and tests. For tests +specifically, see also "Test docstrings — describe behavior, not history" above. + +### Describe state, not history + +State what is true now. Don't reference prior designs, removed code, or what was +changed. Words like "earlier", "previously", "now", "formerly", "the old", "before the +fix" are red flags. + +```python +# Good — forward-looking constraint +class _DiagnosticRow: + """Metadata captured during the backward-induction loop. + + Holds only Python-scalar metadata — no device-array references — + so every (regime, period) row stays at a few bytes regardless of + grid size. + """ + + +# Bad — rehearses prior design +class _DiagnosticRow: + """Metadata captured during the backward-induction loop. + + Holds only Python-scalar metadata. The earlier design captured + state_action_space and a closure directly on each row, which + pinned every period's V template in device memory until the + post-loop flush. + """ +``` + +### No PR numbers, no model-specific magic numbers + +PR references (`#334 removed the host stalls`, `the bug was fixed in #42`) rot as the +codebase evolves and provide no useful signal to a reader who isn't already in context. +Magic numbers tied to a specific model size or hardware +(`~2 MB at production grid sizes`, `fits on a 16 GB device`) imply a fixed scale that's +only true on whichever model/box the comment was written against. State the qualitative +dependency instead. + +```python +# Good — qualitative dependency +# Frees per-period intermediate buffers (V_arr-shaped, so +# model-dependent) so they don't stack up across the loop. + +# Bad — PR reference + magic number +# Frees per-period intermediate buffers (~2 MB each at production +# grid sizes) so we don't re-introduce the host stalls that #334 +# removed. +``` + +### Bulleted lists for enumerated cases + +When describing a fixed set of cases (log levels, regime kinds, parameter types, +dispatch strategies), use one bullet per case rather than running prose. Bullets scan; +prose hides cases. + +```python +# Good — scannable +# Gate falls out of the public log level: +# - `"off"` ⇒ nothing (skips even the NaN fail-fast) +# - `"warning"` / `"progress"` ⇒ NaN/Inf only +# - `"debug"` ⇒ adds the min/max/mean trio + + +# Bad — buried in prose +# Gate falls out of the public log level: `"off"` ⇒ nothing, +# `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the +# min/max/mean trio. `"off"` skips even the NaN fail-fast. +``` + ## Development Notes ### JAX Integration From ff652618bc3a9f152a45f04985d8558d49a5df3a Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 06:54:17 +0200 Subject: [PATCH 66/80] =?UTF-8?q?solve=5Fbrute:=20apply=20docstring=20styl?= =?UTF-8?q?e=20=E2=80=94=20drop=20PR=20ref,=20magic=20number;=20bullet=20l?= =?UTF-8?q?ist?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three follow-ups to the AGENTS.md docstring-style section just inlined on this branch via the cascade-merge: - Drop "the per-period host stalls that #334 removed" from the diagnostics-accumulator comment. PR refs rot. Restate as "doesn't introduce a host stall" with the same forward-looking explanation. - Drop "(~2 MB each at production grid sizes)" — V_arr shape is model-dependent (states × grid points × regimes), so any fixed byte figure misleads on every other model. Restated as "(V_arr-shaped, so model-dependent)" earlier in the same comment; this commit's change is the second history-framed phrase from the same block. - Reflow the log-level prose paragraph into a bulleted list. The three cases (`"off"` / `"warning"` ∪ `"progress"` / `"debug"`) read faster as bullets, and the `"off"` qualifier ("skips even the NaN fail-fast") fits inline rather than needing the separate sentence about contracts and estimation loops. Also drop "without the deferred-stack fan-in that previously OOMed at production sizes" from `tests/solution/test_diagnostics.py` — same "previously" anti-pattern. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 16 +++++++--------- tests/solution/test_diagnostics.py | 3 +-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index fb21d7c4..d8152e13 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -86,21 +86,19 @@ def solve( # intermediate buffers (V_arr-shaped, so model-dependent) so they # don't stack up across the loop. `block_until_ready` is a # *device-only* sync — no host transfer, no PCIe round-trip — so - # it doesn't re-introduce the per-period host stalls that #334 - # removed; if `max_Q_over_a` (the dominant per-period kernel) is - # in flight, the call returns immediately when the small reduction - # is done. + # it doesn't introduce a host stall: if `max_Q_over_a` (the + # dominant per-period kernel) is in flight, the call returns + # immediately when the small reduction is done. # # One host transfer per stat at end of solve (`.item()` on the # running scalars) decides whether to enter the failure-path # localisation. On a healthy solve no per-row materialisation # happens. # - # Gate falls out of the public log level: `"off"` ⇒ nothing, - # `"warning"` / `"progress"` ⇒ NaN/Inf only, `"debug"` ⇒ adds the - # min/max/mean trio. `"off"` skips even the NaN fail-fast — that - # is the documented contract of `"off"` (suppress all output) and - # is what makes the level useful for tight estimation loops. + # Gate falls out of the public log level: + # - `"off"` ⇒ nothing (skips even the NaN fail-fast) + # - `"warning"` / `"progress"` ⇒ NaN/Inf only + # - `"debug"` ⇒ adds the min/max/mean trio diagnostics_enabled = logger.isEnabledFor(logging.WARNING) stats_enabled = logger.isEnabledFor(logging.DEBUG) diagnostic_rows: list[_DiagnosticRow] = [] diff --git a/tests/solution/test_diagnostics.py b/tests/solution/test_diagnostics.py index dbdc9497..262ae0e1 100644 --- a/tests/solution/test_diagnostics.py +++ b/tests/solution/test_diagnostics.py @@ -1,8 +1,7 @@ """Tests for the post-loop diagnostics path in `solve_brute.solve`. These cover: -- happy path at `log_level="warning"` runs without raising and without - the deferred-stack fan-in that previously OOMed at production sizes; +- happy path at `log_level="warning"` runs without raising; - NaN-bearing solves raise `InvalidValueFunctionError` and the message identifies the offending `(regime, age)`; - `log_level="debug"` emits one stat line per `(regime, period)`; From e9b7cc5824c9ecc7e8e7a4def6cfbdedb1312380 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 07:02:30 +0200 Subject: [PATCH 67/80] test_next_state: update to assert nested output shape from #339 merge The cascade-merge from #339 changed `get_next_state_function_for_simulation` to return a nested mapping `{target: {next_: array}}` (the flat `__` form went away with `_build_combined_simulation_function`). The simulate-target test on this branch was set up against the flat shape (added as `c969b1a` before the merge); update its assertions to the nested form so the test still validates the actual output structure. --- tests/test_next_state.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_next_state.py b/tests/test_next_state.py index a3b98fa0..a1f688c5 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -55,10 +55,11 @@ def test_get_next_state_function_with_solve_target(): def test_get_next_state_function_with_simulate_target(): - """Outputs are namespaced by target regime: `__`. + """Outputs are nested by target regime: `{target: {next_state: array}}`. The combined function dispatches inputs to the per-target DAG and - qualifies each output with its target name, matching the format + returns a mapping from target regime name to that target's + `{next_: array}` outputs, matching what `_update_states_for_subjects` consumes. """ @@ -91,9 +92,10 @@ class MockCategory: got = got_func(state=jnp.array([1.0, 2.0])) - assert set(got.keys()) == {"mock__next_a", "mock__next_b"} - assert jnp.array_equal(got["mock__next_a"], jnp.array([2.0, 4.0])) - assert jnp.array_equal(got["mock__next_b"], jnp.array([2.0, 3.0])) + assert set(got.keys()) == {"mock"} + assert set(got["mock"].keys()) == {"next_a", "next_b"} + assert jnp.array_equal(got["mock"]["next_a"], jnp.array([2.0, 4.0])) + assert jnp.array_equal(got["mock"]["next_b"], jnp.array([2.0, 3.0])) def test_create_stochastic_next_func(): From 838473ed7d71d6b380627c2a0777791acf8db7ab Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 09:39:15 +0200 Subject: [PATCH 68/80] validate_initial_conditions: per-constraint admissibility in error message Replace the "Active constraints: ..." list (which printed *all* regime constraints whether binding or not) with a per-constraint boolean column appended to the infeasible-subjects DataFrame. For each constraint and each infeasible subject, the entry is True when that constraint individually admits at least one action and False when it rejects every action by itself. The distinction matters for diagnosis. With a single binding constraint, exactly one column reads False. With a joint-rejection case (each constraint admits some action; their intersection is empty), every column reads True and the user knows the issue is the intersection, not any individual constraint. Implementation: `_per_constraint_feasibility` builds a per-constraint feasibility function via `_get_feasibility(constraints={name:func})` and reuses `_batched_feasibility_check` to run it on the infeasible subjects only. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/simulation/initial_conditions.py | 81 ++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 7 deletions(-) diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 127059e1..00a7298f 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -667,13 +667,71 @@ def _check_combo(action_kw: dict[str, Array]) -> Array: if not infeasible_indices: return None + per_constraint_admits_any = _per_constraint_feasibility( + internal_regime=internal_regime, + subject_states=subject_states, + action_kwargs=action_kwargs, + filtered_params=filtered_params, + flat_actions=flat_actions, + idx_arr=idx_arr, + infeasible_indices=infeasible_indices, + ) + return _format_infeasibility_message( infeasible_indices=infeasible_indices, internal_regime=internal_regime, regime_name=regime_name, initial_states=initial_states, state_names=state_names, + per_constraint_admits_any=per_constraint_admits_any, + ) + + +def _per_constraint_feasibility( + *, + internal_regime: InternalRegime, + subject_states: Mapping[str, Array], + action_kwargs: Mapping[str, Array], + filtered_params: Mapping[str, object], + flat_actions: Mapping[ActionName, Array], + idx_arr: Array, + infeasible_indices: Sequence[int], +) -> dict[str, np.ndarray]: + """Per-constraint feasibility for the infeasible subjects. + + For each constraint, returns a boolean array (one entry per infeasible + subject) indicating whether that constraint *individually* admits at + least one action. Combined with the regime's feasibility verdict, this + distinguishes "constraint X rejects every action by itself" from + "constraints jointly reject everything despite each admitting some". + """ + constraints = internal_regime.simulate_functions.constraints + functions = internal_regime.simulate_functions.functions + if not constraints or not subject_states: + return {} + + infeasible_positions = np.flatnonzero( + np.isin(np.asarray(idx_arr), np.asarray(infeasible_indices)) ) + infeasible_states = { + name: arr[infeasible_positions] for name, arr in subject_states.items() + } + + out: dict[str, np.ndarray] = {} + for name, constraint_func in constraints.items(): + single_feasibility = _get_feasibility( + functions=functions, + constraints=MappingProxyType({name: constraint_func}), + ) + any_feasible = _batched_feasibility_check( + feasibility_func=single_feasibility, + subject_states=infeasible_states, + action_kwargs=action_kwargs, + filtered_params=filtered_params, + flat_actions=flat_actions, + ) + out[name] = np.asarray(any_feasible) + return out def _raise_feasibility_type_error( @@ -729,6 +787,7 @@ def _format_infeasibility_message( regime_name: RegimeName, initial_states: Mapping[str, Array], state_names: Sequence[str], + per_constraint_admits_any: Mapping[str, np.ndarray], ) -> str: """Format an error message for infeasible subjects. @@ -738,6 +797,12 @@ def _format_infeasibility_message( regime_name: Name of the regime. initial_states: Mapping of state names to arrays. state_names: List of state variable names. + per_constraint_admits_any: Mapping from constraint name to a boolean + array (one entry per infeasible subject) — True where that + constraint *individually* admits at least one action. False + entries identify constraints that reject every action on their + own; rows with all-True entries are infeasible only because the + constraints jointly reject the action set. Returns: Formatted error message string. @@ -759,9 +824,10 @@ def _format_infeasibility_message( if isinstance(grid, DiscreteGrid) and name in state_df.columns: state_df[name] = [grid.categories[int(v)] for v in state_df[name]] - # Constraint names - constraint_names = list(internal_regime.simulate_functions.constraints.keys()) - constraints_str = "\n".join(f" - {name}" for name in constraint_names) + # Append one boolean column per constraint: True = admits ≥ 1 action, + # False = rejects every action by itself for that subject. + for name, mask in per_constraint_admits_any.items(): + state_df[name] = list(mask) # Truncate for large groups n = len(infeasible_indices) @@ -775,10 +841,11 @@ def _format_infeasibility_message( return ( f"All actions are infeasible for {n} subject(s) " f"in regime '{regime_name}'.\n\n" - f"Active constraints:\n{constraints_str}\n\n" - f"Infeasible subjects:\n{table_str}\n\n" - f"No action combination satisfies all constraints for these " - f"initial states." + f"Per-constraint admissibility (True = constraint admits ≥ 1 " + f"action by itself; False = constraint rejects every action):\n" + f"{table_str}\n\n" + f"No action combination satisfies all constraints jointly for " + f"these initial states." ) From e4cae2aa57d4bf568b8ebbade55d44571e3a086f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 09:50:52 +0200 Subject: [PATCH 69/80] _per_constraint_feasibility: filter args per single-constraint feasibility --- src/lcm/simulation/initial_conditions.py | 48 ++++++++++++++++++++---- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 00a7298f..6ef75a89 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -670,8 +670,7 @@ def _check_combo(action_kw: dict[str, Array]) -> Array: per_constraint_admits_any = _per_constraint_feasibility( internal_regime=internal_regime, subject_states=subject_states, - action_kwargs=action_kwargs, - filtered_params=filtered_params, + regime_params=regime_params, flat_actions=flat_actions, idx_arr=idx_arr, infeasible_indices=infeasible_indices, @@ -687,12 +686,28 @@ def _check_combo(action_kw: dict[str, Array]) -> Array: ) +def _admits_any_action( + *, + feasibility_func: Callable[..., Array], + action_kwargs: Mapping[str, Array], + params: Mapping[str, object], +) -> bool: + """Return True iff the feasibility function admits ≥ 1 action under params.""" + if action_kwargs: + + def _check_combo(action_kw: dict[str, Array]) -> Array: + return feasibility_func(**action_kw, **params) + + per_combo = jax.vmap(_check_combo)(action_kwargs) + return bool(jnp.any(per_combo)) + return bool(feasibility_func(**params)) + + def _per_constraint_feasibility( *, internal_regime: InternalRegime, subject_states: Mapping[str, Array], - action_kwargs: Mapping[str, Array], - filtered_params: Mapping[str, object], + regime_params: Mapping[str, object], flat_actions: Mapping[ActionName, Array], idx_arr: Array, infeasible_indices: Sequence[int], @@ -704,6 +719,11 @@ def _per_constraint_feasibility( least one action. Combined with the regime's feasibility verdict, this distinguishes "constraint X rejects every action by itself" from "constraints jointly reject everything despite each admitting some". + + Each constraint's feasibility function has its own argument set (a + subset of the combined feasibility's union); filter `subject_states`, + `action_kwargs`, and `filtered_params` per constraint so dags doesn't + raise on stray kwargs. """ constraints = internal_regime.simulate_functions.constraints functions = internal_regime.simulate_functions.functions @@ -723,11 +743,25 @@ def _per_constraint_feasibility( functions=functions, constraints=MappingProxyType({name: constraint_func}), ) + accepted = get_union_of_args([single_feasibility]) + single_states = {k: v for k, v in infeasible_states.items() if k in accepted} + single_actions = {k: v for k, v in flat_actions.items() if k in accepted} + single_params = {k: v for k, v in regime_params.items() if k in accepted} + n = len(infeasible_indices) + if not single_states: + # Action-only / parameter-only constraint — identical for all subjects. + admits_any = _admits_any_action( + feasibility_func=single_feasibility, + action_kwargs=single_actions, + params=single_params, + ) + out[name] = np.full(n, admits_any, dtype=bool) + continue any_feasible = _batched_feasibility_check( feasibility_func=single_feasibility, - subject_states=infeasible_states, - action_kwargs=action_kwargs, - filtered_params=filtered_params, + subject_states=single_states, + action_kwargs=single_actions, + filtered_params=single_params, flat_actions=flat_actions, ) out[name] = np.asarray(any_feasible) From 50f78f0898e436585aadf0447425bc9fceaf1df2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 14:12:01 +0200 Subject: [PATCH 70/80] Address #340 review: docstring style, TDD, x64+AOT guard, period dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Docstring/comment cleanup (drop counterfactuals and history rehashes per AGENTS.md "Describe state, not history"): - discrete.py to_jax: rephrase "would otherwise produce" as positive constraint - initial_conditions.py: replace "Without this..." with the invariant the cast enforces - compile.py module docstring: drop "Mirrors the pattern in solve_brute" cross-reference - compile.py argmax-over-active-periods comment: drop aca-model-specific forced-canwork / pre-FRA example - compile.py _build_next_state_args: drop the now-stale comment about Python-int period (all sites use jnp.int32) Bullet-list dispatch cases (per AGENTS.md "Bulleted lists for enumerated cases"): - model.py n_subjects field: 4 dispatch cases as bullets, plus an explicit param-shape stability contract for MSM-style use - model.py _resolve_simulate_internal_regimes: 3 cases as bullets TDD/test style (per AGENTS.md "Testing"): - test_simulate_aot.py: split test_n_subjects_none_keeps_lazy_behavior into two single-assertion tests - test_simulate_aot.py: rename test_simulate_caches_recompiled_size_no_ second_warning -> test_simulate_warns_only_once_per_mismatching_size (the implementation only adds to _warned_n_subjects, not the cache) - test_simulate_aot.py: split first/second-call compile counting into test_simulate_first_matching_call_populates_aot_cache and test_simulate_second_matching_call_does_not_invoke_compile, dropping the weak `n_first > 0` smoke check - test_simulate_aot.py: docstrings on the validation tests, return-type annotations on counting_compile and _build_initial_conditions - test_int_dtype_invariants.py: per-test docstrings - compile.py _get_regime_V_shapes: docstring matching the file's other private helpers - model.py: class-body annotations and field docstrings for _simulate_compile_cache, _warned_n_subjects, _simulate_compile_lock x64 + AOT fail-fast (issue 19): - model_processing._fail_if_x64_with_aot raises ModelInitializationError when n_subjects is set under jax_enable_x64=True; the AOT path pins integer dtypes to int32, x64 promotes to int64, so the cached signature would not match runtime - AOT tests get an autouse fixture that disables x64 for the duration period dtype consistency (issue 20): - _build_next_state_args / _build_crtp_args: lower with period=jnp.int32(0) to match the argmax path - transitions.calculate_next_states / calculate_next_regime_membership: pass period=jnp.int32(period) at runtime so the runtime call matches the AOT abstract signature Concurrency (issue 18): - Model gains a threading.Lock guarding check-then-set on _simulate_compile_cache and _warned_n_subjects - __getstate__ / __setstate__ exclude the lock and the per-process AOT cache from pickling (compiled programs can't survive process boundaries anyway) Dedup-key invariant (issue 17): - _collect_unique_simulate_functions: comment documents the dedup contract — pylcm's process_regimes ships per-regime callables for next_state and crtp, so identity-based dedup is collision-free - new test_simulate_functions_use_per_regime_callables in test_regime_processing.py pins the invariant against future regression --- src/lcm/grids/discrete.py | 12 +-- src/lcm/model.py | 99 ++++++++++++++----- src/lcm/model_processing.py | 23 +++++ src/lcm/simulation/compile.py | 39 +++++--- src/lcm/simulation/initial_conditions.py | 10 +- src/lcm/simulation/transitions.py | 4 +- .../regime_building/test_regime_processing.py | 58 +++++++++++ tests/simulation/test_simulate_aot.py | 77 ++++++++++++--- tests/test_int_dtype_invariants.py | 3 + 9 files changed, 255 insertions(+), 70 deletions(-) diff --git a/src/lcm/grids/discrete.py b/src/lcm/grids/discrete.py index acdb850d..72ded5ea 100644 --- a/src/lcm/grids/discrete.py +++ b/src/lcm/grids/discrete.py @@ -51,12 +51,10 @@ def to_jax(self) -> Int1D: """Convert the grid to a Jax array. Discrete state/action codes are pinned to `int32` regardless of the - ambient `jax_enable_x64` setting. `jnp.array([...])` would otherwise - produce `int32` in 32-bit mode and `int64` in x64 mode, and - downstream values (transitions, V-array indexing, action lookups) - inherit that ambiguity — which silently splits the JIT cache into - per-period int32/int64 variants and breaks any AOT-compiled - program that ships a single signature. `int32` covers any realistic - category count and matches the `MISSING_CAT_CODE` sentinel. + ambient `jax_enable_x64` setting. A single integer dtype across + transitions, V-array indexing, and action lookups keeps the JIT cache + unsplit and lets AOT-compiled programs ship one signature. `int32` + accommodates any realistic category count and matches the + `MISSING_CAT_CODE` sentinel. """ return jnp.array(self.codes, dtype=jnp.int32) diff --git a/src/lcm/model.py b/src/lcm/model.py index 8935535a..8bec9b4b 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -2,6 +2,7 @@ import dataclasses import logging +import threading from collections.abc import Mapping from pathlib import Path from types import MappingProxyType @@ -90,16 +91,35 @@ class Model: n_subjects: int | None = None """Expected simulate batch size; enables AOT compile of simulate functions. - When set, the first matching `simulate(...)` call AOT-compiles all simulate - functions for batch shape `n_subjects` in parallel. Subsequent calls with the - same size reuse the compiled programs. Calls with a mismatching size warn - once per size and fall back to the runtime-traced path. `None` keeps the - purely lazy behaviour. + Dispatch by call shape: + + - `None`: purely lazy behaviour, no AOT. + - First `simulate(...)` with `actual_n == n_subjects`: AOT-compiles all + simulate functions for that batch shape in parallel and caches them. + - Subsequent `simulate(...)` with the same matching size: reuses the + cached compiled programs. + - `simulate(...)` with a mismatching size: warns once per size and falls + back to the runtime-traced path. + + Param-shape contract: the cache is keyed only on `n_subjects`. The shapes + and dtypes of `internal_params` leaves at the first matching call become + part of the AOT signature; subsequent calls must keep them stable. MSM- + style estimation (varying values, fixed shapes) is the target use case; + construct a fresh `Model` whenever a param array's shape or dtype changes. """ _params_template: ParamsTemplate """Template for the model parameters.""" + _simulate_compile_cache: dict[int, MappingProxyType[RegimeName, InternalRegime]] + """AOT-compiled `internal_regimes` per matching `n_subjects`.""" + + _warned_n_subjects: set[int] + """Mismatching `actual_n_subjects` already warned about (one warning each).""" + + _simulate_compile_lock: threading.Lock + """Guards check-then-set on `_simulate_compile_cache` and `_warned_n_subjects`.""" + def __init__( self, *, @@ -139,10 +159,9 @@ def __init__( self.n_periods = ages.n_periods self.fixed_params = ensure_containers_are_immutable(fixed_params) self.n_subjects = n_subjects - self._simulate_compile_cache: dict[ - int, MappingProxyType[RegimeName, InternalRegime] - ] = {} - self._warned_n_subjects: set[int] = set() + self._simulate_compile_cache = {} + self._warned_n_subjects = set() + self._simulate_compile_lock = threading.Lock() validate_model_inputs( n_periods=self.n_periods, @@ -172,6 +191,25 @@ def __init__( regime_names_to_ids=self.regime_names_to_ids, ) + def __getstate__(self) -> dict[str, object]: + """Drop AOT compile state from the pickle. + + The threading lock isn't pickleable, and the cached compiled programs + can't survive a process boundary anyway. + """ + state = self.__dict__.copy() + state.pop("_simulate_compile_lock", None) + state.pop("_simulate_compile_cache", None) + state.pop("_warned_n_subjects", None) + return state + + def __setstate__(self, state: dict[str, object]) -> None: + """Restore AOT compile state to a fresh empty cache.""" + self.__dict__.update(state) + self._simulate_compile_cache = {} + self._warned_n_subjects = set() + self._simulate_compile_lock = threading.Lock() + def get_params_template(self) -> UserFacingParamsTemplate: """Get a human-readable params template. @@ -273,35 +311,44 @@ def _resolve_simulate_internal_regimes( ) -> MappingProxyType[RegimeName, InternalRegime]: """Return internal_regimes to use for simulate; AOT cache when matching. - Returns the original `internal_regimes` when `n_subjects` is `None` or - when the actual batch size mismatches the declared one (logging a - warning once per mismatching size). Otherwise builds and caches the - AOT-compiled regimes for the matching size. + Three dispatch cases: + + - `n_subjects is None`: return the original `internal_regimes` + (purely lazy path). + - `actual_n_subjects != n_subjects`: return the original + `internal_regimes` and log a warning the first time each + mismatching size is seen. + - `actual_n_subjects == n_subjects`: return the cached AOT-compiled + regimes, building them on the first call. """ if self.n_subjects is None: return self.internal_regimes if actual_n_subjects != self.n_subjects: - if actual_n_subjects not in self._warned_n_subjects: + with self._simulate_compile_lock: + already_warned = actual_n_subjects in self._warned_n_subjects + if not already_warned: + self._warned_n_subjects.add(actual_n_subjects) + if not already_warned: log.warning( "simulate called with n_subjects=%d but model declared " "n_subjects=%d; falling back to runtime compile.", actual_n_subjects, self.n_subjects, ) - self._warned_n_subjects.add(actual_n_subjects) return self.internal_regimes - if self.n_subjects not in self._simulate_compile_cache: - self._simulate_compile_cache[self.n_subjects] = ( - compile_all_simulate_functions( - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=self.n_subjects, - max_compilation_workers=max_compilation_workers, - logger=log, + with self._simulate_compile_lock: + if self.n_subjects not in self._simulate_compile_cache: + self._simulate_compile_cache[self.n_subjects] = ( + compile_all_simulate_functions( + internal_regimes=self.internal_regimes, + internal_params=internal_params, + ages=self.ages, + n_subjects=self.n_subjects, + max_compilation_workers=max_compilation_workers, + logger=log, + ) ) - ) - return self._simulate_compile_cache[self.n_subjects] + return self._simulate_compile_cache[self.n_subjects] def simulate( self, diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index d141c6a3..947a5428 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -10,6 +10,7 @@ from collections.abc import Callable, Mapping from types import MappingProxyType +import jax from dags import get_ancestors from dags.tree import QNAME_DELIMITER, qname_from_tree_path from jax import Array @@ -151,6 +152,7 @@ def validate_model_inputs( ) -> None: """Validate model constructor inputs.""" _fail_if_invalid_n_subjects(n_subjects=n_subjects) + _fail_if_x64_with_aot(n_subjects=n_subjects) # Early exit if regimes are not lcm.Regime instances if not all(isinstance(regime, Regime) for regime in regimes.values()): @@ -217,6 +219,27 @@ def _fail_if_invalid_n_subjects(*, n_subjects: int | None) -> None: raise ValueError(msg) +def _fail_if_x64_with_aot(*, n_subjects: int | None) -> None: + """Reject `n_subjects` set under `jax_enable_x64=True`. + + The AOT path pins integer dtypes to int32 (see `DiscreteGrid.to_jax`, + `build_initial_states`); under x64 mode, JAX's defaults promote int + intermediates to int64, so the cached AOT signature would not match the + runtime values. Use the lazy path (`n_subjects=None`) under x64 instead. + """ + if n_subjects is None: + return + if jax.config.read("jax_enable_x64"): + msg = ( + "n_subjects requires jax_enable_x64=False. The AOT simulate path pins " + "integer dtypes to int32; x64 mode promotes int intermediates to int64 " + "and breaks the cached AOT signature. Either disable x64 with " + "`jax.config.update('jax_enable_x64', False)` or use the lazy path " + "by leaving n_subjects unset." + ) + raise ModelInitializationError(msg) + + def _validate_all_variables_used(regimes: Mapping[RegimeName, Regime]) -> list[str]: """Validate that all states and actions are used somewhere in each regime. diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index c325b27f..282c39a7 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -6,9 +6,9 @@ simulate call sites then pick them up transparently — no signature changes downstream. -Mirrors the pattern in `solve_brute._compile_all_functions`: deduplicate by -callable identity, sequentially lower (tracing is not thread-safe), then -parallel-compile via `ThreadPoolExecutor` (XLA releases the GIL). +Compilation deduplicates callables by identity (only one program per unique +callable), lowers them sequentially (JAX tracing is not thread-safe), then +parallel-compiles them via a `ThreadPoolExecutor` (XLA releases the GIL). """ import dataclasses @@ -166,13 +166,11 @@ def _collect_unique_simulate_functions( # `sf.argmax_and_max_Q_over_a` has entries for *every* period # (pylcm builds them across the full age grid), but the regime is # only dispatched at runtime for periods in `regime.active_periods`. - # The unused entries can carry a stale `complete_targets` set - # whose shape doesn't match the regime's actual transitions - # (e.g. a forced-canwork regime's argmax for a pre-FRA period - # has choose targets in scope, even though the regime never - # reaches that period at runtime). Tracing those would surface - # `next_` bookkeeping inconsistencies that the lazy path - # never trips. Restrict AOT to active periods to mirror runtime. + # Inactive-period entries can carry a `complete_targets` set whose + # shape doesn't match the regime's actual transitions for that + # period; tracing them would surface `next_` bookkeeping + # mismatches the lazy path never reaches. Restrict AOT to active + # periods to mirror runtime. for period in regime.active_periods: argmax_func = sf.argmax_and_max_Q_over_a[period] active_next = _active_regimes_at_period( @@ -198,6 +196,14 @@ def _collect_unique_simulate_functions( ) unique[key] = (jax.jit(argmax_func), args, label) + # Dedup contract for `next_state` / `crtp`: pylcm's `process_regimes` + # builds these per regime (via `_build_next_state_vmapped` and the + # regime-specific transition-probs builder), so each regime ships a + # distinct callable object. Two regimes collide on the dedup key only + # when they truly share the same compiled program (and thus the same + # arg signature). The invariant is pinned by + # `test_simulate_functions_use_per_regime_callables` in + # `tests/regime_building/test_regime_processing.py`. if not regime.terminal: args = _build_next_state_args( internal_regime=regime, @@ -287,6 +293,12 @@ def _get_regime_V_shapes( internal_regimes: MappingProxyType[RegimeName, InternalRegime], internal_params: InternalParams, ) -> dict[RegimeName, tuple[int, ...]]: + """Return per-regime V-array shape (one length per state grid). + + Used to construct zero-shaped templates for `next_regime_to_V_arr` + when lowering each period's argmax — the abstract signature only + needs the shapes, not the values. + """ shapes: dict[RegimeName, tuple[int, ...]] = {} for regime_name, regime in internal_regimes.items(): space = regime.state_action_space( @@ -365,14 +377,11 @@ def _build_next_state_args( n_initial_states=n_subjects, ) - # `period` is passed as a plain Python int by `calculate_next_states` - # (transitions.py), which traces as the default-precision int. Match that - # here so the lowered shape signature lines up with the runtime call. return { **subject_states, **subject_actions, **stoch_keys, - "period": 0, + "period": jnp.int32(0), "age": ages.values[0], **regime_params, } @@ -394,7 +403,7 @@ def _build_crtp_args( return { **subject_states, **subject_actions, - "period": 0, + "period": jnp.int32(0), "age": ages.values[0], **regime_params, } diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 6ef75a89..69ca8d9d 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -66,13 +66,9 @@ def build_initial_states( key = f"{regime_name}__{state_name}" grid = internal_regime.grids[state_name] if isinstance(grid, DiscreteGrid): - # Match the grid's index dtype so the state is index-stable - # across the simulate loop. Without this, period-0 dispatch - # carries the user-supplied dtype (often int32) but post- - # transition states are promoted to the grid dtype (int64 - # under x64), forcing JAX to compile two argmax variants - # per regime and breaking AOT-compiled programs that key - # on a single signature. + # Cast user-supplied discrete states to the grid's index + # dtype so every period's argmax sees a single signature + # for that state. target_dtype = grid.to_jax().dtype if state_name in initial_states: flat[key] = initial_states[state_name].astype(target_dtype) diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index c494d996..d1a1a4d3 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -126,7 +126,7 @@ def calculate_next_states( **state_action_space.states, **optimal_actions, **stochastic_variables_keys, - period=period, + period=jnp.int32(period), age=age, **regime_params, ) @@ -187,7 +187,7 @@ def calculate_next_regime_membership( internal_regime.simulate_functions.compute_regime_transition_probs( # ty: ignore[call-non-callable] **state_action_space.states, **optimal_actions, - period=period, + period=jnp.int32(period), age=age, **regime_params, ) diff --git a/tests/regime_building/test_regime_processing.py b/tests/regime_building/test_regime_processing.py index d51f59e2..794390fb 100644 --- a/tests/regime_building/test_regime_processing.py +++ b/tests/regime_building/test_regime_processing.py @@ -7,6 +7,7 @@ from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal +from lcm import Regime, categorical from lcm.ages import AgeGrid from lcm.grids import DiscreteGrid, LinSpacedGrid from lcm.regime_building.processing import ( @@ -177,6 +178,63 @@ def wealth_constraint(wealth): assert got.index.is_unique +def test_simulate_functions_use_per_regime_callables(): + """Each non-terminal regime gets a distinct `next_state` / `crtp` callable. + + The simulate-AOT path in `lcm.simulation.compile` deduplicates by callable + identity for `next_state` and `compute_regime_transition_probs`. That is + only safe if `process_regimes` ships a fresh callable per regime — two + regimes sharing one callable would compile against the first regime's + state-action shapes and silently apply that program to the second. + """ + + def next_x(x): + return x + + def regime_transition(age, final_age): + return jnp.where(age >= final_age, 1, 0) + + @categorical(ordered=False) + class TwoRegimeId: + early: int + late: int + + early = Regime( + transition=regime_transition, + states={"x": LinSpacedGrid(start=0, stop=10, n_points=4)}, + state_transitions={"x": next_x}, + functions={"utility": lambda x: x}, + active=lambda age: age < 1, + ) + late = Regime( + transition=regime_transition, + states={"x": LinSpacedGrid(start=0, stop=10, n_points=6)}, + state_transitions={"x": next_x}, + functions={"utility": lambda x: x}, + active=lambda age: age >= 1, + ) + + regimes = {"early": early, "late": late} + internal_regimes = process_regimes( + regimes=regimes, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_names_to_ids=MappingProxyType({"early": 0, "late": 1}), + enable_jit=True, + ) + + early_next_state = internal_regimes["early"].simulate_functions.next_state + late_next_state = internal_regimes["late"].simulate_functions.next_state + assert id(early_next_state) != id(late_next_state) + + early_crtp = internal_regimes[ + "early" + ].simulate_functions.compute_regime_transition_probs + late_crtp = internal_regimes[ + "late" + ].simulate_functions.compute_regime_transition_probs + assert id(early_crtp) != id(late_crtp) + + def test_rename_params_to_qnames_with_partial(): """Regression: dags >=0.5.1 renames bound partial keywords to qualified names.""" diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 3b8e57a1..5d7037b3 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -4,13 +4,21 @@ parallel-compiles all simulate functions for batch shape `N`. Subsequent calls with size `N` reuse the cache; calls with a mismatching size warn once per size and fall back to the runtime-traced path. + +The AOT path pins integer dtypes to int32, which is incompatible with +`jax_enable_x64=True`. The conftest enables x64 by default, so every test in +this file disables it via the autouse `_disable_x64` fixture. """ import logging +from collections.abc import Iterator +from typing import Any import jax.numpy as jnp import jax.stages import pytest +from jax import Array +from jax import config as jax_config from lcm import Model from lcm.ages import AgeGrid @@ -22,6 +30,17 @@ ) +@pytest.fixture(autouse=True) +def _disable_x64() -> Iterator[None]: + """Disable x64 for AOT tests (conftest enables it globally).""" + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=False) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) + + def _build_test_model(*, n_periods: int, n_subjects: int | None = None) -> Model: """Construct the small 2-regime regression model with optional n_subjects.""" final_age_alive = 18 + n_periods - 2 @@ -38,7 +57,7 @@ def _build_test_model(*, n_periods: int, n_subjects: int | None = None) -> Model ) -def _build_initial_conditions(*, n_subjects: int) -> dict: +def _build_initial_conditions(*, n_subjects: int) -> dict[str, Array]: """Subject array of size `n_subjects` matching the regression test model.""" wealths = jnp.linspace(20.0, 320.0, num=n_subjects) return { @@ -50,17 +69,34 @@ def _build_initial_conditions(*, n_subjects: int) -> dict: @pytest.mark.parametrize("invalid", [0, -3]) def test_n_subjects_validation_rejects_non_positive(invalid: int) -> None: + """`Model(n_subjects=0)` and negative values raise `ValueError`.""" with pytest.raises(ValueError, match="n_subjects"): _build_test_model(n_periods=3, n_subjects=invalid) def test_n_subjects_validation_rejects_non_int() -> None: + """`Model(n_subjects=1.5)` raises `TypeError`.""" with pytest.raises(TypeError, match="n_subjects"): _build_test_model(n_periods=3, n_subjects=1.5) # ty: ignore[invalid-argument-type] -def test_n_subjects_none_keeps_lazy_behavior() -> None: - """Without n_subjects, simulate works and no AOT cache is populated.""" +def test_n_subjects_none_leaves_aot_cache_empty_after_simulate() -> None: + """`Model(n_subjects=None)` keeps `_simulate_compile_cache` empty after simulate.""" + n_periods = 3 + model = _build_test_model(n_periods=n_periods, n_subjects=None) + params = get_params(n_periods=n_periods) + + model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=4), + ) + + assert dict(model._simulate_compile_cache) == {} + + +def test_n_subjects_none_yields_simulate_result_sized_to_actual() -> None: + """`Model(n_subjects=None).simulate(...)` returns a result sized to the input.""" n_periods = 3 model = _build_test_model(n_periods=n_periods, n_subjects=None) params = get_params(n_periods=n_periods) @@ -72,14 +108,12 @@ def test_n_subjects_none_keeps_lazy_behavior() -> None: ) assert result.n_subjects == 4 - assert model.n_subjects is None - assert not getattr(model, "_simulate_compile_cache", {}) -def test_simulate_compiles_only_once_with_matching_n_subjects( +def test_simulate_second_matching_call_does_not_invoke_compile( monkeypatch: pytest.MonkeyPatch, ) -> None: - """First simulate call AOT-compiles; second call hits the cache.""" + """Matching second `simulate(...)` invokes `Lowered.compile` zero times.""" n_periods = 3 n_subjects = 4 model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) @@ -89,7 +123,9 @@ def test_simulate_compiles_only_once_with_matching_n_subjects( counter = {"count": 0} original_compile = jax.stages.Lowered.compile - def counting_compile(self: jax.stages.Lowered, *args, **kwargs): + def counting_compile( + self: jax.stages.Lowered, *args: Any, **kwargs: Any + ) -> jax.stages.Compiled: counter["count"] += 1 return original_compile(self, *args, **kwargs) @@ -102,7 +138,6 @@ def counting_compile(self: jax.stages.Lowered, *args, **kwargs): period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, ) - n_first = counter["count"] counter["count"] = 0 model.simulate( @@ -110,10 +145,26 @@ def counting_compile(self: jax.stages.Lowered, *args, **kwargs): period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, ) - n_second = counter["count"] - assert n_first > 0, "First simulate must trigger compilation." - assert n_second == 0, "Second simulate must hit the AOT cache." + assert counter["count"] == 0 + + +def test_simulate_first_matching_call_populates_aot_cache() -> None: + """Matching first `simulate(...)` populates the cache for that size.""" + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + assert n_subjects not in model._simulate_compile_cache + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=_build_initial_conditions(n_subjects=n_subjects), + ) + assert n_subjects in model._simulate_compile_cache @@ -148,7 +199,7 @@ def test_simulate_warns_on_n_subjects_mismatch( assert actual_n not in model._simulate_compile_cache -def test_simulate_caches_recompiled_size_no_second_warning( +def test_simulate_warns_only_once_per_mismatching_size( caplog: pytest.LogCaptureFixture, ) -> None: """Two calls with the same mismatching size produce only one WARNING.""" diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index 3147de3f..282b24e9 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -10,6 +10,7 @@ def test_discrete_grid_to_jax_is_int32() -> None: + """Every `DiscreteGrid.to_jax()` in the model returns an `int32` array.""" model = get_model(n_periods=3) for regime in model.regimes.values(): for grid in {**regime.states, **regime.actions}.values(): @@ -21,6 +22,7 @@ def test_discrete_grid_to_jax_is_int32() -> None: def test_build_initial_states_discrete_dtype_is_int32() -> None: + """`build_initial_states` casts every discrete state array to `int32`.""" model = get_model(n_periods=3) initial_states = { "wealth": jnp.array([20.0, 50.0]), @@ -38,4 +40,5 @@ def test_build_initial_states_discrete_dtype_is_int32() -> None: def test_missing_cat_code_is_int32_minimum() -> None: + """`MISSING_CAT_CODE` equals `iinfo(int32).min` — never a real category code.""" assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE From 7e815320d33baf197fb01857159dfdf7fd210956 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 15:06:45 +0200 Subject: [PATCH 71/80] Package A: int dtype barriers at the API boundary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pin every int that crosses into pylcm to int32, with overflow checks at the boundary helpers. After this, the AOT simulate path works under both `jax_enable_x64=True` and `=False` — the int-side hazard is gone, so the construction-time guard is dropped. Boundaries: - New `lcm.dtypes.safe_to_int32(value, *, name)`: host-side cast that raises `ValueError` with the leaf's qualified name on int32 overflow, rather than silently wrapping. Reused at every int boundary helper. - `lcm.params.processing._cast_int_leaves_to_int32`: applied in `broadcast_to_template` over every resolved params leaf. Walks `MappingLeaf` / `SequenceLeaf` recursively. Casts typed JAX/numpy int arrays to int32; leaves Python int/bool scalars alone so JAX's weak-typing rules can still promote them per call site (e.g. `discount_factor: 1` keeps working in float-typed functions). - `lcm.simulation.transitions._update_states_for_subjects`: casts `next_state_values` to the storage dtype before `jnp.where`, but only when the kinds match (int->int, float->float). Cross-kind pairs (int storage + float transition output, possible when a user passes int initial conditions for a continuous state) keep JAX's promotion semantics; the cross-kind boundary cast is Package B. Guard removal: - `_fail_if_x64_with_aot` deleted. The boundary casts now keep `internal_params` and the simulate state pool int32 regardless of `jax_enable_x64`, so the AOT cache signature is stable under x64. `tests/simulation/test_simulate_aot.py` drops its autouse `_disable_x64` fixture; AOT tests run under the conftest's default x64 setting. Tests: - New `tests/test_dtypes.py` (5 tests): `safe_to_int32` round-trip, in-range int64 array, overflow with qualified-name error, underflow. - Extended `tests/test_int_dtype_invariants.py` with six tests covering the four new boundaries: - `_update_states_for_subjects` keeps storage int32 when next-state output is int64 - `process_params` casts a typed int64 array leaf to int32 - `process_params` raises `ValueError` naming the param qname on int array overflow - `process_params` casts int arrays inside a `MappingLeaf` to int32 - `process_params` passes Python int leaves through unchanged for JAX weak-typing - `simulate(...)` accepts int64 `regime` initial conditions and round-trips to the same DataFrame as int32 Out of scope (Package B): - Continuous-state float dtype normalisation at boundaries. - Float scalar params -> single canonical float dtype. - Cross-kind cast in `_update_states_for_subjects` (int storage + float transition output, when user supplies int initial conditions for a continuous state — pre-existing behaviour preserved). --- src/lcm/dtypes.py | 46 ++++++++ src/lcm/model_processing.py | 23 ---- src/lcm/params/processing.py | 43 ++++++++ src/lcm/simulation/transitions.py | 12 ++- tests/simulation/test_simulate_aot.py | 21 +--- tests/test_dtypes.py | 41 ++++++++ tests/test_int_dtype_invariants.py | 144 +++++++++++++++++++++++++- 7 files changed, 287 insertions(+), 43 deletions(-) create mode 100644 src/lcm/dtypes.py create mode 100644 tests/test_dtypes.py diff --git a/src/lcm/dtypes.py b/src/lcm/dtypes.py new file mode 100644 index 00000000..75dd591a --- /dev/null +++ b/src/lcm/dtypes.py @@ -0,0 +1,46 @@ +"""Boundary-cast helpers that pin user-supplied data to canonical pylcm dtypes. + +These helpers run **outside JIT** at every API boundary (params, initial +conditions, regime-id arrays). They check the value fits the target dtype +and raise a clearly-named error if not. Inside-JIT casts (e.g. on +transition outputs landing in the simulate state pool) keep the silent +saturation/wrap semantics — overflow there means a broken user transition, +which is out of scope for the boundary helpers. +""" + +import jax.numpy as jnp +import numpy as np +from jax import Array + +_INT32_MIN = int(np.iinfo(np.int32).min) +_INT32_MAX = int(np.iinfo(np.int32).max) + + +def safe_to_int32(value: object, *, name: str) -> Array: + """Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range. + + Args: + value: A Python int, numpy/JAX integer scalar, or array-like of + integer values. + name: Qualified name of the leaf — surfaced in the error message + so the user can locate the offending input. + + Returns: + A `jnp.int32` array (0-d if `value` was a scalar). + + Raises: + ValueError: If any element of `value` is outside the int32 range + `[-2**31, 2**31 - 1]`. The message names the leaf via `name`. + + """ + np_value = np.asarray(value) + if np_value.size > 0: + lo = int(np_value.min()) + hi = int(np_value.max()) + if lo < _INT32_MIN or hi > _INT32_MAX: + msg = ( + f"{name}: int32 overflow — value range [{lo}, {hi}] " + f"exceeds [{_INT32_MIN}, {_INT32_MAX}]." + ) + raise ValueError(msg) + return jnp.asarray(np_value, dtype=jnp.int32) diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index 947a5428..d141c6a3 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -10,7 +10,6 @@ from collections.abc import Callable, Mapping from types import MappingProxyType -import jax from dags import get_ancestors from dags.tree import QNAME_DELIMITER, qname_from_tree_path from jax import Array @@ -152,7 +151,6 @@ def validate_model_inputs( ) -> None: """Validate model constructor inputs.""" _fail_if_invalid_n_subjects(n_subjects=n_subjects) - _fail_if_x64_with_aot(n_subjects=n_subjects) # Early exit if regimes are not lcm.Regime instances if not all(isinstance(regime, Regime) for regime in regimes.values()): @@ -219,27 +217,6 @@ def _fail_if_invalid_n_subjects(*, n_subjects: int | None) -> None: raise ValueError(msg) -def _fail_if_x64_with_aot(*, n_subjects: int | None) -> None: - """Reject `n_subjects` set under `jax_enable_x64=True`. - - The AOT path pins integer dtypes to int32 (see `DiscreteGrid.to_jax`, - `build_initial_states`); under x64 mode, JAX's defaults promote int - intermediates to int64, so the cached AOT signature would not match the - runtime values. Use the lazy path (`n_subjects=None`) under x64 instead. - """ - if n_subjects is None: - return - if jax.config.read("jax_enable_x64"): - msg = ( - "n_subjects requires jax_enable_x64=False. The AOT simulate path pins " - "integer dtypes to int32; x64 mode promotes int intermediates to int64 " - "and breaks the cached AOT signature. Either disable x64 with " - "`jax.config.update('jax_enable_x64', False)` or use the lazy path " - "by leaving n_subjects unset." - ) - raise ModelInitializationError(msg) - - def _validate_all_variables_used(regimes: Mapping[RegimeName, Regime]) -> list[str]: """Validate that all states and actions are used somewhere in each regime. diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index 680e46e7..03060407 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -4,10 +4,14 @@ from types import MappingProxyType from typing import Any, cast +import numpy as np from dags.tree import QNAME_DELIMITER, qname_from_tree_path, tree_path_from_qname +from lcm.dtypes import safe_to_int32 from lcm.exceptions import InvalidNameError, InvalidParamsError from lcm.interfaces import InternalRegime +from lcm.params.mapping_leaf import MappingLeaf +from lcm.params.sequence_leaf import SequenceLeaf from lcm.typing import ( InternalParams, ParamsTemplate, @@ -110,12 +114,51 @@ def broadcast_to_template( if unknown: raise InvalidParamsError(f"Unknown keys: {sorted(unknown)}") + for regime, leaves in result.items(): + for param_qname, value in leaves.items(): + leaves[param_qname] = _cast_int_leaves_to_int32( + value, name=f"{regime}{QNAME_DELIMITER}{param_qname}" + ) + return cast( "InternalParams", MappingProxyType({k: MappingProxyType(v) for k, v in result.items()}), ) +def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 + """Normalise typed integer arrays in a params value to `jnp.int32`. + + Only typed JAX or numpy integer arrays are cast — Python `int` / `bool` + leaves stay unmodified. JAX's weak-typing rules promote raw Python ints + correctly to whichever dtype the surrounding operation needs (e.g. + `discount_factor: 1` works in a float-typed function), so casting them + eagerly to `int32` would force premature dtype commitment. Typed + arrays (`jnp.array(..., dtype=jnp.int64)`) are strongly typed by JAX + and would otherwise leak their dtype into the AOT signature. + + Walks `MappingLeaf` and `SequenceLeaf` recursively. Float and + non-numeric leaves pass through — float normalisation is Package B. + """ + if isinstance(value, MappingLeaf): + return MappingLeaf( + { + k: _cast_int_leaves_to_int32(v, name=f"{name}.{k}") + for k, v in value.data.items() + } + ) + if isinstance(value, SequenceLeaf): + return SequenceLeaf( + [ + _cast_int_leaves_to_int32(v, name=f"{name}[{i}]") + for i, v in enumerate(value.data) + ] + ) + if hasattr(value, "dtype") and np.issubdtype(value.dtype, np.integer): + return safe_to_int32(value, name=name) + return value + + def _find_candidates( *, qname: str, diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index d1a1a4d3..db1fec82 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -286,9 +286,19 @@ def _update_states_for_subjects( for target, target_next_states in computed_next_states.items(): for next_state_name, next_state_values in target_next_states.items(): state_name = f"{target}__{next_state_name.removeprefix('next_')}" + target_dtype = all_states[state_name].dtype + # Preserve storage dtype only when the transition output is the + # same numeric kind. Across kinds (e.g. int storage + float + # transition output) leave JAX's promotion in place; the + # cross-kind boundary cast belongs to Package B. + new_values = ( + next_state_values.astype(target_dtype) + if next_state_values.dtype.kind == target_dtype.kind + else next_state_values + ) updated_states[state_name] = jnp.where( subject_indices, - next_state_values, + new_values, all_states[state_name], ) diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 5d7037b3..29254a43 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -3,22 +3,18 @@ When `Model(n_subjects=N)` is set, the first matching `simulate(...)` call parallel-compiles all simulate functions for batch shape `N`. Subsequent calls with size `N` reuse the cache; calls with a mismatching size warn once per size -and fall back to the runtime-traced path. - -The AOT path pins integer dtypes to int32, which is incompatible with -`jax_enable_x64=True`. The conftest enables x64 by default, so every test in -this file disables it via the autouse `_disable_x64` fixture. +and fall back to the runtime-traced path. AOT works under both `x64=False` +and `x64=True` because integer leaves are normalised to `int32` at every +boundary by `lcm.params.processing` and the simulate state pool. """ import logging -from collections.abc import Iterator from typing import Any import jax.numpy as jnp import jax.stages import pytest from jax import Array -from jax import config as jax_config from lcm import Model from lcm.ages import AgeGrid @@ -30,17 +26,6 @@ ) -@pytest.fixture(autouse=True) -def _disable_x64() -> Iterator[None]: - """Disable x64 for AOT tests (conftest enables it globally).""" - previous = jax_config.read("jax_enable_x64") - jax_config.update("jax_enable_x64", val=False) - try: - yield - finally: - jax_config.update("jax_enable_x64", val=previous) - - def _build_test_model(*, n_periods: int, n_subjects: int | None = None) -> Model: """Construct the small 2-regime regression model with optional n_subjects.""" final_age_alive = 18 + n_periods - 2 diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100644 index 00000000..77b0c69f --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,41 @@ +"""Tests for `lcm.dtypes` boundary-cast helpers.""" + +import jax.numpy as jnp +import numpy as np +import pytest + +from lcm.dtypes import safe_to_int32 + + +def test_safe_to_int32_casts_python_int_in_range() -> None: + """A Python int within int32 range becomes a `jnp.int32` 0-d array.""" + out = safe_to_int32(7, name="x") + assert out.dtype == jnp.int32 + assert int(out) == 7 + + +def test_safe_to_int32_casts_int64_array_in_range() -> None: + """An int64 array within int32 range becomes int32 with the same values.""" + arr = jnp.asarray([0, 1, -3], dtype=jnp.int64) + out = safe_to_int32(arr, name="x") + assert out.dtype == jnp.int32 + np.testing.assert_array_equal(np.asarray(out), [0, 1, -3]) + + +def test_safe_to_int32_raises_on_python_int_overflow() -> None: + """A Python int above int32 max raises `ValueError` naming the leaf.""" + with pytest.raises(ValueError, match="my_param"): + safe_to_int32(2**32, name="my_param") + + +def test_safe_to_int32_raises_on_array_overflow() -> None: + """An int64 array containing values above int32 max raises with the leaf name.""" + arr = jnp.asarray([1, 2, 2**32], dtype=jnp.int64) + with pytest.raises(ValueError, match="regime"): + safe_to_int32(arr, name="regime") + + +def test_safe_to_int32_raises_on_underflow() -> None: + """A Python int below int32 min raises `ValueError` naming the leaf.""" + with pytest.raises(ValueError, match="offset"): + safe_to_int32(-(2**40), name="offset") diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index 282b24e9..9de96eea 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -1,12 +1,26 @@ """Integer dtypes are pinned to int32 across pylcm regardless of x64 mode.""" +from types import MappingProxyType + import jax.numpy as jnp +import pytest +from lcm import Model +from lcm.ages import AgeGrid +from lcm.params import MappingLeaf +from lcm.params.processing import process_params from lcm.simulation.initial_conditions import ( MISSING_CAT_CODE, build_initial_states, ) -from tests.test_models.deterministic.regression import get_model +from lcm.simulation.transitions import _update_states_for_subjects +from tests.test_models.deterministic.regression import ( + RegimeId, + dead, + get_model, + get_params, + working_life, +) def test_discrete_grid_to_jax_is_int32() -> None: @@ -42,3 +56,131 @@ def test_build_initial_states_discrete_dtype_is_int32() -> None: def test_missing_cat_code_is_int32_minimum() -> None: """`MISSING_CAT_CODE` equals `iinfo(int32).min` — never a real category code.""" assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE + + +def test_update_states_for_subjects_preserves_storage_dtype() -> None: + """A transition that returns int64 cannot promote the storage pool to int64.""" + all_states = MappingProxyType( + {"work__health": jnp.asarray([0, 1, 0, 1], dtype=jnp.int32)} + ) + int64_next = jnp.asarray([1, 1, 1, 1], dtype=jnp.int64) + computed = MappingProxyType({"work": MappingProxyType({"next_health": int64_next})}) + subjects = jnp.asarray([True, False, True, False]) + + updated = _update_states_for_subjects( + all_states=all_states, + computed_next_states=computed, + subject_indices=subjects, + ) + + assert updated["work__health"].dtype == jnp.int32 + + +def test_process_params_passes_python_int_through_for_jax_weak_typing() -> None: + """Python `int` params stay weak-typed so JAX promotes them per call site.""" + template = MappingProxyType({"regime_a": MappingProxyType({"final_age": "int"})}) + user_params = {"regime_a": {"final_age": 65}} + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + # Python int stays Python int; JAX weak-typing handles promotion at JIT. + assert out["regime_a"]["final_age"] == 65 + assert isinstance(out["regime_a"]["final_age"], int) + + +def test_process_params_casts_int64_array_to_int32() -> None: + """A `jnp.int64` array param leaf is normalised to `jnp.int32`.""" + template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) + user_params = {"regime_a": {"schedule": jnp.asarray([0, 1, 2], dtype=jnp.int64)}} + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + schedule = out["regime_a"]["schedule"] + assert schedule.dtype == jnp.int32 # ty: ignore[unresolved-attribute] + + +def test_process_params_int_array_overflow_raises_with_qualified_name() -> None: + """An out-of-int32-range int array surfaces the param's qualified name.""" + template = MappingProxyType({"regime_a": MappingProxyType({"big_param": "Array"})}) + user_params = {"regime_a": {"big_param": jnp.asarray([0, 2**40], dtype=jnp.int64)}} + + with pytest.raises(ValueError, match="big_param"): + process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + +def test_process_params_casts_int_array_inside_mapping_leaf_to_int32() -> None: + """`MappingLeaf` int arrays land at `jnp.int32` after params processing.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "MappingLeaf"})} + ) + user_params = { + "regime_a": { + "sched": MappingLeaf( + { + "low": jnp.asarray([0, 1], dtype=jnp.int64), + "high": jnp.asarray([10, 20], dtype=jnp.int64), + } + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + leaf = out["regime_a"]["sched"] + assert leaf.data["low"].dtype == jnp.int32 # ty: ignore[unresolved-attribute] + assert leaf.data["high"].dtype == jnp.int32 # ty: ignore[unresolved-attribute] + + +def test_simulate_accepts_int64_regime_initial_condition_and_round_trips() -> None: + """`regime` as `jnp.int64` simulates the same as `jnp.int32`.""" + n_periods = 3 + final_age_alive = 18 + n_periods - 2 + model = Model( + regimes={ + "working_life": working_life.replace( + active=lambda age: age <= final_age_alive, + ), + "dead": dead, + }, + ages=AgeGrid(start=18, stop=final_age_alive + 1, step="Y"), + regime_id_class=RegimeId, + ) + params = get_params(n_periods=n_periods) + + common = { + "wealth": jnp.linspace(20.0, 80.0, num=4), + "age": jnp.full((4,), 18.0), + } + initial_conditions_int32 = { + **common, + "regime": jnp.asarray([RegimeId.working_life] * 4, dtype=jnp.int32), + } + initial_conditions_int64 = { + **common, + "regime": jnp.asarray([RegimeId.working_life] * 4, dtype=jnp.int64), + } + + df_int32 = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=initial_conditions_int32, + ).to_dataframe() + df_int64 = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=initial_conditions_int64, + ).to_dataframe() + + assert df_int64["regime"].equals(df_int32["regime"]) From 63b740cefb0fae7078fb57770ad0f6f20e93f1c9 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 18:26:30 +0200 Subject: [PATCH 72/80] Fix Package A 32-bit precision tests: build overflow fixtures with numpy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The two overflow tests built their fixtures via `jnp.asarray([..., 2**32], dtype=jnp.int64)`. Under `jax_enable_x64=False` (the GPU 32-bit-precision job), JAX truncates the requested int64 to int32 at construction time and raises its own `OverflowError` before `safe_to_int32` ever sees the value — so the test asserts a `ValueError` that is never reached. Use `np.asarray(..., dtype=np.int64)` for these fixtures instead. Numpy honours the explicit dtype regardless of the JAX precision setting, so our boundary helper receives a real int64 and produces its own qualified-name `ValueError`. The same pattern (use numpy for overflow fixtures) will land in Package B for the float-overflow test. --- tests/test_dtypes.py | 5 ++++- tests/test_int_dtype_invariants.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 77b0c69f..3f903ebb 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -30,7 +30,10 @@ def test_safe_to_int32_raises_on_python_int_overflow() -> None: def test_safe_to_int32_raises_on_array_overflow() -> None: """An int64 array containing values above int32 max raises with the leaf name.""" - arr = jnp.asarray([1, 2, 2**32], dtype=jnp.int64) + # Use numpy here: `jnp.asarray(..., dtype=jnp.int64)` truncates to int32 + # under `jax_enable_x64=False` and trips JAX's own overflow guard before + # `safe_to_int32` ever sees the value. + arr = np.asarray([1, 2, 2**32], dtype=np.int64) with pytest.raises(ValueError, match="regime"): safe_to_int32(arr, name="regime") diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index 9de96eea..885134cf 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -3,6 +3,7 @@ from types import MappingProxyType import jax.numpy as jnp +import numpy as np import pytest from lcm import Model @@ -108,11 +109,13 @@ def test_process_params_casts_int64_array_to_int32() -> None: def test_process_params_int_array_overflow_raises_with_qualified_name() -> None: """An out-of-int32-range int array surfaces the param's qualified name.""" template = MappingProxyType({"regime_a": MappingProxyType({"big_param": "Array"})}) - user_params = {"regime_a": {"big_param": jnp.asarray([0, 2**40], dtype=jnp.int64)}} + # Numpy here: under `jax_enable_x64=False`, `jnp.asarray(..., dtype=int64)` + # of an out-of-int32 value raises before our helper sees it. + user_params = {"regime_a": {"big_param": np.asarray([0, 2**40], dtype=np.int64)}} with pytest.raises(ValueError, match="big_param"): process_params( - params=user_params, + params=user_params, # ty: ignore[invalid-argument-type] params_template=template, # ty: ignore[invalid-argument-type] ) From 568cbcb9af5be3d17e5cf26f8e87d3a2a5bc7645 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 22:28:44 +0200 Subject: [PATCH 73/80] Address #340 review-2: counterfactuals, multi-assertion tests, dedup keys MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Docstring style: - `dtypes.py` module docstring: drop "out of scope for the boundary helpers" / "broken user transition" framing — describe what the module does instead. - `_cast_int_leaves_to_int32` docstring: drop "would force premature dtype commitment" counterfactual and "Package B" project-jargon reference. Replace with a positive description (cast / pass-through by leaf kind). - `__getstate__` docstring: lead with what is returned (a copy of `__dict__` minus the per-process AOT compile state) before explaining mechanism. Public-API docs: - `process_params` Raises section now lists the `ValueError` that the int boundary cast can surface, plus a paragraph documenting the dtype-normalisation step. Module docstring covers the same. Concurrency docstring: - `_simulate_compile_lock` field docstring rephrased to make it explicit that the consequent `log.warning` is intentionally outside the lock. Test coverage and structure: - Add `test_process_params_casts_int_array_inside_sequence_leaf_to_int32` to mirror the `MappingLeaf` test for `SequenceLeaf` int contents. - Add `test_unpickled_model_can_simulate_with_aot`: full pickle round- trip through `cloudpickle`, then re-run simulate to confirm the AOT cache is reset and re-populated post-unpickle. - Parametrise `test_process_params_casts_int_array_inside_mapping_leaf_to_int32` over `low`/`high` (one assertion per test). - Parametrise `test_safe_to_int32_*` over `python-int` / `int64-array` inputs and split into `_returns_int32` (dtype) and `_preserves_in_range_values` (value preservation). - Split `test_simulate_warns_on_n_subjects_mismatch` into four single- assertion tests behind a shared fixture (warning count, declared-N in message, actual-N in message, no cache entry). - Replace `test_simulate_functions_use_per_regime_callables` 2-assertion body with a parametrised version over `next_state` / `compute_regime_transition_probs`, using a shared fixture. New docstring describes user-visible behaviour rather than the AOT-dedup rationale. - Round-trip int64-regime test now uses `pd.testing.assert_frame_equal` across all output columns, not `.equals()` on one column. Structural fixes: - `_collect_unique_simulate_functions` now keys `next_state` and `compute_regime_transition_probs` by `(kind, regime_name, callable-id)` instead of `(kind, callable-id)`. Two regimes that share a callable identity now still get distinct compiled programs (carried-forward issue 17 from prior review). Removes the need for the comment that pointed at a specific test file path. - `_cast_int_leaves_to_int32` swaps the `hasattr(value, "dtype")` duck-type check for `isinstance(value, (Array, np.ndarray))` — closes the PR-#302 review thread on duck-typing arrays. - `draw_key_from_dict`: pin `regime_ids` to `jnp.int32` so the "all integers are int32" invariant holds throughout simulate, not just at the AOT-traced boundaries. --- src/lcm/dtypes.py | 14 ++- src/lcm/model.py | 15 ++- src/lcm/params/processing.py | 44 ++++++-- src/lcm/simulation/compile.py | 19 ++-- src/lcm/simulation/transitions.py | 5 +- .../regime_building/test_regime_processing.py | 48 ++++---- tests/simulation/test_simulate_aot.py | 106 +++++++++++++++--- tests/test_dtypes.py | 32 ++++-- tests/test_int_dtype_invariants.py | 44 +++++++- 9 files changed, 236 insertions(+), 91 deletions(-) diff --git a/src/lcm/dtypes.py b/src/lcm/dtypes.py index 75dd591a..51dd958c 100644 --- a/src/lcm/dtypes.py +++ b/src/lcm/dtypes.py @@ -1,11 +1,13 @@ """Boundary-cast helpers that pin user-supplied data to canonical pylcm dtypes. -These helpers run **outside JIT** at every API boundary (params, initial -conditions, regime-id arrays). They check the value fits the target dtype -and raise a clearly-named error if not. Inside-JIT casts (e.g. on -transition outputs landing in the simulate state pool) keep the silent -saturation/wrap semantics — overflow there means a broken user transition, -which is out of scope for the boundary helpers. +Used at every API boundary that accepts user data (params, initial +conditions, regime-id arrays) — always called from Python, never inside +JIT. Each helper validates that the value fits the target dtype and +raises a clearly-named error if not. + +Casts further down the simulate stack (e.g. transition outputs landing +in the state pool) use plain `.astype` and rely on the boundary cast +above them having already pinned the canonical dtype. """ import jax.numpy as jnp diff --git a/src/lcm/model.py b/src/lcm/model.py index 8bec9b4b..1897272f 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -118,7 +118,12 @@ class Model: """Mismatching `actual_n_subjects` already warned about (one warning each).""" _simulate_compile_lock: threading.Lock - """Guards check-then-set on `_simulate_compile_cache` and `_warned_n_subjects`.""" + """Serialises mutations of `_simulate_compile_cache` and `_warned_n_subjects`. + + The check-then-set on each container is held under this lock. The + consequent `log.warning` call sits outside the lock so concurrent + simulate() calls don't serialise on logging I/O. + """ def __init__( self, @@ -192,10 +197,12 @@ def __init__( ) def __getstate__(self) -> dict[str, object]: - """Drop AOT compile state from the pickle. + """Return a copy of `__dict__` with per-process AOT compile state removed. - The threading lock isn't pickleable, and the cached compiled programs - can't survive a process boundary anyway. + Drops `_simulate_compile_lock` (a `threading.Lock`, not pickleable), + `_simulate_compile_cache` (compiled XLA programs that can't survive + a process boundary), and `_warned_n_subjects` (its companion set). + `__setstate__` restores all three to their fresh state. """ state = self.__dict__.copy() state.pop("_simulate_compile_lock", None) diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index 03060407..64f02cc2 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -1,4 +1,11 @@ -"""Process user-provided params into internal params.""" +"""Process user-provided params into internal params. + +`process_params` resolves user-supplied parameters against the model's +template, then runs a boundary-cast pass that normalises typed integer +leaves to `jnp.int32` (and integer arrays inside `MappingLeaf` / +`SequenceLeaf`). Out-of-range values surface as `ValueError` with the +offending leaf's qualified name. +""" from collections.abc import Mapping from types import MappingProxyType @@ -6,6 +13,7 @@ import numpy as np from dags.tree import QNAME_DELIMITER, qname_from_tree_path, tree_path_from_qname +from jax import Array from lcm.dtypes import safe_to_int32 from lcm.exceptions import InvalidNameError, InvalidParamsError @@ -38,7 +46,11 @@ def process_params( - Regime level: `{"regime_0": {"arg_0": 0.0}}` — propagates within regime_0 - Function level: `{"regime_0": {"func": {"arg_0": 0.0}}}` — direct specification - The output always matches the params_template skeleton. + The output always matches the params_template skeleton. Typed integer + arrays in the user input — including those inside `MappingLeaf` / + `SequenceLeaf` containers — are cast to `jnp.int32` so the AOT signature + is stable across calls; Python scalars pass through to keep JAX weak- + typing semantics. Args: params: User-provided parameters dictionary. @@ -50,6 +62,8 @@ def process_params( Raises: InvalidParamsError: If params contains unexpected keys or type mismatches. InvalidNameError: If the same parameter is specified at multiple levels. + ValueError: If a typed integer leaf carries a value outside the + int32 range; the message names the offending parameter qname. """ return broadcast_to_template(params=params, template=params_template, required=True) @@ -129,16 +143,20 @@ def broadcast_to_template( def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 """Normalise typed integer arrays in a params value to `jnp.int32`. - Only typed JAX or numpy integer arrays are cast — Python `int` / `bool` - leaves stay unmodified. JAX's weak-typing rules promote raw Python ints - correctly to whichever dtype the surrounding operation needs (e.g. - `discount_factor: 1` works in a float-typed function), so casting them - eagerly to `int32` would force premature dtype commitment. Typed - arrays (`jnp.array(..., dtype=jnp.int64)`) are strongly typed by JAX - and would otherwise leak their dtype into the AOT signature. + Casts: + + - Typed JAX or numpy integer arrays (`jnp.array(..., dtype=jnp.int64)`, + `np.array(...)`) — strongly typed by JAX, so cast to `int32` to keep + the AOT signature stable. + - Integer leaves inside `MappingLeaf` / `SequenceLeaf` — recurse. + + Passes through unchanged: - Walks `MappingLeaf` and `SequenceLeaf` recursively. Float and - non-numeric leaves pass through — float normalisation is Package B. + - Python `int` / `bool` scalars — JAX's weak-typing rules let them + promote to whatever dtype the surrounding operation needs (e.g. + `discount_factor: 1` in a float-typed function). + - Float and non-numeric typed leaves — handled by a separate float- + normalisation pass. """ if isinstance(value, MappingLeaf): return MappingLeaf( @@ -154,7 +172,9 @@ def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 for i, v in enumerate(value.data) ] ) - if hasattr(value, "dtype") and np.issubdtype(value.dtype, np.integer): + if isinstance(value, (Array, np.ndarray)) and np.issubdtype( + value.dtype, np.integer + ): return safe_to_int32(value, name=name) return value diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index 282c39a7..2333743f 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -196,14 +196,9 @@ def _collect_unique_simulate_functions( ) unique[key] = (jax.jit(argmax_func), args, label) - # Dedup contract for `next_state` / `crtp`: pylcm's `process_regimes` - # builds these per regime (via `_build_next_state_vmapped` and the - # regime-specific transition-probs builder), so each regime ships a - # distinct callable object. Two regimes collide on the dedup key only - # when they truly share the same compiled program (and thus the same - # arg signature). The invariant is pinned by - # `test_simulate_functions_use_per_regime_callables` in - # `tests/regime_building/test_regime_processing.py`. + # `next_state` / `crtp` are keyed per-regime: each regime's lower-args + # depend on its own state-action shapes, so even when two regimes + # share a callable identity, their compiled programs are distinct. if not regime.terminal: args = _build_next_state_args( internal_regime=regime, @@ -211,7 +206,7 @@ def _collect_unique_simulate_functions( ages=ages, n_subjects=n_subjects, ) - key = ("next_state", _func_dedup_key(func=sf.next_state)) + key = ("next_state", regime_name, _func_dedup_key(func=sf.next_state)) func_keys[(regime_name, "next_state", None)] = key if key not in unique: # Re-wrap with `jax.jit`: when `fixed_params` are partialled @@ -230,7 +225,11 @@ def _collect_unique_simulate_functions( ages=ages, n_subjects=n_subjects, ) - key = ("crtp", _func_dedup_key(func=sf.compute_regime_transition_probs)) + key = ( + "crtp", + regime_name, + _func_dedup_key(func=sf.compute_regime_transition_probs), + ) func_keys[(regime_name, "crtp", None)] = key if key not in unique: unique[key] = ( diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index db1fec82..ba7cc39c 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -237,8 +237,9 @@ def draw_key_from_dict( """ regime_names = list(d) regime_transition_probs = jnp.array(list(d.values())).T - regime_ids = jnp.array( - [regime_names_to_ids[regime_name] for regime_name in regime_names] + regime_ids = jnp.asarray( + [regime_names_to_ids[regime_name] for regime_name in regime_names], + dtype=jnp.int32, ) def random_id( diff --git a/tests/regime_building/test_regime_processing.py b/tests/regime_building/test_regime_processing.py index 794390fb..b2def942 100644 --- a/tests/regime_building/test_regime_processing.py +++ b/tests/regime_building/test_regime_processing.py @@ -4,12 +4,14 @@ import jax.numpy as jnp import numpy as np import pandas as pd +import pytest from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal from lcm import Regime, categorical from lcm.ages import AgeGrid from lcm.grids import DiscreteGrid, LinSpacedGrid +from lcm.interfaces import InternalRegime from lcm.regime_building.processing import ( _rename_params_to_qnames, process_regimes, @@ -178,15 +180,9 @@ def wealth_constraint(wealth): assert got.index.is_unique -def test_simulate_functions_use_per_regime_callables(): - """Each non-terminal regime gets a distinct `next_state` / `crtp` callable. - - The simulate-AOT path in `lcm.simulation.compile` deduplicates by callable - identity for `next_state` and `compute_regime_transition_probs`. That is - only safe if `process_regimes` ships a fresh callable per regime — two - regimes sharing one callable would compile against the first regime's - state-action shapes and silently apply that program to the second. - """ +@pytest.fixture(name="two_non_terminal_internal_regimes") +def _two_non_terminal_internal_regimes() -> MappingProxyType[str, InternalRegime]: + """Two non-terminal regimes that share underlying user functions.""" def next_x(x): return x @@ -213,26 +209,30 @@ class TwoRegimeId: functions={"utility": lambda x: x}, active=lambda age: age >= 1, ) - - regimes = {"early": early, "late": late} - internal_regimes = process_regimes( - regimes=regimes, + return process_regimes( + regimes={"early": early, "late": late}, ages=AgeGrid(start=0, stop=2, step="Y"), regime_names_to_ids=MappingProxyType({"early": 0, "late": 1}), enable_jit=True, ) - early_next_state = internal_regimes["early"].simulate_functions.next_state - late_next_state = internal_regimes["late"].simulate_functions.next_state - assert id(early_next_state) != id(late_next_state) - - early_crtp = internal_regimes[ - "early" - ].simulate_functions.compute_regime_transition_probs - late_crtp = internal_regimes[ - "late" - ].simulate_functions.compute_regime_transition_probs - assert id(early_crtp) != id(late_crtp) + +@pytest.mark.parametrize( + "attr", + ["next_state", "compute_regime_transition_probs"], +) +def test_simulate_functions_use_per_regime_callables( + two_non_terminal_internal_regimes: MappingProxyType[str, InternalRegime], + attr: str, +) -> None: + """Two regimes built from shared user functions get distinct simulate callables.""" + early_func = getattr( + two_non_terminal_internal_regimes["early"].simulate_functions, attr + ) + late_func = getattr( + two_non_terminal_internal_regimes["late"].simulate_functions, attr + ) + assert id(early_func) != id(late_func) def test_rename_params_to_qnames_with_partial(): diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 29254a43..ccd47cae 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -9,8 +9,11 @@ """ import logging +import threading +from dataclasses import dataclass from typing import Any +import cloudpickle import jax.numpy as jnp import jax.stages import pytest @@ -153,14 +156,25 @@ def test_simulate_first_matching_call_populates_aot_cache() -> None: assert n_subjects in model._simulate_compile_cache -def test_simulate_warns_on_n_subjects_mismatch( +_DECLARED_N = 4 +_ACTUAL_N = 7 + + +@dataclass(frozen=True) +class _MismatchOutcome: + """Captured simulate-with-mismatch artefacts for assertion.""" + + warnings: list[logging.LogRecord] + model: Model + + +@pytest.fixture(name="mismatch_outcome") +def _mismatch_outcome( caplog: pytest.LogCaptureFixture, -) -> None: - """Mismatching size logs WARNING naming both N and M, falls back to lazy path.""" +) -> _MismatchOutcome: + """Run one mismatching `simulate(...)` and capture the WARNING records.""" n_periods = 3 - declared_n = 4 - actual_n = 7 - model = _build_test_model(n_periods=n_periods, n_subjects=declared_n) + model = _build_test_model(n_periods=n_periods, n_subjects=_DECLARED_N) params = get_params(n_periods=n_periods) period_to_regime_to_V_arr = model.solve(params=params) @@ -168,20 +182,45 @@ def test_simulate_warns_on_n_subjects_mismatch( model.simulate( params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, - initial_conditions=_build_initial_conditions(n_subjects=actual_n), + initial_conditions=_build_initial_conditions(n_subjects=_ACTUAL_N), ) - mismatch_warnings = [ + warnings = [ r for r in caplog.records if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() ] - assert len(mismatch_warnings) == 1 - msg = mismatch_warnings[0].getMessage() - assert str(declared_n) in msg - assert str(actual_n) in msg - # Cache is NOT populated for mismatching size — fallback path was taken. - assert actual_n not in model._simulate_compile_cache + return _MismatchOutcome(warnings=warnings, model=model) + + +def test_simulate_mismatch_emits_one_warning( + mismatch_outcome: _MismatchOutcome, +) -> None: + """A single mismatching call logs exactly one WARNING.""" + assert len(mismatch_outcome.warnings) == 1 + + +def test_simulate_mismatch_warning_names_declared_n( + mismatch_outcome: _MismatchOutcome, +) -> None: + """The mismatch warning message contains the declared `n_subjects`.""" + msg = mismatch_outcome.warnings[0].getMessage() + assert str(_DECLARED_N) in msg + + +def test_simulate_mismatch_warning_names_actual_n( + mismatch_outcome: _MismatchOutcome, +) -> None: + """The mismatch warning message contains the actual `n_subjects`.""" + msg = mismatch_outcome.warnings[0].getMessage() + assert str(_ACTUAL_N) in msg + + +def test_simulate_mismatch_does_not_populate_cache( + mismatch_outcome: _MismatchOutcome, +) -> None: + """A mismatching `n_subjects` falls back to the lazy path — no cache entry.""" + assert _ACTUAL_N not in mismatch_outcome.model._simulate_compile_cache def test_simulate_warns_only_once_per_mismatching_size( @@ -189,12 +228,10 @@ def test_simulate_warns_only_once_per_mismatching_size( ) -> None: """Two calls with the same mismatching size produce only one WARNING.""" n_periods = 3 - declared_n = 4 - actual_n = 7 - model = _build_test_model(n_periods=n_periods, n_subjects=declared_n) + model = _build_test_model(n_periods=n_periods, n_subjects=_DECLARED_N) params = get_params(n_periods=n_periods) period_to_regime_to_V_arr = model.solve(params=params) - initial_conditions = _build_initial_conditions(n_subjects=actual_n) + initial_conditions = _build_initial_conditions(n_subjects=_ACTUAL_N) with caplog.at_level(logging.WARNING, logger="lcm"): model.simulate( @@ -214,3 +251,36 @@ def test_simulate_warns_only_once_per_mismatching_size( if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() ] assert len(mismatch_warnings) == 1 + + +def test_unpickled_model_can_simulate_with_aot() -> None: + """A cloudpickle round-tripped `Model` still drives `simulate(...)` with AOT.""" + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + initial_conditions = _build_initial_conditions(n_subjects=n_subjects) + + # Populate the AOT cache before pickling — confirms __getstate__ drops it. + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + assert n_subjects in model._simulate_compile_cache + + restored = cloudpickle.loads(cloudpickle.dumps(model)) + + # The restored Model starts with empty AOT state and a fresh lock. + assert dict(restored._simulate_compile_cache) == {} + assert restored._warned_n_subjects == set() + assert isinstance(restored._simulate_compile_lock, type(threading.Lock())) + + # Simulate works post-unpickle and re-populates the cache for that size. + restored.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + assert n_subjects in restored._simulate_compile_cache diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 3f903ebb..43894cbd 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -7,19 +7,31 @@ from lcm.dtypes import safe_to_int32 -def test_safe_to_int32_casts_python_int_in_range() -> None: - """A Python int within int32 range becomes a `jnp.int32` 0-d array.""" - out = safe_to_int32(7, name="x") +@pytest.mark.parametrize( + "value", + [7, np.asarray([0, 1, -3], dtype=np.int64)], + ids=["python-int", "int64-array"], +) +def test_safe_to_int32_returns_int32(value: object) -> None: + """`safe_to_int32` returns a `jnp.int32` array for any in-range int input.""" + out = safe_to_int32(value, name="x") assert out.dtype == jnp.int32 - assert int(out) == 7 -def test_safe_to_int32_casts_int64_array_in_range() -> None: - """An int64 array within int32 range becomes int32 with the same values.""" - arr = jnp.asarray([0, 1, -3], dtype=jnp.int64) - out = safe_to_int32(arr, name="x") - assert out.dtype == jnp.int32 - np.testing.assert_array_equal(np.asarray(out), [0, 1, -3]) +@pytest.mark.parametrize( + ("value", "expected"), + [ + (7, 7), + (np.asarray([0, 1, -3], dtype=np.int64), [0, 1, -3]), + ], + ids=["python-int", "int64-array"], +) +def test_safe_to_int32_preserves_in_range_values( + value: object, expected: object +) -> None: + """`safe_to_int32` preserves element values for in-range inputs.""" + out = safe_to_int32(value, name="x") + np.testing.assert_array_equal(np.asarray(out), expected) def test_safe_to_int32_raises_on_python_int_overflow() -> None: diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index 885134cf..d174d515 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -4,12 +4,14 @@ import jax.numpy as jnp import numpy as np +import pandas as pd import pytest from lcm import Model from lcm.ages import AgeGrid from lcm.params import MappingLeaf from lcm.params.processing import process_params +from lcm.params.sequence_leaf import SequenceLeaf from lcm.simulation.initial_conditions import ( MISSING_CAT_CODE, build_initial_states, @@ -120,7 +122,8 @@ def test_process_params_int_array_overflow_raises_with_qualified_name() -> None: ) -def test_process_params_casts_int_array_inside_mapping_leaf_to_int32() -> None: +@pytest.mark.parametrize("key", ["low", "high"]) +def test_process_params_casts_int_array_inside_mapping_leaf_to_int32(key: str) -> None: """`MappingLeaf` int arrays land at `jnp.int32` after params processing.""" template = MappingProxyType( {"regime_a": MappingProxyType({"sched": "MappingLeaf"})} @@ -141,9 +144,40 @@ def test_process_params_casts_int_array_inside_mapping_leaf_to_int32() -> None: params_template=template, # ty: ignore[invalid-argument-type] ) - leaf = out["regime_a"]["sched"] - assert leaf.data["low"].dtype == jnp.int32 # ty: ignore[unresolved-attribute] - assert leaf.data["high"].dtype == jnp.int32 # ty: ignore[unresolved-attribute] + assert ( + out["regime_a"]["sched"].data[key].dtype # ty: ignore[unresolved-attribute] + == jnp.int32 + ) + + +@pytest.mark.parametrize("index", [0, 1]) +def test_process_params_casts_int_array_inside_sequence_leaf_to_int32( + index: int, +) -> None: + """`SequenceLeaf` int arrays land at `jnp.int32` after params processing.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "SequenceLeaf"})} + ) + user_params = { + "regime_a": { + "sched": SequenceLeaf( + [ + jnp.asarray([0, 1], dtype=jnp.int64), + jnp.asarray([10, 20], dtype=jnp.int64), + ] + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + assert ( + out["regime_a"]["sched"].data[index].dtype # ty: ignore[unresolved-attribute] + == jnp.int32 + ) def test_simulate_accepts_int64_regime_initial_condition_and_round_trips() -> None: @@ -186,4 +220,4 @@ def test_simulate_accepts_int64_regime_initial_condition_and_round_trips() -> No initial_conditions=initial_conditions_int64, ).to_dataframe() - assert df_int64["regime"].equals(df_int32["regime"]) + pd.testing.assert_frame_equal(df_int64, df_int32, check_dtype=False) From 2f19aa16c6e2bc057c529b790959ed247c4e443a Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 06:03:20 +0200 Subject: [PATCH 74/80] compile: free lower-args after lowering, free Lowered after compile `compile_all_simulate_functions` previously held every entry's concrete lower-args (V-shaped templates, per-regime subject-state / action zeros, regime-params view) in `unique[key][1]` until the function returned, and every `Lowered` HLO module in `lowered[key]` until the slowest parallel compile finished. With per-target next-state DAGs and an AOT cache that pins every kernel alive, the overlap of in-flight lower-args + lowered HLO + compiled kernels was the dominant contributor to the simulate-side compile-phase peak. Drop the args from `unique[key]` immediately after each lowering, and `del lowered[k]` as soon as the corresponding `Compiled` lands in `compiled`. The dedup keys, the parallel pool semantics, and the swap-in step are unchanged; existing tests (n_subjects mismatch, unpickled-model AOT round-trip, dtype invariants) remain green. --- src/lcm/simulation/compile.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py index 2333743f..ad8caeeb 100644 --- a/src/lcm/simulation/compile.py +++ b/src/lcm/simulation/compile.py @@ -98,7 +98,12 @@ def compile_all_simulate_functions( start = time.monotonic() # `func` is a `jax.jit`-wrapped callable; ty sees only the abstract # Callable type, so it can't see `.lower(...)`. - lowered[key] = func.lower(**args) # ty: ignore[unresolved-attribute] + lowered[key] = func.lower(**args) # ty: ignore[unresolved-attribute, invalid-argument-type] + # Drop the concrete lower-args once the `Lowered` object has captured + # its abstract values. This releases V-shaped templates, per-regime + # subject-state/action zeros, and the regime-params view before the + # parallel compile pool starts piling Compiled kernels onto the heap. + unique[key] = (func, None, label) logger.info( " lowered in %s", format_duration(seconds=time.monotonic() - start) ) @@ -129,6 +134,11 @@ def _compile_and_log( for future in as_completed(futures): k, c = future.result() compiled[k] = c + # Release the HLO module held by the `Lowered` object now that + # its `Compiled` counterpart is in `compiled`; otherwise every + # lowered intermediate stays resident until the slowest compile + # finishes. + del lowered[k] return _swap_in_compiled( internal_regimes=internal_regimes, @@ -145,7 +155,7 @@ def _collect_unique_simulate_functions( n_subjects: int, regime_V_shapes: dict[RegimeName, tuple[int, ...]], ) -> tuple[ - dict[Hashable, tuple[Callable, dict, str]], + dict[Hashable, tuple[Callable, dict | None, str]], dict[tuple[RegimeName, str, int | None], Hashable], ]: """Walk every regime/period and dedup the simulate functions to compile. @@ -156,7 +166,7 @@ def _collect_unique_simulate_functions( separate compiled programs whose signature matches what runtime actually dispatches. """ - unique: dict[Hashable, tuple[Callable, dict, str]] = {} + unique: dict[Hashable, tuple[Callable, dict | None, str]] = {} func_keys: dict[tuple[RegimeName, str, int | None], Hashable] = {} for regime_name, regime in internal_regimes.items(): From 143a3ae6369da1c6f0714d200f99ff0f16b29f5f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 07:39:40 +0200 Subject: [PATCH 75/80] solve_brute: rename diag_params to effective_regime_params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The local in `_raise_at` holds the regime's params after merging `resolved_fixed_params` in by hand — the same merge the live solve loop performs implicitly via partialled closures. `effective_regime_params` captures that intent; `diag_params` only said "I'm passing this to the diagnostic call". Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/solution/solve_brute.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index ba792609..77e63e24 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -516,7 +516,7 @@ def _raise_at( # closures, but we have to do it by hand here to call the diagnostic # directly. Same merge order as `interfaces.state_action_space` and # `simulation.result`. - diag_params = MappingProxyType( + effective_regime_params = MappingProxyType( {**internal_regime.resolved_fixed_params, **regime_params} ) state_action_space = internal_regime.state_action_space(regime_params=regime_params) @@ -538,7 +538,7 @@ def _raise_at( compute_intermediates=compute_intermediates, state_action_space=state_action_space, next_regime_to_V_arr=next_regime_to_V_arr, - internal_params=diag_params, + internal_params=effective_regime_params, period=row.period, ) From 1e26926f5cfa6aafda0347384ff552bd7877e4b7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 10:20:38 +0200 Subject: [PATCH 76/80] process_params: cast Python int leaves to jnp.int32 Drops the "Python scalar pass-through to keep JAX weak-typing semantics" line from `_cast_int_leaves_to_int32`: a Python `int` arriving at a DAG function as a Python scalar now becomes `jnp.int32(value)` so downstream code sees a JAX-typed scalar rather than a Python int that JAX would promote per call site. This finishes the "no Python scalars inside JIT'd loops" goal that motivated the dtype-barrier work. `bool` is short-circuited (the float-side cast pass on #345 will handle it; bool is a Python `int` subclass so the bool branch must come before the int one). Module + function docstrings refreshed accordingly. Test `test_process_params_passes_python_int_through_for_jax_weak_typing` renamed and flipped to assert `dtype == jnp.int32`. `ScalarInt` keeps the `int | Int32[Scalar, ""]` union: tightening to JAX-only cascades into 13 call-site mismatches at internal Python metadata sites (`n_points`, `period`) that legitimately pass Python `int` outside the JIT'd DAG. Tightening the alias is a separate audit and follow-up. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/params/processing.py | 38 ++++++++++++++++++------------ tests/test_int_dtype_invariants.py | 10 ++++---- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index 64f02cc2..ae937c44 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -1,10 +1,11 @@ """Process user-provided params into internal params. `process_params` resolves user-supplied parameters against the model's -template, then runs a boundary-cast pass that normalises typed integer -leaves to `jnp.int32` (and integer arrays inside `MappingLeaf` / -`SequenceLeaf`). Out-of-range values surface as `ValueError` with the -offending leaf's qualified name. +template, then runs a boundary-cast pass that normalises every integer +leaf — Python `int`, typed JAX integer arrays, numpy integer arrays, +and integers inside `MappingLeaf` / `SequenceLeaf` — to `jnp.int32`. +Out-of-range values surface as `ValueError` with the offending leaf's +qualified name. """ from collections.abc import Mapping @@ -46,11 +47,11 @@ def process_params( - Regime level: `{"regime_0": {"arg_0": 0.0}}` — propagates within regime_0 - Function level: `{"regime_0": {"func": {"arg_0": 0.0}}}` — direct specification - The output always matches the params_template skeleton. Typed integer - arrays in the user input — including those inside `MappingLeaf` / - `SequenceLeaf` containers — are cast to `jnp.int32` so the AOT signature - is stable across calls; Python scalars pass through to keep JAX weak- - typing semantics. + The output always matches the params_template skeleton. Every integer + leaf — Python `int`, typed JAX or numpy integer arrays, and integers + inside `MappingLeaf` / `SequenceLeaf` — is cast to `jnp.int32` so the + AOT signature is stable across calls. Python `bool` and float leaves + are handled by the float-side cast pass. Args: params: User-provided parameters dictionary. @@ -141,20 +142,21 @@ def broadcast_to_template( def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 - """Normalise typed integer arrays in a params value to `jnp.int32`. + """Normalise integer leaves in a params value to `jnp.int32`. Casts: + - Python `int` scalars — to `jnp.int32` so the DAG sees a JAX scalar + with a pinned dtype rather than a Python int that JAX would + otherwise promote per call site. - Typed JAX or numpy integer arrays (`jnp.array(..., dtype=jnp.int64)`, - `np.array(...)`) — strongly typed by JAX, so cast to `int32` to keep - the AOT signature stable. + `np.array(...)`) — cast to `int32` to keep the AOT signature stable. - Integer leaves inside `MappingLeaf` / `SequenceLeaf` — recurse. Passes through unchanged: - - Python `int` / `bool` scalars — JAX's weak-typing rules let them - promote to whatever dtype the surrounding operation needs (e.g. - `discount_factor: 1` in a float-typed function). + - Python `bool` scalars — handled by the float-side cast pass once + it lands. - Float and non-numeric typed leaves — handled by a separate float- normalisation pass. """ @@ -172,6 +174,12 @@ def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 for i, v in enumerate(value.data) ] ) + # `bool` is a subclass of `int`, so test for it first and short-circuit + # — bool handling lands with the float-side cast pass, not here. + if isinstance(value, bool): + return value + if isinstance(value, int): + return safe_to_int32(value, name=name) if isinstance(value, (Array, np.ndarray)) and np.issubdtype( value.dtype, np.integer ): diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index d174d515..9e9c6011 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -79,8 +79,8 @@ def test_update_states_for_subjects_preserves_storage_dtype() -> None: assert updated["work__health"].dtype == jnp.int32 -def test_process_params_passes_python_int_through_for_jax_weak_typing() -> None: - """Python `int` params stay weak-typed so JAX promotes them per call site.""" +def test_process_params_casts_python_int_to_int32() -> None: + """A Python `int` param leaf is cast to `jnp.int32`.""" template = MappingProxyType({"regime_a": MappingProxyType({"final_age": "int"})}) user_params = {"regime_a": {"final_age": 65}} @@ -89,9 +89,9 @@ def test_process_params_passes_python_int_through_for_jax_weak_typing() -> None: params_template=template, # ty: ignore[invalid-argument-type] ) - # Python int stays Python int; JAX weak-typing handles promotion at JIT. - assert out["regime_a"]["final_age"] == 65 - assert isinstance(out["regime_a"]["final_age"], int) + final_age = out["regime_a"]["final_age"] + assert int(final_age) == 65 + assert final_age.dtype == jnp.int32 # ty: ignore[unresolved-attribute] def test_process_params_casts_int64_array_to_int32() -> None: From 14f81fc723df3fc1d3d7313b251bd9741e82d5ce Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 17:19:46 +0200 Subject: [PATCH 77/80] solve: kick off simulate AOT compile in a background thread MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When `Model(n_subjects=N)` is set, simulate-side XLA compilation used to run lazily on the first matching `simulate(...)` call — strictly after `solve(...)` returned. On production aca-baseline that adds several minutes to the end-to-end wall clock for nothing: solve is GPU-bound, simulate compile is CPU-bound XLA work, so they overlap trivially. Add `_maybe_start_simulate_compile_async` and call it from `solve(...)` right after parameters are processed. It spawns a single-worker `ThreadPoolExecutor` that runs `compile_all_simulate_functions` in the background and parks the result on `_simulate_compile_future`. `_resolve_simulate_internal_regimes` awaits the future before populating the cache, so the lazy fallback path (no `solve` call, direct `simulate(...)`) still works. `__getstate__` / `__setstate__` drop the future on the way out and reset to `None` on the way in — `concurrent.futures.Future` is tied to its originating thread pool and can't survive a process boundary. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 76 +++++++++++++++++++++++++-- tests/simulation/test_simulate_aot.py | 17 ++++++ 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/src/lcm/model.py b/src/lcm/model.py index 1897272f..a44abb83 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -4,6 +4,7 @@ import logging import threading from collections.abc import Mapping +from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path from types import MappingProxyType @@ -117,8 +118,21 @@ class Model: _warned_n_subjects: set[int] """Mismatching `actual_n_subjects` already warned about (one warning each).""" + _simulate_compile_future: ( + Future[MappingProxyType[RegimeName, InternalRegime]] | None + ) + """Pending background AOT compile started by `solve(...)`, or `None`. + + `solve(...)` kicks off `compile_all_simulate_functions` in a single + background thread so XLA compilation overlaps with the GPU-bound + backward induction. `simulate(...)` awaits the future before + dispatching the AOT-compiled program. Cleared after the result lands + in `_simulate_compile_cache`. + """ + _simulate_compile_lock: threading.Lock - """Serialises mutations of `_simulate_compile_cache` and `_warned_n_subjects`. + """Serialises mutations of `_simulate_compile_cache`, `_warned_n_subjects`, + and `_simulate_compile_future`. The check-then-set on each container is held under this lock. The consequent `log.warning` call sits outside the lock so concurrent @@ -166,6 +180,7 @@ def __init__( self.n_subjects = n_subjects self._simulate_compile_cache = {} self._warned_n_subjects = set() + self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() validate_model_inputs( @@ -201,13 +216,16 @@ def __getstate__(self) -> dict[str, object]: Drops `_simulate_compile_lock` (a `threading.Lock`, not pickleable), `_simulate_compile_cache` (compiled XLA programs that can't survive - a process boundary), and `_warned_n_subjects` (its companion set). + a process boundary), `_warned_n_subjects` (its companion set), and + `_simulate_compile_future` (a `Future` tied to the originating thread + pool). `__setstate__` restores all three to their fresh state. """ state = self.__dict__.copy() state.pop("_simulate_compile_lock", None) state.pop("_simulate_compile_cache", None) state.pop("_warned_n_subjects", None) + state.pop("_simulate_compile_future", None) return state def __setstate__(self, state: dict[str, object]) -> None: @@ -215,6 +233,7 @@ def __setstate__(self, state: dict[str, object]) -> None: self.__dict__.update(state) self._simulate_compile_cache = {} self._warned_n_subjects = set() + self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() def get_params_template(self) -> UserFacingParamsTemplate: @@ -278,6 +297,11 @@ def solve( internal_params=internal_params, ages=self.ages, ) + self._maybe_start_simulate_compile_async( + internal_params=internal_params, + max_compilation_workers=max_compilation_workers, + logger=get_logger(log_level=log_level), + ) try: period_to_regime_to_V_arr = solve( internal_params=internal_params, @@ -308,6 +332,41 @@ def solve( ) return period_to_regime_to_V_arr + def _maybe_start_simulate_compile_async( + self, + *, + internal_params: InternalParams, + max_compilation_workers: int | None, + logger: logging.Logger, + ) -> None: + """Spawn `compile_all_simulate_functions` in a background thread. + + Called from `solve(...)` so the simulate-side XLA compilation runs in + parallel with the GPU-bound backward induction. No-op when + `n_subjects is None`, when the cache for this size is already + populated, or when a compile is already in flight. + """ + if self.n_subjects is None: + return + with self._simulate_compile_lock: + if self.n_subjects in self._simulate_compile_cache: + return + if self._simulate_compile_future is not None: + return + executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="lcm-simulate-compile" + ) + self._simulate_compile_future = executor.submit( + compile_all_simulate_functions, + internal_regimes=self.internal_regimes, + internal_params=internal_params, + ages=self.ages, + n_subjects=self.n_subjects, + max_compilation_workers=max_compilation_workers, + logger=logger, + ) + executor.shutdown(wait=False) + def _resolve_simulate_internal_regimes( self, *, @@ -326,7 +385,8 @@ def _resolve_simulate_internal_regimes( `internal_regimes` and log a warning the first time each mismatching size is seen. - `actual_n_subjects == n_subjects`: return the cached AOT-compiled - regimes, building them on the first call. + regimes. If `solve(...)` started a background compile, await it + here; otherwise compile synchronously. """ if self.n_subjects is None: return self.internal_regimes @@ -343,6 +403,16 @@ def _resolve_simulate_internal_regimes( self.n_subjects, ) return self.internal_regimes + with self._simulate_compile_lock: + if self.n_subjects in self._simulate_compile_cache: + return self._simulate_compile_cache[self.n_subjects] + future = self._simulate_compile_future + if future is not None: + compiled = future.result() + with self._simulate_compile_lock: + self._simulate_compile_cache[self.n_subjects] = compiled + self._simulate_compile_future = None + return compiled with self._simulate_compile_lock: if self.n_subjects not in self._simulate_compile_cache: self._simulate_compile_cache[self.n_subjects] = ( diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index ccd47cae..660d99db 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -156,6 +156,23 @@ def test_simulate_first_matching_call_populates_aot_cache() -> None: assert n_subjects in model._simulate_compile_cache +def test_solve_with_n_subjects_kicks_off_background_simulate_compile() -> None: + """`solve(...)` spawns the simulate AOT compile in the background. + + The follow-on `simulate(...)` then awaits the in-flight `Future` instead + of compiling synchronously, so XLA compilation overlaps with the + GPU-bound backward induction in production. + """ + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + + assert model._simulate_compile_future is None + model.solve(params=params) + assert model._simulate_compile_future is not None + + _DECLARED_N = 4 _ACTUAL_N = 7 From 8df832a6c2e56b29f6dd32d29885fdcffd4bada4 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 18:17:08 +0200 Subject: [PATCH 78/80] simulate orchestrates simulate-AOT compile, not solve MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit solve() no longer touches simulate-side compile state. simulate() is the sole driver: spawns the AOT compile in a background thread when n_subjects is set and the batch shape matches, then runs solve (if period_to_regime_to_V_arr is None) and awaits the future at the state-action-space dispatch point. Both public methods share an internal _solve_compiled() body for the snapshot/error handling. Drops _simulate_compile_future from instance state — the future lives in a local variable on the simulate() stack, so there's no per-process state to gate against. The lock keeps protecting _simulate_compile_cache and _warned_n_subjects; the rest of the "maybe spawn" logic collapses into a single inline check at the simulate() call site. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 176 ++++++++++++-------------- tests/simulation/test_simulate_aot.py | 13 +- 2 files changed, 84 insertions(+), 105 deletions(-) diff --git a/src/lcm/model.py b/src/lcm/model.py index a44abb83..bc05f0b4 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -118,21 +118,9 @@ class Model: _warned_n_subjects: set[int] """Mismatching `actual_n_subjects` already warned about (one warning each).""" - _simulate_compile_future: ( - Future[MappingProxyType[RegimeName, InternalRegime]] | None - ) - """Pending background AOT compile started by `solve(...)`, or `None`. - - `solve(...)` kicks off `compile_all_simulate_functions` in a single - background thread so XLA compilation overlaps with the GPU-bound - backward induction. `simulate(...)` awaits the future before - dispatching the AOT-compiled program. Cleared after the result lands - in `_simulate_compile_cache`. - """ - _simulate_compile_lock: threading.Lock - """Serialises mutations of `_simulate_compile_cache`, `_warned_n_subjects`, - and `_simulate_compile_future`. + """Serialises mutations of `_simulate_compile_cache` and + `_warned_n_subjects`. The check-then-set on each container is held under this lock. The consequent `log.warning` call sits outside the lock so concurrent @@ -180,7 +168,6 @@ def __init__( self.n_subjects = n_subjects self._simulate_compile_cache = {} self._warned_n_subjects = set() - self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() validate_model_inputs( @@ -216,16 +203,13 @@ def __getstate__(self) -> dict[str, object]: Drops `_simulate_compile_lock` (a `threading.Lock`, not pickleable), `_simulate_compile_cache` (compiled XLA programs that can't survive - a process boundary), `_warned_n_subjects` (its companion set), and - `_simulate_compile_future` (a `Future` tied to the originating thread - pool). + a process boundary), and `_warned_n_subjects` (its companion set). `__setstate__` restores all three to their fresh state. """ state = self.__dict__.copy() state.pop("_simulate_compile_lock", None) state.pop("_simulate_compile_cache", None) state.pop("_warned_n_subjects", None) - state.pop("_simulate_compile_future", None) return state def __setstate__(self, state: dict[str, object]) -> None: @@ -233,7 +217,6 @@ def __setstate__(self, state: dict[str, object]) -> None: self.__dict__.update(state) self._simulate_compile_cache = {} self._warned_n_subjects = set() - self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() def get_params_template(self) -> UserFacingParamsTemplate: @@ -297,17 +280,34 @@ def solve( internal_params=internal_params, ages=self.ages, ) - self._maybe_start_simulate_compile_async( + return self._solve_compiled( internal_params=internal_params, + params=params, + log=get_logger(log_level=log_level), + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, max_compilation_workers=max_compilation_workers, - logger=get_logger(log_level=log_level), ) + + def _solve_compiled( + self, + *, + internal_params: InternalParams, + params: UserParams, + log: logging.Logger, + log_level: LogLevel, + log_path: str | Path | None, + log_keep_n_latest: int, + max_compilation_workers: int | None, + ) -> MappingProxyType[int, MappingProxyType[RegimeName, FloatND]]: + """Run backward induction, persisting a snapshot on debug or NaN failure.""" try: period_to_regime_to_V_arr = solve( internal_params=internal_params, ages=self.ages, internal_regimes=self.internal_regimes, - logger=get_logger(log_level=log_level), + logger=log, enable_jit=self.enable_jit, max_compilation_workers=max_compilation_workers, ) @@ -332,61 +332,55 @@ def solve( ) return period_to_regime_to_V_arr - def _maybe_start_simulate_compile_async( + def _spawn_simulate_compile( self, *, + n_subjects: int, internal_params: InternalParams, max_compilation_workers: int | None, logger: logging.Logger, - ) -> None: - """Spawn `compile_all_simulate_functions` in a background thread. + ) -> Future[MappingProxyType[RegimeName, InternalRegime]]: + """Submit `compile_all_simulate_functions` to a single-thread executor. - Called from `solve(...)` so the simulate-side XLA compilation runs in - parallel with the GPU-bound backward induction. No-op when - `n_subjects is None`, when the cache for this size is already - populated, or when a compile is already in flight. + Caller decides whether to spawn (`n_subjects` set, batch shape + matches, no cache hit). The returned `Future` runs in parallel with + whatever the caller does next — typically `_solve_compiled(...)`. """ - if self.n_subjects is None: - return - with self._simulate_compile_lock: - if self.n_subjects in self._simulate_compile_cache: - return - if self._simulate_compile_future is not None: - return - executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="lcm-simulate-compile" - ) - self._simulate_compile_future = executor.submit( - compile_all_simulate_functions, - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=self.n_subjects, - max_compilation_workers=max_compilation_workers, - logger=logger, - ) - executor.shutdown(wait=False) + executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="lcm-simulate-compile" + ) + future = executor.submit( + compile_all_simulate_functions, + internal_regimes=self.internal_regimes, + internal_params=internal_params, + ages=self.ages, + n_subjects=n_subjects, + max_compilation_workers=max_compilation_workers, + logger=logger, + ) + executor.shutdown(wait=False) + return future def _resolve_simulate_internal_regimes( self, *, + compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None, actual_n_subjects: int, - internal_params: InternalParams, log: logging.Logger, - max_compilation_workers: int | None, ) -> MappingProxyType[RegimeName, InternalRegime]: """Return internal_regimes to use for simulate; AOT cache when matching. - Three dispatch cases: + Dispatch by `n_subjects` and batch-shape match: - `n_subjects is None`: return the original `internal_regimes` (purely lazy path). - - `actual_n_subjects != n_subjects`: return the original - `internal_regimes` and log a warning the first time each - mismatching size is seen. - - `actual_n_subjects == n_subjects`: return the cached AOT-compiled - regimes. If `solve(...)` started a background compile, await it - here; otherwise compile synchronously. + - `actual_n_subjects != n_subjects`: warn once per mismatching size, + return the original `internal_regimes`. + - `actual_n_subjects == n_subjects`, `compile_future is not None`: + await it and cache the result. + - `actual_n_subjects == n_subjects`, `compile_future is None`: cache + must already hold the entry (caller spawned only on cache miss); + return the cached compiled regimes. """ if self.n_subjects is None: return self.internal_regimes @@ -403,28 +397,12 @@ def _resolve_simulate_internal_regimes( self.n_subjects, ) return self.internal_regimes - with self._simulate_compile_lock: - if self.n_subjects in self._simulate_compile_cache: - return self._simulate_compile_cache[self.n_subjects] - future = self._simulate_compile_future - if future is not None: - compiled = future.result() + if compile_future is not None: + compiled = compile_future.result() with self._simulate_compile_lock: self._simulate_compile_cache[self.n_subjects] = compiled - self._simulate_compile_future = None return compiled with self._simulate_compile_lock: - if self.n_subjects not in self._simulate_compile_cache: - self._simulate_compile_cache[self.n_subjects] = ( - compile_all_simulate_functions( - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=self.n_subjects, - max_compilation_workers=max_compilation_workers, - logger=log, - ) - ) return self._simulate_compile_cache[self.n_subjects] def simulate( @@ -507,33 +485,35 @@ def simulate( ages=self.ages, ) log = get_logger(log_level=log_level) - if period_to_regime_to_V_arr is None: - try: - period_to_regime_to_V_arr = solve( + actual_n_subjects = len(next(iter(initial_conditions.values()))) + n_subjects = self.n_subjects + compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None = ( + None + ) + if n_subjects is not None and n_subjects == actual_n_subjects: + with self._simulate_compile_lock: + needs_compile = n_subjects not in self._simulate_compile_cache + if needs_compile: + compile_future = self._spawn_simulate_compile( + n_subjects=n_subjects, internal_params=internal_params, - ages=self.ages, - internal_regimes=self.internal_regimes, - logger=log, - enable_jit=self.enable_jit, max_compilation_workers=max_compilation_workers, + logger=log, ) - except InvalidValueFunctionError as exc: - if log_path is not None and exc.partial_solution is not None: - snap_dir = save_solve_snapshot( - model=self, - params=params, - period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] - log_path=Path(log_path), - log_keep_n_latest=log_keep_n_latest, - ) - exc.add_note(f"Snapshot saved to {snap_dir}") - raise - actual_n_subjects = len(next(iter(initial_conditions.values()))) + if period_to_regime_to_V_arr is None: + period_to_regime_to_V_arr = self._solve_compiled( + internal_params=internal_params, + params=params, + log=log, + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, + max_compilation_workers=max_compilation_workers, + ) simulate_internal_regimes = self._resolve_simulate_internal_regimes( + compile_future=compile_future, actual_n_subjects=actual_n_subjects, - internal_params=internal_params, log=log, - max_compilation_workers=max_compilation_workers, ) result = simulate( internal_params=internal_params, diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 660d99db..77f01814 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -156,21 +156,20 @@ def test_simulate_first_matching_call_populates_aot_cache() -> None: assert n_subjects in model._simulate_compile_cache -def test_solve_with_n_subjects_kicks_off_background_simulate_compile() -> None: - """`solve(...)` spawns the simulate AOT compile in the background. +def test_solve_does_not_populate_simulate_compile_cache() -> None: + """`solve(...)` does not touch simulate-side compile state. - The follow-on `simulate(...)` then awaits the in-flight `Future` instead - of compiling synchronously, so XLA compilation overlaps with the - GPU-bound backward induction in production. + Simulate AOT compilation is driven entirely by `simulate(...)`; calling + `solve(...)` alone leaves `_simulate_compile_cache` empty. """ n_periods = 3 n_subjects = 4 model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) params = get_params(n_periods=n_periods) - assert model._simulate_compile_future is None model.solve(params=params) - assert model._simulate_compile_future is not None + + assert dict(model._simulate_compile_cache) == {} _DECLARED_N = 4 From a7a87cd8effb3982d674c65d0f00e955e500c235 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 9 May 2026 15:28:33 +0200 Subject: [PATCH 79/80] simulate: swap AOT-compiled regimes for lazy ones on the result MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `SimulationResult.to_pickle()` (and any cloudpickle.dumps on the result) hit `cannot pickle 'jaxlib._jax.LoadedExecutable'` when the result carried the AOT-compiled `internal_regimes`. The compiled callables (`argmax_and_max_Q_over_a`, `next_state`, `compute_regime_transition_probs`) wrap a `LoadedExecutable` that can't survive a process boundary. `to_dataframe` only reads `simulate_functions.functions / constraints / transitions / stochastic_transition_names` — none of which the AOT pass replaces. So after `simulate(...)` runs, the result has no use for the compiled callables: `model.simulate()` swaps them out for the lazy `self.internal_regimes` before returning. Add a TDD test that round-trips the result through cloudpickle under `n_subjects` matching, which is the failure mode pytask hit on HPC. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 7 +++++++ tests/simulation/test_simulate_aot.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/lcm/model.py b/src/lcm/model.py index bc05f0b4..e31f2c37 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -526,6 +526,13 @@ def simulate( simulation_output_dtypes=self.simulation_output_dtypes, seed=seed, ) + # AOT-compiled regimes carry `jax.stages.Compiled` callables that + # wrap an unpicklable `LoadedExecutable`. `to_dataframe` only reads + # the lazy DAG functions / constraints / transitions on + # `simulate_functions`, never the compiled callables — so swap in + # the lazy regimes to keep the result cloudpickle-safe. + if simulate_internal_regimes is not self.internal_regimes: + result._internal_regimes = self.internal_regimes # noqa: SLF001 if log_level == "debug" and log_path is not None: save_simulate_snapshot( model=self, diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 77f01814..698256e4 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -269,6 +269,30 @@ def test_simulate_warns_only_once_per_mismatching_size( assert len(mismatch_warnings) == 1 +def test_simulate_result_pickles_when_n_subjects_matches() -> None: + """`simulate(...)` returns a result that round-trips through cloudpickle. + + With `n_subjects` matching the batch shape, the simulate path runs + AOT-compiled callables that wrap `LoadedExecutable` (unpicklable). + `to_dataframe` doesn't need those callables, so the returned result + must carry the lazy regimes — otherwise downstream pickling + (e.g. pytask handing the result to the next task) fails. + """ + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + + result = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=n_subjects), + ) + + restored = cloudpickle.loads(cloudpickle.dumps(result)) + assert restored.n_subjects == n_subjects + + def test_unpickled_model_can_simulate_with_aot() -> None: """A cloudpickle round-tripped `Model` still drives `simulate(...)` with AOT.""" n_periods = 3 From ebb5e56c365b8f824931c2e32116f0f3300909c7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 10 May 2026 13:14:29 +0200 Subject: [PATCH 80/80] simulate: AOT-compile blocks before solve to avoid contention `simulate(...)` previously kicked off `compile_all_simulate_functions` in a single-thread background executor and ran solve concurrently; `_resolve_simulate_internal_regimes` then awaited the future at the state-action-space dispatch point. With realistic worker counts the parallel XLA compile pool stayed busy through a substantial chunk of the backward-induction loop, contending for CPU and XLA front-end resources and stretching mid-loop ages by an order of magnitude. Drop the future / executor entirely. simulate() now calls compile_all_simulate_functions inline before _solve_compiled, so the entire AOT compile (including its own internal worker pool) finishes before backward induction starts. Same total compile work; predictable timing; lower transient host-RAM peak because the AOT pool's intermediate Lowered objects are released before solve allocates its per-period V buffers. _resolve_simulate_internal_regimes loses its compile_future parameter and only consults the cache. _spawn_simulate_compile is gone, as are the `Future` and `ThreadPoolExecutor` imports. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/model.py | 62 ++++++++++-------------------------------------- 1 file changed, 12 insertions(+), 50 deletions(-) diff --git a/src/lcm/model.py b/src/lcm/model.py index e31f2c37..374d40bf 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -4,7 +4,6 @@ import logging import threading from collections.abc import Mapping -from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path from types import MappingProxyType @@ -96,7 +95,8 @@ class Model: - `None`: purely lazy behaviour, no AOT. - First `simulate(...)` with `actual_n == n_subjects`: AOT-compiles all - simulate functions for that batch shape in parallel and caches them. + simulate functions for that batch shape (blocks before solve runs) + and caches them. - Subsequent `simulate(...)` with the same matching size: reuses the cached compiled programs. - `simulate(...)` with a mismatching size: warns once per size and falls @@ -157,8 +157,8 @@ def __init__( already has a conflicting entry. n_subjects: Expected simulate batch size; if set, the first matching `simulate(...)` call AOT-compiles all simulate functions for - batch shape `n_subjects` in parallel. `None` keeps the purely - lazy behaviour. + batch shape `n_subjects` before backward induction starts. + `None` keeps the purely lazy behaviour. """ self.description = description @@ -332,39 +332,9 @@ def _solve_compiled( ) return period_to_regime_to_V_arr - def _spawn_simulate_compile( - self, - *, - n_subjects: int, - internal_params: InternalParams, - max_compilation_workers: int | None, - logger: logging.Logger, - ) -> Future[MappingProxyType[RegimeName, InternalRegime]]: - """Submit `compile_all_simulate_functions` to a single-thread executor. - - Caller decides whether to spawn (`n_subjects` set, batch shape - matches, no cache hit). The returned `Future` runs in parallel with - whatever the caller does next — typically `_solve_compiled(...)`. - """ - executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="lcm-simulate-compile" - ) - future = executor.submit( - compile_all_simulate_functions, - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=n_subjects, - max_compilation_workers=max_compilation_workers, - logger=logger, - ) - executor.shutdown(wait=False) - return future - def _resolve_simulate_internal_regimes( self, *, - compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None, actual_n_subjects: int, log: logging.Logger, ) -> MappingProxyType[RegimeName, InternalRegime]: @@ -376,11 +346,8 @@ def _resolve_simulate_internal_regimes( (purely lazy path). - `actual_n_subjects != n_subjects`: warn once per mismatching size, return the original `internal_regimes`. - - `actual_n_subjects == n_subjects`, `compile_future is not None`: - await it and cache the result. - - `actual_n_subjects == n_subjects`, `compile_future is None`: cache - must already hold the entry (caller spawned only on cache miss); - return the cached compiled regimes. + - `actual_n_subjects == n_subjects`: return the cached compiled + regimes (caller must have populated the cache before calling). """ if self.n_subjects is None: return self.internal_regimes @@ -397,11 +364,6 @@ def _resolve_simulate_internal_regimes( self.n_subjects, ) return self.internal_regimes - if compile_future is not None: - compiled = compile_future.result() - with self._simulate_compile_lock: - self._simulate_compile_cache[self.n_subjects] = compiled - return compiled with self._simulate_compile_lock: return self._simulate_compile_cache[self.n_subjects] @@ -487,19 +449,20 @@ def simulate( log = get_logger(log_level=log_level) actual_n_subjects = len(next(iter(initial_conditions.values()))) n_subjects = self.n_subjects - compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None = ( - None - ) if n_subjects is not None and n_subjects == actual_n_subjects: with self._simulate_compile_lock: needs_compile = n_subjects not in self._simulate_compile_cache if needs_compile: - compile_future = self._spawn_simulate_compile( - n_subjects=n_subjects, + compiled = compile_all_simulate_functions( + internal_regimes=self.internal_regimes, internal_params=internal_params, + ages=self.ages, + n_subjects=n_subjects, max_compilation_workers=max_compilation_workers, logger=log, ) + with self._simulate_compile_lock: + self._simulate_compile_cache[n_subjects] = compiled if period_to_regime_to_V_arr is None: period_to_regime_to_V_arr = self._solve_compiled( internal_params=internal_params, @@ -511,7 +474,6 @@ def simulate( max_compilation_workers=max_compilation_workers, ) simulate_internal_regimes = self._resolve_simulate_internal_regimes( - compile_future=compile_future, actual_n_subjects=actual_n_subjects, log=log, )