From 1c65f45dd8b2153bf794223ecf952f8a6f5e31bc Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 06:29:29 +0200 Subject: [PATCH 01/23] Support runtime-supplied points on continuous-action IrregSpacedGrids MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the existing runtime-points mechanism (previously state-only) to continuous action grids. With this change, an action declared as `IrregSpacedGrid(n_points=N)` adds an `{action_name: {"points": "Float1D"}}` entry to the regime params template, and `state_action_space()` substitutes the runtime-supplied points into `continuous_actions` at solve / simulate time. Motivation: aca-dev's structural retirement model has a `consumption` action grid whose lower bound is the per-iteration `consumption_floor` parameter. Without this change the c-grid bounds would have to be fixed at build time, which forces either an over-wide grid (wasted density) or model rebuilds per estimation iteration (recompilation). Mirrors the existing state-grid treatment: - `regime_template.py`: walks `regime.actions` alongside `regime.states`, factoring the shared shadowing check into helpers. - `interfaces.InternalRegime.state_action_space()`: builds both state and continuous-action replacements in a single sweep over `self.grids`, then calls `_base_state_action_space.replace(...)` with whichever side actually had substitutions. - `pandas_utils._is_runtime_grid_param`: also recognises action grids so column extraction in `to_dataframe()` keeps working. Tests (TDD): four new tests in `tests/test_runtime_params.py`, mirroring the state-grid counterparts — params-template entry, solve, runtime-vs-fixed equivalence, and a sanity check that varying runtime points actually changes V. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/interfaces.py | 68 ++++++++++++++++------- src/lcm/pandas_utils.py | 15 +++--- src/lcm/params/regime_template.py | 52 +++++++++++++----- tests/test_runtime_params.py | 90 +++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 39 deletions(-) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 9ff454b2..6c80ddb9 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -242,12 +242,12 @@ class InternalRegime: """Flat resolved fixed params for this regime, used by to_dataframe targets.""" def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpace: - """Return the state-action space with runtime state grids filled in. + """Return the state-action space with runtime grids filled in. - For IrregSpacedGrid with runtime-supplied points, the grid points come from - params as `{state_name}__points`. For _ShockGrid with runtime-supplied params, - the grid points are computed from shock params in the params dict or - resolved_fixed_params. + For IrregSpacedGrid (state or continuous action) with runtime-supplied + points, the grid points come from params as `{name}__points`. For + `_ShockGrid` with runtime-supplied params, the grid points are computed + from shock params in the params dict or `resolved_fixed_params`. Args: regime_params: Flat regime parameters supplied at runtime. @@ -257,35 +257,63 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac """ all_params = {**self.resolved_fixed_params, **regime_params} - replacements: dict[str, ContinuousState | DiscreteState] = {} - for state_name, spec in self.grids.items(): - if state_name not in self._base_state_action_space.states: + state_replacements: dict[str, ContinuousState | DiscreteState] = {} + action_replacements: dict[str, ContinuousAction] = {} + for name, spec in self.grids.items(): + in_states = name in self._base_state_action_space.states + in_continuous_actions = ( + name in self._base_state_action_space.continuous_actions + ) + if not (in_states or in_continuous_actions): continue if isinstance(spec, IrregSpacedGrid) and spec.pass_points_at_runtime: - points_key = f"{state_name}__points" + points_key = f"{name}__points" if points_key not in all_params: continue - replacements[state_name] = cast( - "ContinuousState", all_params[points_key] - ) - elif isinstance(spec, _ShockGrid) and spec.params_to_pass_at_runtime: + if in_states: + state_replacements[name] = cast( + "ContinuousState", all_params[points_key] + ) + else: + action_replacements[name] = cast( + "ContinuousAction", all_params[points_key] + ) + elif ( + in_states + and isinstance(spec, _ShockGrid) + and spec.params_to_pass_at_runtime + ): all_present = all( - f"{state_name}__{p}" in all_params - for p in spec.params_to_pass_at_runtime + f"{name}__{p}" in all_params for p in spec.params_to_pass_at_runtime ) if not all_present: continue shock_kw: dict[str, float] = dict(spec.params) for p in spec.params_to_pass_at_runtime: - shock_kw[p] = cast("float", all_params[f"{state_name}__{p}"]) - replacements[state_name] = spec.compute_gridpoints(**shock_kw) + shock_kw[p] = cast("float", all_params[f"{name}__{p}"]) + state_replacements[name] = spec.compute_gridpoints(**shock_kw) - if not replacements: + if not state_replacements and not action_replacements: return self._base_state_action_space - new_states = dict(self._base_state_action_space.states) | replacements + new_states = ( + MappingProxyType( + dict(self._base_state_action_space.states) | state_replacements + ) + if state_replacements + else None + ) + new_continuous_actions = ( + MappingProxyType( + dict(self._base_state_action_space.continuous_actions) + | action_replacements + ) + if action_replacements + else None + ) return self._base_state_action_space.replace( - states=MappingProxyType(new_states) + states=new_states, + continuous_actions=new_continuous_actions, ) diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 805932e2..2d696a45 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -504,12 +504,15 @@ def _resolve_per_target_template_key( def _is_runtime_grid_param(*, func_name: FunctionName, regime: Regime) -> bool: """Check if a template function key refers to a runtime grid param.""" - if func_name not in regime.states: - return False - grid = regime.states[func_name] - return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or ( - isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime) - ) + if func_name in regime.states: + grid = regime.states[func_name] + return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or ( + isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime) + ) + if func_name in regime.actions: + grid = regime.actions[func_name] + return isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime + return False def _fail_if_period_level(sr: pd.Series) -> None: diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 299b2c27..2b9c748e 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -70,27 +70,53 @@ def create_regime_params_template( _validate_no_shadowing(function_params, regime) + _add_runtime_grid_params(function_params, regime) + + return MappingProxyType( + {k: MappingProxyType(v) for k, v in function_params.items()} + ) + + +def _add_runtime_grid_params( + function_params: dict[FunctionName, dict[str, str]], + regime: Regime, +) -> None: + """Add runtime-supplied state/action grid params to the template in place.""" for state_name, grid in regime.states.items(): if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: - if state_name in function_params: - raise InvalidNameError( - f"IrregSpacedGrid state '{state_name}' (with runtime-supplied " - f"points) conflicts with a function of the same name in the regime." - ) + _fail_if_runtime_grid_shadows_function( + function_params=function_params, name=state_name, kind="state" + ) function_params[state_name] = {"points": "Float1D"} elif isinstance(grid, _ShockGrid) and grid.params_to_pass_at_runtime: - if state_name in function_params: - raise InvalidNameError( - f"_ShockGrid state '{state_name}' (with runtime-supplied params) " - f"conflicts with a function of the same name in the regime." - ) + _fail_if_runtime_grid_shadows_function( + function_params=function_params, + name=state_name, + kind="_ShockGrid state", + ) function_params[state_name] = dict.fromkeys( grid.params_to_pass_at_runtime, "float" ) - return MappingProxyType( - {k: MappingProxyType(v) for k, v in function_params.items()} - ) + for action_name, grid in regime.actions.items(): + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + _fail_if_runtime_grid_shadows_function( + function_params=function_params, name=action_name, kind="action" + ) + function_params[action_name] = {"points": "Float1D"} + + +def _fail_if_runtime_grid_shadows_function( + *, + function_params: dict[FunctionName, dict[str, str]], + name: str, + kind: str, +) -> None: + if name in function_params: + raise InvalidNameError( + f"IrregSpacedGrid {kind} '{name}' (with runtime-supplied " + f"points/params) conflicts with a function of the same name in the regime." + ) def _collect_all_functions_for_template( diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index 9d588649..ceaee4cf 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -141,3 +141,93 @@ def test_runtime_grid_matches_fixed(): for period in V_fixed: if "alive" in V_fixed[period] and "alive" in V_runtime[period]: aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) + + +def _make_action_grid_model(*, consumption_grid): + """Create a 2-regime model where consumption is the runtime-points action grid.""" + alive = Regime( + functions={"utility": _utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": consumption_grid}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_runtime_action_grid_in_params_template(): + """IrregSpacedGrid action with runtime-supplied points adds 'points' to template.""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + alive_template = model._params_template["alive"] + assert "consumption" in alive_template + assert "points" in alive_template["consumption"] + + +def test_solve_with_runtime_action_grid(): + """Solve should work when action grid points are provided via params.""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}, + } + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + assert len(period_to_regime_to_V_arr) > 0 + + +def test_runtime_action_grid_matches_fixed(): + """Runtime action grid with same points gives same V as a fixed action grid.""" + points = jnp.linspace(0.1, 5.0, 5) + + model_fixed = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(points=list(points.tolist())), + ) + params_fixed = {"discount_factor": 0.95, "interest_rate": 0.05} + V_fixed = model_fixed.solve(params=params_fixed, log_level="off") + + model_runtime = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params_runtime = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": points}}, + } + V_runtime = model_runtime.solve(params=params_runtime, log_level="off") + + for period in V_fixed: + if "alive" in V_fixed[period] and "alive" in V_runtime[period]: + aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) + + +def test_runtime_action_grid_changes_solution(): + """Different runtime action points should yield different V (sanity check).""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + base = {"discount_factor": 0.95, "interest_rate": 0.05} + V_low = model.solve( + params=base | {"alive": {"consumption": {"points": jnp.linspace(0.1, 1.0, 5)}}}, + log_level="off", + ) + V_high = model.solve( + params=base | {"alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}}, + log_level="off", + ) + # Period 0 alive value should differ when the action support differs + assert not jnp.allclose(V_low[0]["alive"], V_high[0]["alive"]) From 1792279387abea0004821b5cf706ee60e9752f7c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 07:24:08 +0200 Subject: [PATCH 02/23] benchmarks: bump aca-model to runtime-consumption-points version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit aca-model now declares `consumption` as `IrregSpacedGrid(n_points=N)` with runtime-supplied points. The bench builder now passes `model=model` to `get_benchmark_params` so consumption gridpoints are injected into params before solving. aca-model rev: adc8a19 → 4123fe9 (feature/runtime-consumption-points) Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmarks/bench_aca_baseline.py | 2 +- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 3ef4c9b2..477d5635 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -55,7 +55,7 @@ def _build() -> tuple[object, object, object]: ) model = create_benchmark_model() - _, model_params = get_benchmark_params() + _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 ) diff --git a/pixi.lock b/pixi.lock index 5b3f6f2e..3f235281 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=adc8a19328608781a5cb2a65ab2d93d580163aae#adc8a19328608781a5cb2a65ab2d93d580163aae + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=adc8a19328608781a5cb2a65ab2d93d580163aae#adc8a19328608781a5cb2a65ab2d93d580163aae +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev142+g95ab1b648.d20260423 - sha256: ae4d75e092f6528909d9f185ed13eccc4aabeae96f5c6d41987cbb2704afa7e7 + version: 0.0.2.dev125+g1c65f45dd.d20260429 + sha256: 9d509ac58b2af5658d439c586c2971ead9988f34017c4b0b64c4dd6db51b27aa requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index ced35ccc..7efd8a17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "adc8a19328608781a5cb2a65ab2d93d580163aae" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "4123fe9739c1c4bccebaa149985d0415a4272ef1" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From cf00e99b277e95f903eb3f2b782a288297ec8a63 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 13:55:45 +0200 Subject: [PATCH 03/23] Fail loudly when reading runtime IrregSpacedGrid before substitution `IrregSpacedGrid(n_points=N)` declares a continuous grid whose values are supplied at runtime via `params[regime][grid_name]['points']`. Substitution happens inside `InternalRegime.state_action_space(regime_params=...)` at solve / simulate time. Any code path that calls `to_jax()` on the base grid before substitution silently got `jnp.full(N, jnp.nan)` and went on to compute against the placeholder. That is exactly what fired in `validate_initial_conditions` for `task_simulate_aca`: the validator built the action grid by calling `internal_regime.grids[name].to_jax()` (placeholder NaNs), then asked `borrowing_constraint(consumption=NaN, wealth=W)` whether each gridpoint was feasible. NaN comparisons are False, so every action was reported infeasible for every subject in every initial regime. Make the invariant explicit: `IrregSpacedGrid.to_jax()` raises `GridInitializationError` for runtime-supplied grids, with a message pointing the caller at `state_action_space(regime_params=...)` for real values or `.n_points` for shape. Confine the legitimate "placeholder needed for AOT tracing" caller (the base state-action space) to a private helper in `state_action_space.py` that uses NaN explicitly. Reroute `_check_regime_feasibility` through the substituted state-action space. Add regression tests covering both runtime action and runtime state grids round-tripping `simulate(check_initial_conditions=True)`, and unit tests pinning down the new raise + the existing NaN-source mechanics in `map_coordinates` / `get_irreg_coordinate`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 19 +- src/lcm/simulation/initial_conditions.py | 21 +- src/lcm/state_action_space.py | 23 +- tests/test_runtime_params.py | 10 +- tests/test_single_feasible_action.py | 562 +++++++++++++++++++++++ 5 files changed, 624 insertions(+), 11 deletions(-) create mode 100644 tests/test_single_feasible_action.py diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index c81f9c6a..db646ca2 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -188,9 +188,24 @@ def pass_points_at_runtime(self) -> bool: return self.points is None def to_jax(self) -> Float1D: - """Convert the grid to a Jax array.""" + """Convert the grid to a Jax array. + + Raises `GridInitializationError` for runtime-supplied grids + (`pass_points_at_runtime=True`). Substitution happens at solve / + simulate time via `InternalRegime.state_action_space(regime_params=...)`; + any code path that reads the base grid's points before substitution is + a bug. + """ if self.points is None: - return jnp.full(self.n_points, jnp.nan) + raise GridInitializationError( + f"IrregSpacedGrid was declared with n_points={self.n_points} " + f"and no points; values are supplied at runtime via " + f"params['']['']['points']. Reading the grid " + f"before substitution is a bug — call " + f"`internal_regime.state_action_space(regime_params=...)` and " + f"read points from there, or use `.n_points` if only the shape " + f"is needed." + ) return jnp.asarray(self.points) @overload diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 1be89309..2cdd3c3b 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Mapping, Sequence from types import MappingProxyType -from typing import Never +from typing import Never, cast import jax import numpy as np @@ -25,6 +25,7 @@ from lcm.regime_building.Q_and_F import _get_feasibility from lcm.typing import ( ActionName, + FlatRegimeParams, InternalParams, RegimeName, RegimeNamesToIds, @@ -579,11 +580,21 @@ def _check_regime_feasibility( # noqa: C901 if not action_names: return None - grids = MappingProxyType( - {name: spec.to_jax() for name, spec in internal_regime.grids.items()} + # Build the state-action space with runtime-supplied grid points + # substituted. The base grid's `to_jax()` raises for runtime-supplied + # `IrregSpacedGrid`s declared with `pass_points_at_runtime=True`, so the + # validator must read points from `state_action_space(regime_params=...)`. + state_action_space = internal_regime.state_action_space( + regime_params=cast("FlatRegimeParams", MappingProxyType(dict(regime_params))), + ) + action_grids: dict[str, Array] = { + **state_action_space.discrete_actions, + **state_action_space.continuous_actions, + } + flat_actions = _build_flat_action_grid( + action_names=action_names, + grids=MappingProxyType(action_grids), ) - - flat_actions = _build_flat_action_grid(action_names=action_names, grids=grids) filtered_params = {k: v for k, v in regime_params.items() if k in accepted} state_names = list(internal_regime.variable_info.query("is_state").index) diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index ad6f6c1f..28cbce72 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -1,13 +1,29 @@ from types import MappingProxyType +import jax.numpy as jnp import pandas as pd from jax import Array -from lcm.grids import Grid +from lcm.grids import Grid, IrregSpacedGrid from lcm.interfaces import StateActionSpace from lcm.typing import StateName, StateOrActionName +def _grid_to_jax_or_placeholder(grid: Grid) -> Array: + """Return the grid's points, or a NaN placeholder for runtime-supplied grids. + + `IrregSpacedGrid.to_jax()` raises when its points haven't been supplied — that + is the right behaviour everywhere except here: the base state-action space + needs a *shape-correct* array to wire through pytree structures and AOT + tracing before runtime substitution by + `InternalRegime.state_action_space(regime_params=...)`. NaN (rather than + zero) makes any accidental computation against the placeholder fail loudly. + """ + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + return jnp.full(grid.n_points, jnp.nan) + return grid.to_jax() + + def create_state_action_space( *, variable_info: pd.DataFrame, @@ -33,7 +49,8 @@ def create_state_action_space( """ if states is None: _states = { - sn: grids[sn].to_jax() for sn in variable_info.query("is_state").index + sn: _grid_to_jax_or_placeholder(grids[sn]) + for sn in variable_info.query("is_state").index } else: _validate_all_states_present( @@ -47,7 +64,7 @@ def create_state_action_space( for name in variable_info.query("is_action & is_discrete").index } continuous_actions = { - name: grids[name].to_jax() + name: _grid_to_jax_or_placeholder(grids[name]) for name in variable_info.query("is_action & is_continuous").index } state_and_discrete_action_names = tuple( diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index ceaee4cf..0dab6acd 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -69,7 +69,15 @@ def test_runtime_grid_creation(): grid = IrregSpacedGrid(n_points=5) assert grid.pass_points_at_runtime assert grid.n_points == 5 - assert grid.to_jax().shape == (5,) # placeholder zeros + + +def test_runtime_grid_to_jax_raises_before_substitution(): + """`to_jax()` on a runtime grid is a bug — substitution happens at + solve/simulate time via `InternalRegime.state_action_space(regime_params=...)`. + The error message must point the caller at that API.""" + grid = IrregSpacedGrid(n_points=5) + with pytest.raises(Exception, match="state_action_space"): + grid.to_jax() def test_fixed_grid_not_runtime(): diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py new file mode 100644 index 00000000..70b6dbfe --- /dev/null +++ b/tests/test_single_feasible_action.py @@ -0,0 +1,562 @@ +"""Reproduce the aca-model NaN failure (issue OpenSourceEconomics/aca-model#9): +solve raises `InvalidValueFunctionError: ... regime 'dead': all values are NaN`. + +Two hypotheses are exercised here: + +1. **single/all infeasible action**: when constraints leave a single (or no) + consumption gridpoint feasible. `max(Q, where=F, initial=-inf)` is supposed + to mask infeasible Q values, but if Q itself contains NaN at infeasible + cells (e.g. `log(0)` at `consumption=0`), the mask is enough — proven by + the tests in this file. `-inf` from all-infeasible cells does not cascade + to NaN by itself either. + +2. **CRRA bequest with pref_type indexing under jnp.where**: the bequest + function evaluates *both* branches of `jnp.where(jnp.isclose(gamma, 1), + log_branch, power_branch)`. For a parameter set where gamma is exactly 1 + for one preference type but not the other, the power_branch divides by + `1 - gamma = 0` for the gamma=1 type. NaN/Inf from the *unselected* + branch leaks through `jnp.where`'s gradient/forward pass under JIT + tracing in some configurations. + +The model uses an `IrregSpacedGrid` consumption action with runtime-supplied +points to also exercise the `feature/runtime-action-grids` path (PR #338). +""" + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.grids import IrregSpacedGrid +from lcm.typing import ContinuousAction, ContinuousState, FloatND + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _utility(consumption: ContinuousAction) -> FloatND: + """CRRA-like utility. log requires consumption > 0.""" + return jnp.log(consumption) + + +def _next_wealth( + wealth: ContinuousState, + consumption: ContinuousAction, +) -> ContinuousState: + return wealth - consumption + + +def _borrowing_constraint( + consumption: ContinuousAction, wealth: ContinuousState +) -> FloatND: + return consumption <= wealth + + +def _next_regime(age: int, last_alive_age: int) -> FloatND: + """Deterministic regime transition: alive→alive while age= last_alive_age, RegimeId.dead, RegimeId.alive) + + +def _build_model( + *, + wealth_lo: float, + wealth_hi: float, + n_wealth: int, + consumption_lo: float, + consumption_hi: float, + n_consumption: int, + n_periods: int, +) -> tuple[Model, dict]: + """Build a 2-regime (alive, dead) model with runtime consumption points. + + The wealth grid is fixed-LinSpaced (so changing it doesn't perturb the + runtime-action-grid path). The consumption grid is the runtime-supplied + IrregSpacedGrid. + """ + last_alive_age = n_periods - 2 # alive at ages 0..n-2; dead at n-1 + alive = Regime( + functions={"utility": _utility}, + states={"wealth": LinSpacedGrid(start=wealth_lo, stop=wealth_hi, n_points=n_wealth)}, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": IrregSpacedGrid(n_points=n_consumption)}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age > last_alive_age, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=n_periods - 1, step="Y"), + regime_id_class=RegimeId, + ) + consumption_points = jnp.linspace(consumption_lo, consumption_hi, n_consumption) + params = { + "discount_factor": 0.95, + "alive": { + "consumption": {"points": consumption_points}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + } + return model, params + + +def test_baseline_no_nan(): + """Healthy regime: at every wealth, every consumption point is feasible. + + consumption_lo == wealth_lo so the smallest action is always feasible at + every wealth gridpoint. `consumption_hi <= wealth_lo` ensures even the + largest consumption is feasible at the lowest wealth state. + """ + model, params = _build_model( + wealth_lo=10.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + + +def test_some_states_have_only_one_feasible_action(): + """At low-wealth states, only consumption[0] satisfies `consumption <= wealth`. + + consumption_lo < wealth_lo (so smallest action feasible at every wealth), + but consumption[1:] > wealth_lo (so at the lowest wealth gridpoint only + consumption[0] is feasible). + """ + model, params = _build_model( + wealth_lo=1.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=0.5, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + # At wealth = 1.0, only consumption = 0.5 is feasible (1.625, 2.75, 3.875, + # 5.0 all > 1.0). + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)), ( + "single-feasible-action state should not produce NaN V" + ) + + +def test_some_states_have_no_feasible_action(): + """At sufficiently low wealth, *every* consumption gridpoint is infeasible. + + consumption_lo > wealth_lo means at the lowest wealth state no + consumption point satisfies the borrowing constraint. + """ + model, params = _build_model( + wealth_lo=0.1, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + # At wealth = 0.1, no consumption point is feasible. max returns -inf. + # The next period interpolates V over the resulting -inf cells. + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)), ( + "all-infeasible state should yield -inf, not NaN" + ) + + +def test_log_zero_consumption_propagates_nan_via_max_when_unconstrained(): + """U(c=0) = -inf is fine, but U evaluated at infeasible negative wealth + via `next_wealth = wealth - c` going through interpolation in the next + period should not pollute V.""" + # consumption_lo = 0 → log(0) = -inf at the smallest action, regardless + # of feasibility. `where=F_arr` should mask this out of the max. + model, params = _build_model( + wealth_lo=1.0, + wealth_hi=10.0, + n_wealth=5, + consumption_lo=0.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + # log(0) = -inf is not NaN, but combined with where-mask edge + # cases this could in principle leak; check it does not. + assert not jnp.any(jnp.isnan(V_arr)) + + +@pytest.mark.parametrize( + ("wealth_lo", "consumption_lo", "label"), + [ + (1.0, 0.5, "single-feasible"), + (0.1, 1.0, "all-infeasible"), + ], +) +def test_simulate_with_constrained_action_grid(wealth_lo, consumption_lo, label): + """End-to-end solve+simulate for both regimes.""" + model, params = _build_model( + wealth_lo=wealth_lo, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=consumption_lo, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([wealth_lo, 5.0, 20.0]), + "regime": jnp.array( + [RegimeId.alive, RegimeId.alive, RegimeId.alive], dtype=jnp.int32 + ), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + check_initial_conditions=False, + log_level="off", + ) + df = result.to_dataframe() + assert not df["value"].isna().any(), ( + f"{label}: simulated value column should not contain NaN" + ) + + +# --------------------------------------------------------------------------- +# Replicas of the aca-baseline failure path: dead regime with a CRRA bequest +# whose `gamma` is per-pref_type, evaluated through `jnp.where`. +# --------------------------------------------------------------------------- + + +@categorical(ordered=False) +class PrefType: + type_0: int + type_1: int + + +@categorical(ordered=False) +class AliveDeadRegimeId: + alive: int + dead: int + + +def _crra_bequest( + assets: ContinuousState, + pref_type, + bequest_shifter: float, + consumption_weight, + coefficient_rra, +) -> FloatND: + """Replica of aca_model.agent.preferences.bequest, simplified. + + `consumption_weight` and `coefficient_rra` are FloatND indexed by + `pref_type`. Both branches of the `jnp.where` are traced. + """ + alpha = consumption_weight[pref_type] + gamma = coefficient_rra[pref_type] + assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter + one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + val = jnp.where( + jnp.isclose(gamma, 1.0), + jnp.log(assets_shifted), + assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, + ) + return val + + +def _alive_utility( + consumption: ContinuousAction, pref_type, consumption_weight +) -> FloatND: + """Make pref_type matter in alive's utility too (otherwise pylcm complains).""" + alpha = consumption_weight[pref_type] + return alpha * jnp.log(consumption) + + +def _next_assets(assets: ContinuousState, consumption: ContinuousAction) -> ContinuousState: + return assets - consumption + + +def _alive_borrow(consumption: ContinuousAction, assets: ContinuousState) -> FloatND: + return consumption <= assets + + +def _alive_to_dead(age: int, last_alive_age: int) -> FloatND: + """Deterministic regime transition; returns a scalar regime ID.""" + return jnp.where( + age >= last_alive_age, AliveDeadRegimeId.dead, AliveDeadRegimeId.alive + ) + + +def _build_alive_dead_model( + *, + coefficient_rra: tuple[float, float], + consumption_weight: tuple[float, float], + n_periods: int = 3, +) -> tuple[Model, dict]: + last_alive_age = n_periods - 2 + from lcm import DiscreteGrid + + alive = Regime( + functions={"utility": _alive_utility}, + states={ + "assets": LinSpacedGrid(start=1.0, stop=20.0, n_points=5), + "pref_type": DiscreteGrid(PrefType, batch_size=1), + }, + state_transitions={"assets": _next_assets, "pref_type": None}, + actions={"consumption": IrregSpacedGrid(n_points=5)}, + constraints={"borrowing_constraint": _alive_borrow}, + transition=_alive_to_dead, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": _crra_bequest}, + states={ + "assets": LinSpacedGrid(start=1.0, stop=20.0, n_points=5), + "pref_type": DiscreteGrid(PrefType, batch_size=1), + }, + active=lambda _age: True, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=n_periods - 1, step="Y"), + regime_id_class=AliveDeadRegimeId, + ) + cw_arr = jnp.asarray(consumption_weight) + params = { + "discount_factor": 0.95, + "alive": { + "utility": {"consumption_weight": cw_arr}, + "consumption": {"points": jnp.linspace(0.5, 5.0, 5)}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + "dead": { + "utility": { + "bequest_shifter": 100.0, + "consumption_weight": cw_arr, + "coefficient_rra": jnp.asarray(coefficient_rra), + }, + }, + } + return model, params + + +def test_bequest_gamma_close_to_one_is_safe(): + """gamma=0.999077 (the benchmark value for type_1) should not produce NaN.""" + model, params = _build_alive_dead_model( + coefficient_rra=(3.84, 0.999077), + consumption_weight=(0.68, 0.88), + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + + +def test_bequest_gamma_exactly_one_for_one_type_only(): + """gamma=1.0 exactly for one pref type triggers `jnp.where(isclose, log, power)`. + + The unselected `power` branch divides by `1 - gamma = 0` for that type. JAX + evaluates both branches of `jnp.where`; the non-finite from the unselected + branch is masked, but `0/0` produces NaN that may not be masked correctly + when the operand to `jnp.where` is itself NaN under XLA's nan-prop rules. + """ + model, params = _build_alive_dead_model( + coefficient_rra=(3.84, 1.0), # type_1 hits the log branch + consumption_weight=(0.68, 0.88), + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for period, regime_to_V in period_to_regime_to_V_arr.items(): + for regime, V_arr in regime_to_V.items(): + assert not jnp.any(jnp.isnan(V_arr)), ( + f"NaN in V[{regime}, period={period}] when one type has gamma=1.0" + ) + + +# --------------------------------------------------------------------------- +# Direct probe: `map_coordinates` produces NaN at ±inf / NaN coordinates. +# This is the concrete NaN source — `lower_weight = 1 - inf = -inf` and +# `upper_weight = inf` combined with positive grid values gives `inf - inf = +# NaN`. The aca-baseline NaN-in-V at age 51 is most plausibly traced back to +# *some* upstream computation (next_assets / next_aime, or a state coordinate +# from `get_irreg_coordinate` / `get_*_coordinate` that divides by zero on a +# degenerate grid segment) producing inf, which then poisons the value +# function via this interpolation path. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("bad_coord", [jnp.inf, -jnp.inf, jnp.nan]) +def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): + """`map_coordinates` cannot recover from a non-finite continuous-state + coordinate: the linear-interp weights become `inf` and `1 - inf = -inf`, + and `inf * V[k] - inf * V[k-1]` reduces to NaN. + + Implication for callers: any path that can feed `inf` or `NaN` into the + coordinate finder (e.g. division by zero in a state transition, an + overflow when V values are O(1e8), or a `0/0` in a degenerate + IrregSpacedGrid segment) will produce NaN in V. + """ + from lcm.regime_building.ndimage import map_coordinates + + V_arr = jnp.array([1.0, 5.0, 12.0]) + out = map_coordinates(V_arr, coordinates=[jnp.array(bad_coord)]) + assert jnp.isnan(out) + + +def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): + """`get_irreg_coordinate` divides by `upper_point - lower_point`. If a + runtime-supplied points array contains duplicates, the divisor is 0. + + Reproduces a class of failures that is *only* possible under + `feature/runtime-action-grids` / runtime-state-grids when the caller + constructs the points from a parameter that can collapse (e.g. + `geomspace(consumption_floor, MAX, n_points)` with `consumption_floor == + MAX`, or any param-driven `linspace` whose endpoints can coincide). + """ + from lcm.grids.coordinates import get_irreg_coordinate + + # Duplicate adjacent points where the query value equals the duplicate. + # `searchsorted([0, 1, 1], 1.0, side='right')=3` → clipped to n-1=2, + # idx_lower=1, lower_point=points[1]=1.0, upper_point=points[2]=1.0, + # step_size=0 → decimal_part = 0/0 = nan. + points = jnp.array([0.0, 1.0, 1.0]) + coord = get_irreg_coordinate(value=jnp.array(1.0), points=points) + assert not jnp.isfinite(coord), ( + "Duplicate adjacent grid points cause `step_size = 0` in " + "`get_irreg_coordinate`; current behaviour is to silently return " + "inf/nan, which then poisons V interpolation downstream." + ) + + +# --------------------------------------------------------------------------- +# `validate_initial_conditions` uses `_base_state_action_space` directly, +# which still holds the placeholder zeros for runtime-supplied +# `IrregSpacedGrid`. With a feasibility constraint that the all-zero +# placeholder fails, every subject is reported infeasible — even though the +# real (post-substitution) grid would pass. This affects runtime grids +# regardless of whether they are state or action grids. +# --------------------------------------------------------------------------- + + +def _runtime_state_grid_model() -> tuple[Model, dict, dict]: + """A 2-regime model with a runtime-supplied IrregSpacedGrid *state*.""" + + @categorical(ordered=False) + class RuntimeRegimeId: + alive: int + dead: int + + def utility(consumption, wealth): + return jnp.log(consumption) + 0.0 * wealth + + def next_wealth(wealth, consumption): + return wealth - consumption + + def borrow(consumption, wealth): + # The validator sees `wealth` as a per-subject array with the + # subject-supplied initial values, but `consumption` as the *grid* + # (placeholder zeros for runtime grids). With a feasibility check + # that requires `consumption > 0`, every action gridpoint is + # infeasible until the runtime points replace the placeholder. + return consumption > 0 + + def next_regime(age, last_alive_age): + return jnp.where(age >= last_alive_age, RuntimeRegimeId.dead, RuntimeRegimeId.alive) + + last_alive_age = 1 + alive = Regime( + functions={"utility": utility}, + states={"wealth": IrregSpacedGrid(n_points=4)}, # runtime state grid + state_transitions={"wealth": next_wealth}, + actions={"consumption": LinSpacedGrid(start=0.5, stop=5.0, n_points=5)}, + constraints={"borrow": borrow}, + transition=next_regime, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age > last_alive_age, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RuntimeRegimeId, + ) + params = { + "discount_factor": 0.95, + "alive": { + "wealth": {"points": jnp.linspace(1.0, 10.0, 4)}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + } + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([2.0, 5.0, 9.0]), + "regime": jnp.array( + [RuntimeRegimeId.alive, RuntimeRegimeId.alive, RuntimeRegimeId.alive], + dtype=jnp.int32, + ), + } + return model, params, initial_conditions + + +def test_runtime_action_grid_passes_initial_conditions_validation(): + """`feature/runtime-action-grids` regression: initial-conditions + feasibility check must use the *substituted* action grid, not the + `_base_state_action_space` placeholder zeros.""" + model, params = _build_model( + wealth_lo=10.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([10.0, 15.0, 20.0]), + "regime": jnp.array( + [RegimeId.alive, RegimeId.alive, RegimeId.alive], dtype=jnp.int32 + ), + } + # `check_initial_conditions=True` (the default) must pass — the + # runtime-supplied consumption points are well-formed. + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + ) + assert result.n_subjects == 3 + + +def test_runtime_state_grid_passes_initial_conditions_validation(): + """Same regression for runtime-supplied *state* grids.""" + model, params, initial_conditions = _runtime_state_grid_model() + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + ) + assert result.n_subjects == 3 From db98cde25b64ce1ba506b7b32a8c0c3c881b938f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:12:19 +0000 Subject: [PATCH 04/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_single_feasible_action.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index 70b6dbfe..20b19152 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -78,7 +78,9 @@ def _build_model( last_alive_age = n_periods - 2 # alive at ages 0..n-2; dead at n-1 alive = Regime( functions={"utility": _utility}, - states={"wealth": LinSpacedGrid(start=wealth_lo, stop=wealth_hi, n_points=n_wealth)}, + states={ + "wealth": LinSpacedGrid(start=wealth_lo, stop=wealth_hi, n_points=n_wealth) + }, state_transitions={"wealth": _next_wealth}, actions={"consumption": IrregSpacedGrid(n_points=n_consumption)}, constraints={"borrowing_constraint": _borrowing_constraint}, @@ -290,7 +292,9 @@ def _alive_utility( return alpha * jnp.log(consumption) -def _next_assets(assets: ContinuousState, consumption: ContinuousAction) -> ContinuousState: +def _next_assets( + assets: ContinuousState, consumption: ContinuousAction +) -> ContinuousState: return assets - consumption @@ -479,7 +483,9 @@ def borrow(consumption, wealth): return consumption > 0 def next_regime(age, last_alive_age): - return jnp.where(age >= last_alive_age, RuntimeRegimeId.dead, RuntimeRegimeId.alive) + return jnp.where( + age >= last_alive_age, RuntimeRegimeId.dead, RuntimeRegimeId.alive + ) last_alive_age = 1 alive = Regime( From 03ba800ff8d9ccc95b75b5ce89026f8c96868611 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 14:35:44 +0200 Subject: [PATCH 05/23] Fix remaining ruff check errors in test_single_feasible_action.py Move late `DiscreteGrid`, `map_coordinates`, and `get_irreg_coordinate` imports to the module top level (PLC0415), drop the unnecessary `val` assignment before return (RET504), and mark the unused `wealth` arg in the local `borrow` constraint as `# noqa: ARG001`. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_single_feasible_action.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index 20b19152..04628ed2 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -25,8 +25,10 @@ import jax.numpy as jnp import pytest -from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical +from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical from lcm.grids import IrregSpacedGrid +from lcm.grids.coordinates import get_irreg_coordinate +from lcm.regime_building.ndimage import map_coordinates from lcm.typing import ContinuousAction, ContinuousState, FloatND @@ -276,12 +278,11 @@ def _crra_bequest( gamma = coefficient_rra[pref_type] assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) - val = jnp.where( + return jnp.where( jnp.isclose(gamma, 1.0), jnp.log(assets_shifted), assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, ) - return val def _alive_utility( @@ -316,7 +317,6 @@ def _build_alive_dead_model( n_periods: int = 3, ) -> tuple[Model, dict]: last_alive_age = n_periods - 2 - from lcm import DiscreteGrid alive = Regime( functions={"utility": _alive_utility}, @@ -418,8 +418,6 @@ def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): overflow when V values are O(1e8), or a `0/0` in a degenerate IrregSpacedGrid segment) will produce NaN in V. """ - from lcm.regime_building.ndimage import map_coordinates - V_arr = jnp.array([1.0, 5.0, 12.0]) out = map_coordinates(V_arr, coordinates=[jnp.array(bad_coord)]) assert jnp.isnan(out) @@ -435,8 +433,6 @@ def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): `geomspace(consumption_floor, MAX, n_points)` with `consumption_floor == MAX`, or any param-driven `linspace` whose endpoints can coincide). """ - from lcm.grids.coordinates import get_irreg_coordinate - # Duplicate adjacent points where the query value equals the duplicate. # `searchsorted([0, 1, 1], 1.0, side='right')=3` → clipped to n-1=2, # idx_lower=1, lower_point=points[1]=1.0, upper_point=points[2]=1.0, @@ -474,7 +470,7 @@ def utility(consumption, wealth): def next_wealth(wealth, consumption): return wealth - consumption - def borrow(consumption, wealth): + def borrow(consumption, wealth): # noqa: ARG001 # The validator sees `wealth` as a per-subject array with the # subject-supplied initial values, but `consumption` as the *grid* # (placeholder zeros for runtime grids). With a feasibility check From 3769d2d63a0f708b2bf4b8a11b603b686282b996 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 17:28:20 +0200 Subject: [PATCH 06/23] Raise on regime-function-output indexed by discrete state in a consumer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A regime function whose output is then re-indexed by a discrete state inside another consumer (function, constraint, or transition) is a silent footgun: pylcm broadcasts function outputs to per-cell scalars before consumption, so the indexing silently produces NaN at runtime instead of the intended scalar. The aca-baseline benchmark hit this via `bequest(... utility_scale_factor[pref_type])` where `utility_scale_factor` is registered as a regime function — the dead regime's V came back all-NaN with no actionable error. Adds an AST-walking validator in `validate_logical_consistency` that inspects every consumer (functions, constraints, transition) for a `Subscript(Name=X, slice=Name=Y)` pattern where `X` is in `regime.functions` and `Y` is a `DiscreteGrid` state. If any clash is found, raises `RegimeInitializationError` listing each clash and pointing the user at the safe pattern (function takes the state, returns a scalar — see `discount_factor`). Three TDD tests in `tests/test_function_output_state_indexing.py`: - the clash raises (functions case) - the safe pattern (function takes the state, scalar return) builds - the check applies to constraints too Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 80 ++++++++++ tests/test_function_output_state_indexing.py | 153 +++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 tests/test_function_output_state_indexing.py diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 0754c0a7..ddc63631 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -5,6 +5,9 @@ """ +import ast +import inspect +import textwrap from collections.abc import Callable, Mapping from typing import TypeAliasType @@ -125,6 +128,7 @@ def validate_logical_consistency(regime: Regime) -> None: error_messages.extend(_validate_active(regime.active)) error_messages.extend(_validate_state_transitions(regime)) + error_messages.extend(_validate_function_output_state_indexing(regime)) states_and_actions_overlap = set(regime.states) & set(regime.actions) if states_and_actions_overlap: @@ -138,6 +142,82 @@ def validate_logical_consistency(regime: Regime) -> None: raise RegimeInitializationError(msg) +def _validate_function_output_state_indexing(regime: Regime) -> list[str]: + """Detect the regime-function-output / state-indexed-input name clash. + + A regime function whose output is then re-indexed by a discrete state inside + another consumer (function, constraint, or transition) is a silent footgun: + pylcm broadcasts function outputs to per-cell scalars before consumption, so + the indexing produces NaN at runtime instead of the intended scalar. + + The safe pattern is to take the state as input on the producing function and + return the scalar directly (see `aca_model.agent.preferences.discount_factor`). + """ + function_output_names = set(regime.functions) + discrete_state_names = { + name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid) + } + if not function_output_names or not discrete_state_names: + return [] + + consumers: list[tuple[str, Callable]] = [] + consumers.extend(regime.functions.items()) + consumers.extend(regime.constraints.items()) + if callable(regime.transition): + consumers.append(("regime_transition", regime.transition)) + + errors: list[str] = [] + for consumer_name, func in consumers: + clashes = _find_function_output_state_indexing( + func=func, + function_output_names=function_output_names, + discrete_state_names=discrete_state_names, + ) + for func_output_name, state_name in clashes: + errors.append( + f"Consumer '{consumer_name}' indexes regime function output " + f"'{func_output_name}' by discrete state '{state_name}' " + f"(`{func_output_name}[{state_name}]`). pylcm broadcasts " + f"function outputs to per-cell scalars before consumption, so " + f"this indexing silently produces NaN. Refactor " + f"'{func_output_name}' to take '{state_name}' as input and " + f"return the scalar directly." + ) + return errors + + +def _find_function_output_state_indexing( + *, + func: Callable, + function_output_names: set[str], + discrete_state_names: set[str], +) -> list[tuple[str, str]]: + """Return `(function_output_name, state_name)` clashes inside `func`'s body.""" + try: + source = textwrap.dedent(inspect.getsource(func)) + except OSError, TypeError: + return [] + try: + tree = ast.parse(source) + except SyntaxError: + return [] + + clashes: list[tuple[str, str]] = [] + for node in ast.walk(tree): + if not isinstance(node, ast.Subscript): + continue + if not isinstance(node.value, ast.Name): + continue + if node.value.id not in function_output_names: + continue + if not isinstance(node.slice, ast.Name): + continue + if node.slice.id not in discrete_state_names: + continue + clashes.append((node.value.id, node.slice.id)) + return clashes + + def collect_state_transitions( states: Mapping[StateName, Grid], state_transitions: Mapping[ diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_state_indexing.py new file mode 100644 index 00000000..f326833c --- /dev/null +++ b/tests/test_function_output_state_indexing.py @@ -0,0 +1,153 @@ +"""Tests for the regime-function-output / state-indexed-input name clash. + +A regime function whose output is then re-indexed by a discrete state +inside another function is a silent footgun: pylcm broadcasts function +outputs to per-cell scalars before consumption, so the indexing produces +NaN at runtime instead of the intended scalar. + +The validation layer must raise on construction with a clear message +pointing the user at the safe pattern (see `discount_factor` in +`aca_model.agent.preferences`: take the state as input, return a scalar). +""" + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.exceptions import RegimeInitializationError +from lcm.typing import ContinuousAction, DiscreteState, FloatND + + +@categorical(ordered=False) +class PrefType: + type_0: int + type_1: int + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _per_type_scale(some_param: FloatND) -> FloatND: + """Returns one scale per pref_type — depends only on a per-type Series param.""" + return jnp.abs(1.0 / (1.0 - some_param)) + + +def _utility_with_state_indexed_function_output( + consumption: ContinuousAction, + pref_type: DiscreteState, + per_type_scale: FloatND, +) -> FloatND: + # The clash: per_type_scale is registered as a regime function output + # (returns shape (n_pref_types,)), but here it is consumed and indexed + # by the `pref_type` discrete state. + return per_type_scale[pref_type] * jnp.log(consumption + 1.0) + + +def _next_regime(period: int) -> FloatND: + return jnp.where(period >= 1, RegimeId.dead, RegimeId.alive) + + +def _make_clashing_model() -> Model: + alive = Regime( + functions={ + "utility": _utility_with_state_indexed_function_output, + "per_type_scale": _per_type_scale, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_function_output_indexed_by_state_raises(): + """A regime function output indexed by a discrete state inside another regime + function must raise on construction (silent NaN bug otherwise).""" + with pytest.raises( + RegimeInitializationError, + match=r"per_type_scale.*pref_type", + ): + _make_clashing_model() + + +def _utility_safe( + consumption: ContinuousAction, + pref_type: DiscreteState, # noqa: ARG001 + per_type_scale: FloatND, +) -> FloatND: + # Safe variant: per_type_scale is consumed as a scalar (no [pref_type] indexing). + return per_type_scale * jnp.log(consumption + 1.0) + + +def _per_type_scale_safe(pref_type: DiscreteState, some_param: FloatND) -> FloatND: + """Safe variant: takes pref_type, returns scalar (mirrors discount_factor).""" + return jnp.abs(1.0 / (1.0 - some_param[pref_type])) + + +def test_safe_pattern_does_not_raise(): + """The safe pattern (function takes the state, returns a scalar) builds fine.""" + alive = Regime( + functions={ + "utility": _utility_safe, + "per_type_scale": _per_type_scale_safe, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_constraint_indexing_function_output_by_state_raises(): + """The check applies to regime constraints too, not only `functions`.""" + + def _constraint_indexing_function_output( + consumption: ContinuousAction, + pref_type: DiscreteState, + per_type_scale: FloatND, + ) -> FloatND: + return consumption <= per_type_scale[pref_type] + + with pytest.raises( + RegimeInitializationError, + match=r"per_type_scale.*pref_type", + ): + Regime( + functions={ + "utility": lambda consumption, pref_type, per_type_scale: jnp.log( # noqa: ARG005 + consumption + 1.0 + ), + "per_type_scale": _per_type_scale, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + constraints={"feasibility": _constraint_indexing_function_output}, + transition=_next_regime, + active=lambda age: age < 2, + ) From 282542f25e58d72172e39fb79aebee94d2057590 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 18:32:41 +0200 Subject: [PATCH 07/23] benchmarks: bump aca-model to dead-regime-NaN fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit aca-model `feature/runtime-consumption-points` 4123fe9 → 1342861 (refactors `utility_scale_factor` to take `pref_type` and return a scalar, eliminating the regime-function-output / state-indexed-input clash that produced NaN in the dead regime's V). Co-Authored-By: Claude Opus 4.7 (1M context) --- pixi.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 3f235281..e507ccc8 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=4123fe9739c1c4bccebaa149985d0415a4272ef1#4123fe9739c1c4bccebaa149985d0415a4272ef1 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev125+g1c65f45dd.d20260429 - sha256: 9d509ac58b2af5658d439c586c2971ead9988f34017c4b0b64c4dd6db51b27aa + version: 0.0.2.dev130+g3769d2d63.d20260429 + sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 7efd8a17..3fe01162 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "4123fe9739c1c4bccebaa149985d0415a4272ef1" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } From 72c83f79faa3b7536bdc6b331d2ffba711ead8a5 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 21:02:11 +0200 Subject: [PATCH 08/23] Substitute runtime-supplied action gridpoints in simulate's state-action space `create_regime_state_action_space` (used during forward simulation) was calling `create_state_action_space` directly, which leaves `pass_points_at_runtime=True` IrregSpacedGrid action grids as their NaN placeholder. The placeholder fed straight into `argmax_and_max_Q_over_a` and `_lookup_values_from_indices`, so optimal actions came back NaN, the source regime's `next_state` propagated NaN into every target regime's namespaced state, and `validate_V` raised on the first downstream regime whose utility depended on those states (the dead regime in aca-model: assets/pref_type both NaN). Route through `internal_regime.state_action_space(regime_params=...)` (the same path solve uses) and overlay the per-subject states. Add a TDD regression test in tests/test_runtime_params.py covering the simulate path. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/simulation/simulate.py | 1 + src/lcm/simulation/transitions.py | 25 ++++++----- tests/test_runtime_params.py | 74 +++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 9c017db3..d54d2674 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -244,6 +244,7 @@ def _simulate_regime_in_period( state_action_space = create_regime_state_action_space( internal_regime=internal_regime, states=states, + regime_params=internal_params[regime_name], ) # Compute optimal actions # We need to pass the value function array of the next period to the diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 67e6df67..23bfd58c 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -15,7 +15,6 @@ from lcm.interfaces import InternalRegime, StateActionSpace from lcm.simulation.random import generate_simulation_keys -from lcm.state_action_space import create_state_action_space from lcm.typing import ( ActionName, Bool1D, @@ -31,29 +30,35 @@ def create_regime_state_action_space( *, internal_regime: InternalRegime, states: MappingProxyType[str, Array], + regime_params: FlatRegimeParams, ) -> StateActionSpace: """Create the state-action space containing only the relevant subjects in a regime. + Continuous action grids declared with `pass_points_at_runtime=True` are + completed from `regime_params` (via + `InternalRegime.state_action_space`) — otherwise they would carry the + NaN placeholder used during compilation, which propagates into + `optimal_actions` and ultimately `next_states`. + Args: internal_regime: The internal regime instance. states: The current states of all subjects. + regime_params: Flat regime parameters supplied at runtime, used to + substitute runtime-supplied action gridpoints. Returns: The state-action space for the subjects in the regime. """ - relevant_state_names = internal_regime.variable_info.query("is_state").index + base = internal_regime.state_action_space(regime_params=regime_params) - states_for_state_action_space = { - sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names - } - - return create_state_action_space( - variable_info=internal_regime.variable_info, - grids=internal_regime.grids, - states=states_for_state_action_space, + relevant_state_names = internal_regime.variable_info.query("is_state").index + states_for_state_action_space = MappingProxyType( + {sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names} ) + return base.replace(states=states_for_state_action_space) + def calculate_next_states( *, diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index 0dab6acd..a401608d 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -239,3 +239,77 @@ def test_runtime_action_grid_changes_solution(): ) # Period 0 alive value should differ when the action support differs assert not jnp.allclose(V_low[0]["alive"], V_high[0]["alive"]) + + +def _make_action_grid_model_with_stateful_dead(*, consumption_grid): + """Variant where `dead` has a `wealth` state so its utility depends on it. + + Mirrors the aca-model dead regime (carries assets / pref_type so the + bequest function can read them). Used to surface NaN propagation + when the simulate path forgets to substitute runtime-supplied action + gridpoints. + """ + + def _alive_utility( + consumption: ContinuousAction, wealth: ContinuousState + ) -> FloatND: + return jnp.log(consumption + 1) + 0.01 * wealth + + def _dead_utility(wealth: ContinuousState) -> FloatND: + return jnp.log(wealth + 1) + + alive = Regime( + functions={"utility": _alive_utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={ + "wealth": { + "alive": _next_wealth, + "dead": _next_wealth, + }, + }, + actions={"consumption": consumption_grid}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": _dead_utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + active=lambda _age: True, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_simulate_with_runtime_action_grid_no_nan() -> None: + """Simulate must substitute runtime-supplied action gridpoints into the + state-action space; otherwise the action grid is filled with NaN + placeholders, optimal_actions become NaN, next_states propagate NaN to + the dead regime, and `validate_V` raises. + """ + model = _make_action_grid_model_with_stateful_dead( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}, + } + initial_conditions = { + "regime": jnp.array([RegimeId.alive, RegimeId.alive, RegimeId.alive]), + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([2.0, 5.0, 9.0]), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + check_initial_conditions=False, + ) + df = result.to_dataframe() + assert not df["value"].isna().any() From db6214f275d3863fc11a79395ec6b9a0e300db79 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 15:02:27 +0200 Subject: [PATCH 09/23] Guard log-spaced grids against non-positive start; reject non-finite grid values `LogSpacedGrid` previously inherited only the generic continuous-grid checks (start < stop, n_points > 0). With `start <= 0`, `to_jax()` silently returned NaN/-inf, and the bug would only surface deep inside an interpolation kernel. Now refuses at construction. While here, tighten two adjacent silent-failure modes: - `_validate_continuous_grid` rejects non-finite `start`/`stop`. `start >= stop` is False for NaN, so a NaN bound previously slipped through every check. - `_validate_irreg_spaced_grid` rejects non-finite points. The ascending-order test uses `>=`, which is False for NaN, so a NaN point previously passed the order check silently. Both matter for runtime-supplied grids: e.g. `geomspace(consumption_floor, MAX, N)` with a bad `consumption_floor` produces all-NaN points, and we want that caught at the grid layer rather than as a downstream V_arr NaN diagnostic. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 58 +++++++++++++++++++++++++++++++------ tests/test_grids.py | 30 +++++++++++++++++++ 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index db646ca2..8550188d 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -1,4 +1,5 @@ import dataclasses +import math from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -116,12 +117,24 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: class LogSpacedGrid(UniformContinuousGrid): """A logarithmically spaced grid of continuous values. + Requires `start > 0`; otherwise `log(start)` is undefined and `to_jax()` + silently returns NaN/-inf, which propagates through every downstream + interpolation. + Example: -------- Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 10, 100]`. """ + def __post_init__(self) -> None: + _validate_continuous_grid( + start=self.start, + stop=self.stop, + n_points=self.n_points, + requires_positive_start=True, + ) + def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" return grid_coordinates.logspace( @@ -228,6 +241,7 @@ def _validate_continuous_grid( start: float, stop: float, n_points: int, + requires_positive_start: bool = False, ) -> None: """Validate the continuous grid parameters. @@ -235,6 +249,8 @@ def _validate_continuous_grid( start: The start value of the grid. stop: The stop value of the grid. n_points: The number of points in the grid. + requires_positive_start: If True, also require `start > 0` (used by + log-spaced grids since `log(x)` is undefined for `x <= 0`). Raises: GridInitializationError: If the grid parameters are invalid. @@ -250,6 +266,15 @@ def _validate_continuous_grid( if not valid_stop_type: error_messages.append("stop must be a scalar int or float value") + # Reject NaN/inf early — `start >= stop` returns False for NaN, so an + # un-finite start would otherwise pass silently and produce a broken grid. + if valid_start_type and not math.isfinite(start): + error_messages.append(f"start must be finite, got {start}") + valid_start_type = False + if valid_stop_type and not math.isfinite(stop): + error_messages.append(f"stop must be finite, got {stop}") + valid_stop_type = False + if not isinstance(n_points, int) or n_points < 1: error_messages.append( f"n_points must be an int greater than 0 but is {n_points}", @@ -258,6 +283,12 @@ def _validate_continuous_grid( if valid_start_type and valid_stop_type and start >= stop: error_messages.append("start must be less than stop") + if valid_start_type and requires_positive_start and start <= 0: + error_messages.append( + f"start must be > 0 for a log-spaced grid (got {start}); " + f"`log(x)` is undefined for `x <= 0`." + ) + if error_messages: msg = format_messages(error_messages) raise GridInitializationError(msg) @@ -290,15 +321,24 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None: f"Non-numeric elements found at indices: {non_numeric}" ) else: - # Check that points are in ascending order - for i in range(len(points) - 1): - if points[i] >= points[i + 1]: - error_messages.append( - "Points must be in strictly ascending order. " - f"Found points[{i}]={points[i]} >= " - f"points[{i + 1}]={points[i + 1]}" - ) - break + # Reject NaN/inf — comparisons with NaN are False, so the + # ascending-order check below would silently let them through. + non_finite = [(i, p) for i, p in enumerate(points) if not math.isfinite(p)] + if non_finite: + error_messages.append( + f"All elements of points must be finite. " + f"Non-finite elements found at: {non_finite}" + ) + else: + # Check that points are in strictly ascending order + for i in range(len(points) - 1): + if points[i] >= points[i + 1]: + error_messages.append( + "Points must be in strictly ascending order. " + f"Found points[{i}]={points[i]} >= " + f"points[{i + 1}]={points[i + 1]}" + ) + break if error_messages: msg = format_messages(error_messages) diff --git a/tests/test_grids.py b/tests/test_grids.py index 045f829b..da2370d9 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -217,6 +217,36 @@ def test_logspace_grid_invalid_start(): LogSpacedGrid(start=1, stop=0, n_points=10) +def test_logspace_grid_rejects_zero_start(): + with pytest.raises(GridInitializationError, match="log-spaced grid"): + LogSpacedGrid(start=0, stop=10, n_points=5) + + +def test_logspace_grid_rejects_negative_start(): + with pytest.raises(GridInitializationError, match="log-spaced grid"): + LogSpacedGrid(start=-1.0, stop=10, n_points=5) + + +def test_validate_continuous_grid_rejects_nan_start(): + with pytest.raises(GridInitializationError, match="start must be finite"): + _validate_continuous_grid(start=float("nan"), stop=10, n_points=5) + + +def test_validate_continuous_grid_rejects_inf_stop(): + with pytest.raises(GridInitializationError, match="stop must be finite"): + _validate_continuous_grid(start=1, stop=float("inf"), n_points=5) + + +def test_irreg_spaced_grid_rejects_nan_points(): + with pytest.raises(GridInitializationError, match="must be finite"): + IrregSpacedGrid(points=(1.0, float("nan"), 3.0)) + + +def test_irreg_spaced_grid_rejects_inf_points(): + with pytest.raises(GridInitializationError, match="must be finite"): + IrregSpacedGrid(points=(1.0, 2.0, float("inf"))) + + def test_replace_mixin(): grid = LinSpacedGrid(start=1, stop=5, n_points=5) new_grid = grid.replace(start=0) From f56b9be523be22ecdf331d6c6ca0e21e29ff6a3b Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 18:24:08 +0200 Subject: [PATCH 10/23] Address PR review: docstrings, type hints, validation, drop separator banners MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tests/test_single_feasible_action.py: drop three decorative section banners (AGENTS.md prohibits `# ---...---` separators); fold the banner prose into the docstrings of the tests/helpers below. - tests/test_single_feasible_action.py: type-annotate `_crra_bequest` and `_alive_utility`'s pref_type / consumption_weight / coefficient_rra arguments (DiscreteState / FloatND). - tests/test_runtime_params.py: type-annotate `_make_action_grid_model` and `_make_action_grid_model_with_stateful_dead`. - src/lcm/simulation/transitions.py: re-run `_validate_all_states_present` in the new `create_regime_state_action_space` (the substitution switch from `create_state_action_space(states=...)` to `base.replace(states=...)` had silently dropped this check). - src/lcm/params/regime_template.py: docstring on `_fail_if_runtime_grid_shadows_function`; fix stale phrasing in `create_regime_params_template` ("matching the state name" → "matching the state or action name"). - src/lcm/interfaces.py: comment why the `_ShockGrid` substitution branch is gated on `in_states` only (state-only by design, AGENTS.md forbids ShockGrids as actions; gate is the explicit enforcement of that invariant). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/interfaces.py | 5 +++ src/lcm/params/regime_template.py | 19 ++++++++- src/lcm/simulation/transitions.py | 11 +++-- tests/test_runtime_params.py | 6 ++- tests/test_single_feasible_action.py | 62 +++++++++++----------------- 5 files changed, 59 insertions(+), 44 deletions(-) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 6c80ddb9..59617eb6 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -278,6 +278,11 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac action_replacements[name] = cast( "ContinuousAction", all_params[points_key] ) + # `_ShockGrid` is state-only by construction (intrinsic + # transitions, forbidden as actions per AGENTS.md). The + # `in_states` gate makes that invariant explicit — a + # `_ShockGrid` reaching the action branch would be a model + # bug, not something this method should silently substitute. elif ( in_states and isinstance(spec, _ShockGrid) diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 2b9c748e..85a66b47 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -32,7 +32,7 @@ def create_regime_params_template( Grids with runtime-supplied values (IrregSpacedGrid without points, `_ShockGrid` without full shock_params) add entries to the template under - pseudo-function keys matching the state name. + pseudo-function keys matching the state or action name. Args: regime: The regime as provided by the user. @@ -112,6 +112,23 @@ def _fail_if_runtime_grid_shadows_function( name: str, kind: str, ) -> None: + """Raise if a runtime grid name collides with an existing function name. + + Runtime-supplied state and action grids contribute pseudo-function entries + to the params template (keyed by the state or action name). Letting such a + pseudo-function entry shadow a real regime function would silently break + parameter resolution, so we reject it at template-construction time. + + Args: + function_params: Template entries collected so far, keyed by + (pseudo-)function name. + name: State or action name being added. + kind: `"state"` or `"action"`, surfaced in the error message. + + Raises: + InvalidNameError: If `name` already exists in `function_params`. + + """ if name in function_params: raise InvalidNameError( f"IrregSpacedGrid {kind} '{name}' (with runtime-supplied " diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 23bfd58c..65c8068b 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -15,6 +15,7 @@ from lcm.interfaces import InternalRegime, StateActionSpace from lcm.simulation.random import generate_simulation_keys +from lcm.state_action_space import _validate_all_states_present from lcm.typing import ( ActionName, Bool1D, @@ -53,11 +54,15 @@ def create_regime_state_action_space( base = internal_regime.state_action_space(regime_params=regime_params) relevant_state_names = internal_regime.variable_info.query("is_state").index - states_for_state_action_space = MappingProxyType( - {sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names} + states_for_state_action_space = { + sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names + } + _validate_all_states_present( + provided_states=states_for_state_action_space, + required_state_names=set(relevant_state_names), ) - return base.replace(states=states_for_state_action_space) + return base.replace(states=MappingProxyType(states_for_state_action_space)) def calculate_next_states( diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index a401608d..f68adec4 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -151,7 +151,7 @@ def test_runtime_grid_matches_fixed(): aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) -def _make_action_grid_model(*, consumption_grid): +def _make_action_grid_model(*, consumption_grid: IrregSpacedGrid) -> Model: """Create a 2-regime model where consumption is the runtime-points action grid.""" alive = Regime( functions={"utility": _utility}, @@ -241,7 +241,9 @@ def test_runtime_action_grid_changes_solution(): assert not jnp.allclose(V_low[0]["alive"], V_high[0]["alive"]) -def _make_action_grid_model_with_stateful_dead(*, consumption_grid): +def _make_action_grid_model_with_stateful_dead( + *, consumption_grid: IrregSpacedGrid +) -> Model: """Variant where `dead` has a `wealth` state so its utility depends on it. Mirrors the aca-model dead regime (carries assets / pref_type so the diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index 04628ed2..b5ce7a3f 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -29,7 +29,7 @@ from lcm.grids import IrregSpacedGrid from lcm.grids.coordinates import get_irreg_coordinate from lcm.regime_building.ndimage import map_coordinates -from lcm.typing import ContinuousAction, ContinuousState, FloatND +from lcm.typing import ContinuousAction, ContinuousState, DiscreteState, FloatND @categorical(ordered=False) @@ -244,12 +244,6 @@ def test_simulate_with_constrained_action_grid(wealth_lo, consumption_lo, label) ) -# --------------------------------------------------------------------------- -# Replicas of the aca-baseline failure path: dead regime with a CRRA bequest -# whose `gamma` is per-pref_type, evaluated through `jnp.where`. -# --------------------------------------------------------------------------- - - @categorical(ordered=False) class PrefType: type_0: int @@ -264,10 +258,10 @@ class AliveDeadRegimeId: def _crra_bequest( assets: ContinuousState, - pref_type, + pref_type: DiscreteState, bequest_shifter: float, - consumption_weight, - coefficient_rra, + consumption_weight: FloatND, + coefficient_rra: FloatND, ) -> FloatND: """Replica of aca_model.agent.preferences.bequest, simplified. @@ -286,7 +280,9 @@ def _crra_bequest( def _alive_utility( - consumption: ContinuousAction, pref_type, consumption_weight + consumption: ContinuousAction, + pref_type: DiscreteState, + consumption_weight: FloatND, ) -> FloatND: """Make pref_type matter in alive's utility too (otherwise pylcm complains).""" alpha = consumption_weight[pref_type] @@ -395,28 +391,19 @@ def test_bequest_gamma_exactly_one_for_one_type_only(): ) -# --------------------------------------------------------------------------- -# Direct probe: `map_coordinates` produces NaN at ±inf / NaN coordinates. -# This is the concrete NaN source — `lower_weight = 1 - inf = -inf` and -# `upper_weight = inf` combined with positive grid values gives `inf - inf = -# NaN`. The aca-baseline NaN-in-V at age 51 is most plausibly traced back to -# *some* upstream computation (next_assets / next_aime, or a state coordinate -# from `get_irreg_coordinate` / `get_*_coordinate` that divides by zero on a -# degenerate grid segment) producing inf, which then poisons the value -# function via this interpolation path. -# --------------------------------------------------------------------------- - - @pytest.mark.parametrize("bad_coord", [jnp.inf, -jnp.inf, jnp.nan]) def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): """`map_coordinates` cannot recover from a non-finite continuous-state coordinate: the linear-interp weights become `inf` and `1 - inf = -inf`, - and `inf * V[k] - inf * V[k-1]` reduces to NaN. + and `inf * V[k] - inf * V[k-1]` reduces to NaN. The aca-baseline NaN + at age 51 traces back to some upstream computation (next_assets, + next_aime, or a coordinate finder that divides by zero on a degenerate + grid segment) feeding inf into this path. Implication for callers: any path that can feed `inf` or `NaN` into the - coordinate finder (e.g. division by zero in a state transition, an - overflow when V values are O(1e8), or a `0/0` in a degenerate - IrregSpacedGrid segment) will produce NaN in V. + coordinate finder (division by zero in a state transition, overflow + when V values are O(1e8), or `0/0` in a degenerate IrregSpacedGrid + segment) will produce NaN in V. """ V_arr = jnp.array([1.0, 5.0, 12.0]) out = map_coordinates(V_arr, coordinates=[jnp.array(bad_coord)]) @@ -446,18 +433,17 @@ def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): ) -# --------------------------------------------------------------------------- -# `validate_initial_conditions` uses `_base_state_action_space` directly, -# which still holds the placeholder zeros for runtime-supplied -# `IrregSpacedGrid`. With a feasibility constraint that the all-zero -# placeholder fails, every subject is reported infeasible — even though the -# real (post-substitution) grid would pass. This affects runtime grids -# regardless of whether they are state or action grids. -# --------------------------------------------------------------------------- - - def _runtime_state_grid_model() -> tuple[Model, dict, dict]: - """A 2-regime model with a runtime-supplied IrregSpacedGrid *state*.""" + """Build a 2-regime model with a runtime-supplied IrregSpacedGrid *state*. + + Reproduces the failure mode where `validate_initial_conditions` reads + `_base_state_action_space` directly — still holding the placeholder + zeros for runtime-supplied `IrregSpacedGrid`s. With a feasibility + constraint that the all-zero placeholder fails, every subject is + reported infeasible even though the real (post-substitution) grid + would pass. The same mechanism affects runtime grids whether they + are states or actions. + """ @categorical(ordered=False) class RuntimeRegimeId: From 3b7be82be2e9e7b82ea80173ce4c5e82148eb132 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:42:03 +0200 Subject: [PATCH 11/23] LogSpacedGrid docstring: drop redundant rationale The validator's error message already explains why; the class docstring only needs the contract. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 8550188d..20811ae9 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -117,9 +117,7 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: class LogSpacedGrid(UniformContinuousGrid): """A logarithmically spaced grid of continuous values. - Requires `start > 0`; otherwise `log(start)` is undefined and `to_jax()` - silently returns NaN/-inf, which propagates through every downstream - interpolation. + Requires `start > 0`. Example: -------- From 8589146010a60274083075f22d5e23f6e7d73a94 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:44:14 +0200 Subject: [PATCH 12/23] IrregSpacedGrid.to_jax docstring: shorter, point at the substituted-grid path Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 20811ae9..664dc829 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -202,10 +202,9 @@ def to_jax(self) -> Float1D: """Convert the grid to a Jax array. Raises `GridInitializationError` for runtime-supplied grids - (`pass_points_at_runtime=True`). Substitution happens at solve / - simulate time via `InternalRegime.state_action_space(regime_params=...)`; - any code path that reads the base grid's points before substitution is - a bug. + (`pass_points_at_runtime=True`). To get the substituted points, + call `internal_regime.state_action_space(regime_params=...)` and + read from `.states[name]` or `.continuous_actions[name]`. """ if self.points is None: raise GridInitializationError( From 541392c0f649741991bf81c7bd9a4e41d6491424 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:44:50 +0200 Subject: [PATCH 13/23] IrregSpacedGrid.to_jax error message: same shape as docstring Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 664dc829..1e4b1140 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -208,13 +208,13 @@ def to_jax(self) -> Float1D: """ if self.points is None: raise GridInitializationError( - f"IrregSpacedGrid was declared with n_points={self.n_points} " - f"and no points; values are supplied at runtime via " - f"params['']['']['points']. Reading the grid " - f"before substitution is a bug — call " + f"IrregSpacedGrid declared with n_points={self.n_points} and " + f"no points; values are supplied at runtime via " + f"params['']['']['points']. To get the " + f"substituted points, call " f"`internal_regime.state_action_space(regime_params=...)` and " - f"read points from there, or use `.n_points` if only the shape " - f"is needed." + f"read from `.states[name]` or `.continuous_actions[name]`. " + f"Use `.n_points` if only the shape is needed." ) return jnp.asarray(self.points) From 54d22b00b7062403790fb7fa9a0795e1335e9649 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:46:11 +0200 Subject: [PATCH 14/23] Use jnp.isfinite in grid validators; drop math import Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/grids/continuous.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 1e4b1140..9a147c44 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -1,5 +1,4 @@ import dataclasses -import math from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -265,10 +264,10 @@ def _validate_continuous_grid( # Reject NaN/inf early — `start >= stop` returns False for NaN, so an # un-finite start would otherwise pass silently and produce a broken grid. - if valid_start_type and not math.isfinite(start): + if valid_start_type and not jnp.isfinite(start): error_messages.append(f"start must be finite, got {start}") valid_start_type = False - if valid_stop_type and not math.isfinite(stop): + if valid_stop_type and not jnp.isfinite(stop): error_messages.append(f"stop must be finite, got {stop}") valid_stop_type = False @@ -320,7 +319,7 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None: else: # Reject NaN/inf — comparisons with NaN are False, so the # ascending-order check below would silently let them through. - non_finite = [(i, p) for i, p in enumerate(points) if not math.isfinite(p)] + non_finite = [(i, p) for i, p in enumerate(points) if not jnp.isfinite(p)] if non_finite: error_messages.append( f"All elements of points must be finite. " From ba38876a5c0065d581890cbaf7225546b4c7db00 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:51:24 +0200 Subject: [PATCH 15/23] Drop cryptic aca_model reference from validator docstring Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index ddc63631..96df0dce 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -151,7 +151,7 @@ def _validate_function_output_state_indexing(regime: Regime) -> list[str]: the indexing produces NaN at runtime instead of the intended scalar. The safe pattern is to take the state as input on the producing function and - return the scalar directly (see `aca_model.agent.preferences.discount_factor`). + return the scalar directly. """ function_output_names = set(regime.functions) discrete_state_names = { From 4d11dead41e59dbc14ea3daf3cf35303cecc4b3b Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:55:59 +0200 Subject: [PATCH 16/23] Validator: also flag function-output indexed by a derived categorical MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Derived categoricals (`regime.derived_categoricals`, function outputs that pylcm treats as categoricals — see https://pylcm.readthedocs.io/en/latest/pandas-interop/#derived-categoricals) suffer the same per-cell broadcast clash as discrete states. Extend `discrete_state_names` in `_validate_function_output_state_indexing` to include them; add a TDD test. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 2 +- tests/test_function_output_state_indexing.py | 47 ++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 96df0dce..9f883e0c 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -156,7 +156,7 @@ def _validate_function_output_state_indexing(regime: Regime) -> list[str]: function_output_names = set(regime.functions) discrete_state_names = { name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid) - } + } | set(regime.derived_categoricals) if not function_output_names or not discrete_state_names: return [] diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_state_indexing.py index f326833c..b4d690ed 100644 --- a/tests/test_function_output_state_indexing.py +++ b/tests/test_function_output_state_indexing.py @@ -123,6 +123,53 @@ def test_safe_pattern_does_not_raise(): ) +def test_function_output_indexed_by_derived_categorical_raises(): + """The check applies to derived categoricals (function outputs treated as + categoricals via `derived_categoricals`), not only states.""" + + @categorical(ordered=False) + class IsMarried: + single: int + married: int + + def _is_married(spousal_income: DiscreteState) -> DiscreteState: + return jnp.int32(spousal_income > 0) + + def _per_marital_scale(some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param)) + + def _utility_clash( + consumption: ContinuousAction, + is_married: DiscreteState, + per_marital_scale: FloatND, + ) -> FloatND: + return per_marital_scale[is_married] * jnp.log(consumption + 1.0) + + @categorical(ordered=True) + class SpousalIncome: + single: int + married_no_inc: int + married_has_inc: int + + with pytest.raises( + RegimeInitializationError, + match=r"per_marital_scale.*is_married", + ): + Regime( + functions={ + "utility": _utility_clash, + "per_marital_scale": _per_marital_scale, + "is_married": _is_married, + }, + states={"spousal_income": DiscreteGrid(SpousalIncome)}, + state_transitions={"spousal_income": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + derived_categoricals={"is_married": DiscreteGrid(IsMarried)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + + def test_constraint_indexing_function_output_by_state_raises(): """The check applies to regime constraints too, not only `functions`.""" From 01608bb049d41448e7a2f9bd9bd97de26030e425 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 06:57:50 +0200 Subject: [PATCH 17/23] create_regime_state_action_space docstring: trim rationale Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/simulation/transitions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 65c8068b..18bdbb6b 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -37,9 +37,7 @@ def create_regime_state_action_space( Continuous action grids declared with `pass_points_at_runtime=True` are completed from `regime_params` (via - `InternalRegime.state_action_space`) — otherwise they would carry the - NaN placeholder used during compilation, which propagates into - `optimal_actions` and ultimately `next_states`. + `InternalRegime.state_action_space`). Args: internal_regime: The internal regime instance. From 6241056a354f852028ebbfa36b02c524523438d1 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:00:42 +0200 Subject: [PATCH 18/23] state_action_space: move private helpers below public function (deep module) Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/state_action_space.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index 28cbce72..7c88e165 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -9,21 +9,6 @@ from lcm.typing import StateName, StateOrActionName -def _grid_to_jax_or_placeholder(grid: Grid) -> Array: - """Return the grid's points, or a NaN placeholder for runtime-supplied grids. - - `IrregSpacedGrid.to_jax()` raises when its points haven't been supplied — that - is the right behaviour everywhere except here: the base state-action space - needs a *shape-correct* array to wire through pytree structures and AOT - tracing before runtime substitution by - `InternalRegime.state_action_space(regime_params=...)`. NaN (rather than - zero) makes any accidental computation against the placeholder fail loudly. - """ - if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: - return jnp.full(grid.n_points, jnp.nan) - return grid.to_jax() - - def create_state_action_space( *, variable_info: pd.DataFrame, @@ -79,6 +64,21 @@ def create_state_action_space( ) +def _grid_to_jax_or_placeholder(grid: Grid) -> Array: + """Return the grid's points, or a NaN placeholder for runtime-supplied grids. + + `IrregSpacedGrid.to_jax()` raises when its points haven't been supplied — that + is the right behaviour everywhere except here: the base state-action space + needs a *shape-correct* array to wire through pytree structures and AOT + tracing before runtime substitution by + `InternalRegime.state_action_space(regime_params=...)`. NaN (rather than + zero) makes any accidental computation against the placeholder fail loudly. + """ + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + return jnp.full(grid.n_points, jnp.nan) + return grid.to_jax() + + def _validate_all_states_present( *, provided_states: dict[StateName, Array], required_state_names: set[StateName] ) -> None: From 2efe9e1b5bf9381badd66775069fdeaafcfd918c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:07:08 +0200 Subject: [PATCH 19/23] Drop application-specific (aca) references from test docstrings pylcm is a general library; references to a particular companion application become stale fast and force readers to know unrelated projects to follow the test rationale. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_function_output_state_indexing.py | 4 +-- tests/test_runtime_params.py | 6 ++--- tests/test_single_feasible_action.py | 27 ++++++++------------ tests/test_validation_scalar_actions.py | 4 +-- 4 files changed, 17 insertions(+), 24 deletions(-) diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_state_indexing.py index b4d690ed..2400f33b 100644 --- a/tests/test_function_output_state_indexing.py +++ b/tests/test_function_output_state_indexing.py @@ -6,8 +6,8 @@ NaN at runtime instead of the intended scalar. The validation layer must raise on construction with a clear message -pointing the user at the safe pattern (see `discount_factor` in -`aca_model.agent.preferences`: take the state as input, return a scalar). +pointing the user at the safe pattern: take the discrete state as input +on the producing function, return a scalar. """ import jax.numpy as jnp diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index f68adec4..61b2414a 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -246,10 +246,8 @@ def _make_action_grid_model_with_stateful_dead( ) -> Model: """Variant where `dead` has a `wealth` state so its utility depends on it. - Mirrors the aca-model dead regime (carries assets / pref_type so the - bequest function can read them). Used to surface NaN propagation - when the simulate path forgets to substitute runtime-supplied action - gridpoints. + Used to surface NaN propagation when the simulate path forgets to + substitute runtime-supplied action gridpoints. """ def _alive_utility( diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index b5ce7a3f..53a18594 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -1,5 +1,4 @@ -"""Reproduce the aca-model NaN failure (issue OpenSourceEconomics/aca-model#9): -solve raises `InvalidValueFunctionError: ... regime 'dead': all values are NaN`. +"""Reproduce ways `solve` can raise `InvalidValueFunctionError: all values are NaN`. Two hypotheses are exercised here: @@ -10,16 +9,16 @@ the tests in this file. `-inf` from all-infeasible cells does not cascade to NaN by itself either. -2. **CRRA bequest with pref_type indexing under jnp.where**: the bequest +2. **CRRA bequest with discrete-state indexing under jnp.where**: a bequest function evaluates *both* branches of `jnp.where(jnp.isclose(gamma, 1), log_branch, power_branch)`. For a parameter set where gamma is exactly 1 - for one preference type but not the other, the power_branch divides by + for one categorical type but not the other, the power_branch divides by `1 - gamma = 0` for the gamma=1 type. NaN/Inf from the *unselected* branch leaks through `jnp.where`'s gradient/forward pass under JIT tracing in some configurations. The model uses an `IrregSpacedGrid` consumption action with runtime-supplied -points to also exercise the `feature/runtime-action-grids` path (PR #338). +points to also exercise the runtime-action-grids substitution path. """ import jax.numpy as jnp @@ -263,7 +262,7 @@ def _crra_bequest( consumption_weight: FloatND, coefficient_rra: FloatND, ) -> FloatND: - """Replica of aca_model.agent.preferences.bequest, simplified. + """Simplified CRRA bequest used to probe NaN propagation under jnp.where. `consumption_weight` and `coefficient_rra` are FloatND indexed by `pref_type`. Both branches of the `jnp.where` are traced. @@ -395,14 +394,11 @@ def test_bequest_gamma_exactly_one_for_one_type_only(): def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): """`map_coordinates` cannot recover from a non-finite continuous-state coordinate: the linear-interp weights become `inf` and `1 - inf = -inf`, - and `inf * V[k] - inf * V[k-1]` reduces to NaN. The aca-baseline NaN - at age 51 traces back to some upstream computation (next_assets, - next_aime, or a coordinate finder that divides by zero on a degenerate - grid segment) feeding inf into this path. + and `inf * V[k] - inf * V[k-1]` reduces to NaN. Implication for callers: any path that can feed `inf` or `NaN` into the coordinate finder (division by zero in a state transition, overflow - when V values are O(1e8), or `0/0` in a degenerate IrregSpacedGrid + when V values are large, or `0/0` in a degenerate IrregSpacedGrid segment) will produce NaN in V. """ V_arr = jnp.array([1.0, 5.0, 12.0]) @@ -414,11 +410,10 @@ def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): """`get_irreg_coordinate` divides by `upper_point - lower_point`. If a runtime-supplied points array contains duplicates, the divisor is 0. - Reproduces a class of failures that is *only* possible under - `feature/runtime-action-grids` / runtime-state-grids when the caller - constructs the points from a parameter that can collapse (e.g. - `geomspace(consumption_floor, MAX, n_points)` with `consumption_floor == - MAX`, or any param-driven `linspace` whose endpoints can coincide). + Reproduces a class of failures specific to runtime-supplied grids: when + the caller constructs points from a parameter that can collapse (e.g. a + `geomspace(lo, hi, n)` with `lo == hi`, or any param-driven `linspace` + whose endpoints can coincide). """ # Duplicate adjacent points where the query value equals the duplicate. # `searchsorted([0, 1, 1], 1.0, side='right')=3` → clipped to n-1=2, diff --git a/tests/test_validation_scalar_actions.py b/tests/test_validation_scalar_actions.py index 6855e4b9..89289f5b 100644 --- a/tests/test_validation_scalar_actions.py +++ b/tests/test_validation_scalar_actions.py @@ -7,8 +7,8 @@ Validation must do the same. This pattern arises in models that pass multi-dimensional lookup tables as parameters -via MappingLeaf — e.g. tax schedules and pension accrual tables in aca-model, or -tax-transfer schedules in ttsim/gettsim. +via MappingLeaf — e.g. tax schedules, pension accrual tables, or tax-transfer +schedules. """ import jax From f1c4d5dc2be4b8c28128a45b504276f31e5230ce Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:16:22 +0200 Subject: [PATCH 20/23] Validator: rename to discrete_grid_names, also include discrete actions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The variable previously named `discrete_state_names` accumulated state DiscreteGrids, derived categoricals, and now discrete actions — all three suffer the same per-cell broadcast clash when a consumer does `func_output[X]`. Renamed the variable, the two helpers (`_validate_function_output_grid_indexing`, `_find_function_output_grid_indexing`), the test module (`test_function_output_grid_indexing.py`), and the error-message wording ("discrete state" → "discrete grid"). Added a TDD test for the discrete-action case. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 53 +++++++++++-------- ... => test_function_output_grid_indexing.py} | 53 ++++++++++++++++--- 2 files changed, 76 insertions(+), 30 deletions(-) rename tests/{test_function_output_state_indexing.py => test_function_output_grid_indexing.py} (78%) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 9f883e0c..1e2be941 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -128,7 +128,7 @@ def validate_logical_consistency(regime: Regime) -> None: error_messages.extend(_validate_active(regime.active)) error_messages.extend(_validate_state_transitions(regime)) - error_messages.extend(_validate_function_output_state_indexing(regime)) + error_messages.extend(_validate_function_output_grid_indexing(regime)) states_and_actions_overlap = set(regime.states) & set(regime.actions) if states_and_actions_overlap: @@ -142,22 +142,29 @@ def validate_logical_consistency(regime: Regime) -> None: raise RegimeInitializationError(msg) -def _validate_function_output_state_indexing(regime: Regime) -> list[str]: - """Detect the regime-function-output / state-indexed-input name clash. +def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: + """Detect the regime-function-output / discrete-grid-indexed-input name clash. - A regime function whose output is then re-indexed by a discrete state inside - another consumer (function, constraint, or transition) is a silent footgun: - pylcm broadcasts function outputs to per-cell scalars before consumption, so - the indexing produces NaN at runtime instead of the intended scalar. + A regime function whose output is then re-indexed by a discrete grid (state, + action, or derived categorical) inside another consumer (function, + constraint, or transition) is a silent footgun: pylcm broadcasts function + outputs to per-cell scalars before consumption, so the indexing produces + NaN at runtime instead of the intended scalar. - The safe pattern is to take the state as input on the producing function and - return the scalar directly. + The safe pattern is to take the discrete grid as input on the producing + function and return the scalar directly. """ function_output_names = set(regime.functions) - discrete_state_names = { - name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid) - } | set(regime.derived_categoricals) - if not function_output_names or not discrete_state_names: + discrete_grid_names = ( + {name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid)} + | { + name + for name, grid in regime.actions.items() + if isinstance(grid, DiscreteGrid) + } + | set(regime.derived_categoricals) + ) + if not function_output_names or not discrete_grid_names: return [] consumers: list[tuple[str, Callable]] = [] @@ -168,31 +175,31 @@ def _validate_function_output_state_indexing(regime: Regime) -> list[str]: errors: list[str] = [] for consumer_name, func in consumers: - clashes = _find_function_output_state_indexing( + clashes = _find_function_output_grid_indexing( func=func, function_output_names=function_output_names, - discrete_state_names=discrete_state_names, + discrete_grid_names=discrete_grid_names, ) - for func_output_name, state_name in clashes: + for func_output_name, grid_name in clashes: errors.append( f"Consumer '{consumer_name}' indexes regime function output " - f"'{func_output_name}' by discrete state '{state_name}' " - f"(`{func_output_name}[{state_name}]`). pylcm broadcasts " + f"'{func_output_name}' by discrete grid '{grid_name}' " + f"(`{func_output_name}[{grid_name}]`). pylcm broadcasts " f"function outputs to per-cell scalars before consumption, so " f"this indexing silently produces NaN. Refactor " - f"'{func_output_name}' to take '{state_name}' as input and " + f"'{func_output_name}' to take '{grid_name}' as input and " f"return the scalar directly." ) return errors -def _find_function_output_state_indexing( +def _find_function_output_grid_indexing( *, func: Callable, function_output_names: set[str], - discrete_state_names: set[str], + discrete_grid_names: set[str], ) -> list[tuple[str, str]]: - """Return `(function_output_name, state_name)` clashes inside `func`'s body.""" + """Return `(function_output_name, grid_name)` clashes inside `func`'s body.""" try: source = textwrap.dedent(inspect.getsource(func)) except OSError, TypeError: @@ -212,7 +219,7 @@ def _find_function_output_state_indexing( continue if not isinstance(node.slice, ast.Name): continue - if node.slice.id not in discrete_state_names: + if node.slice.id not in discrete_grid_names: continue clashes.append((node.value.id, node.slice.id)) return clashes diff --git a/tests/test_function_output_state_indexing.py b/tests/test_function_output_grid_indexing.py similarity index 78% rename from tests/test_function_output_state_indexing.py rename to tests/test_function_output_grid_indexing.py index 2400f33b..fa0f2753 100644 --- a/tests/test_function_output_state_indexing.py +++ b/tests/test_function_output_grid_indexing.py @@ -1,12 +1,13 @@ -"""Tests for the regime-function-output / state-indexed-input name clash. +"""Tests for the regime-function-output / discrete-grid-indexed-input name clash. -A regime function whose output is then re-indexed by a discrete state -inside another function is a silent footgun: pylcm broadcasts function -outputs to per-cell scalars before consumption, so the indexing produces -NaN at runtime instead of the intended scalar. +A regime function whose output is then re-indexed by a discrete grid +(state, action, or derived categorical) inside another function is a +silent footgun: pylcm broadcasts function outputs to per-cell scalars +before consumption, so the indexing produces NaN at runtime instead of +the intended scalar. The validation layer must raise on construction with a clear message -pointing the user at the safe pattern: take the discrete state as input +pointing the user at the safe pattern: take the discrete grid as input on the producing function, return a scalar. """ @@ -15,7 +16,7 @@ from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical from lcm.exceptions import RegimeInitializationError -from lcm.typing import ContinuousAction, DiscreteState, FloatND +from lcm.typing import ContinuousAction, DiscreteAction, DiscreteState, FloatND @categorical(ordered=False) @@ -170,6 +171,44 @@ class SpousalIncome: ) +def test_function_output_indexed_by_discrete_action_raises(): + """The check applies to discrete actions, not only states/derived categoricals.""" + + @categorical(ordered=False) + class WorkChoice: + no_work: int + work: int + + def _per_choice_scale(some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param)) + + def _utility_clash( + consumption: ContinuousAction, + labor_supply: DiscreteAction, + per_choice_scale: FloatND, + ) -> FloatND: + return per_choice_scale[labor_supply] * jnp.log(consumption + 1.0) + + with pytest.raises( + RegimeInitializationError, + match=r"per_choice_scale.*labor_supply", + ): + Regime( + functions={ + "utility": _utility_clash, + "per_choice_scale": _per_choice_scale, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={ + "consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5), + "labor_supply": DiscreteGrid(WorkChoice), + }, + transition=_next_regime, + active=lambda age: age < 2, + ) + + def test_constraint_indexing_function_output_by_state_raises(): """The check applies to regime constraints too, not only `functions`.""" From d9f37ceeda2382f527dad8c99c9bc3144696a502 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 07:22:32 +0200 Subject: [PATCH 21/23] Validator: tighten to actual footgun shape; correct behaviour description MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous docstring claimed the indexing 'silently produces NaN', but a disabled-validator probe shows otherwise: - When the producer takes the discrete grid as input, its output is a per-cell scalar; `func_output[grid]` raises `IndexError: Too many indices` at trace time. This is the real footgun the validator should catch. - When the producer does NOT take the discrete grid as input, its output stays array-shaped and `func_output[grid]` is correct code that solves to sensible V values. The previous validator flagged both shapes — including the safe one — as a clash. Tighten: only fire when the producing function also takes the discrete grid as input. Update the description to match observed behaviour (IndexError, not NaN). Add a regression test that exercises the array-valued-producer + state-indexed-consumer shape and asserts it builds without raising. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lcm/regime_building/validation.py | 40 ++++++--- tests/test_function_output_grid_indexing.py | 99 ++++++++++++++------- 2 files changed, 93 insertions(+), 46 deletions(-) diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 1e2be941..92656692 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -145,14 +145,12 @@ def validate_logical_consistency(regime: Regime) -> None: def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: """Detect the regime-function-output / discrete-grid-indexed-input name clash. - A regime function whose output is then re-indexed by a discrete grid (state, - action, or derived categorical) inside another consumer (function, - constraint, or transition) is a silent footgun: pylcm broadcasts function - outputs to per-cell scalars before consumption, so the indexing produces - NaN at runtime instead of the intended scalar. - - The safe pattern is to take the discrete grid as input on the producing - function and return the scalar directly. + The unsafe pattern is: a regime function `f` takes a discrete grid `g` + (state, action, or derived categorical) as an input — so `f`'s output + is a per-cell scalar — and a consumer then indexes `f[g]`. The + consumer is indexing a 0-d array by a scalar integer, which raises + `IndexError` at trace time. The fix is to drop the redundant `[g]` + in the consumer (or refactor `f` not to take `g`). """ function_output_names = set(regime.functions) discrete_grid_names = ( @@ -167,6 +165,19 @@ def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: if not function_output_names or not discrete_grid_names: return [] + # Only treat `func_output[grid]` as unsafe when the producing function + # *also* takes `grid` as an input — that is the case where the output + # is per-cell scalar and the consumer's indexing is wrong. If the + # producing function does not take `grid`, its output shape is + # whatever it computed (typically an array indexable by `grid`) and + # the consumer pattern is correct. + function_inputs: dict[str, set[str]] = {} + for name, func in regime.functions.items(): + try: + function_inputs[name] = set(inspect.signature(func).parameters) + except ValueError, TypeError: + function_inputs[name] = set() + consumers: list[tuple[str, Callable]] = [] consumers.extend(regime.functions.items()) consumers.extend(regime.constraints.items()) @@ -181,14 +192,17 @@ def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: discrete_grid_names=discrete_grid_names, ) for func_output_name, grid_name in clashes: + if grid_name not in function_inputs.get(func_output_name, set()): + continue errors.append( f"Consumer '{consumer_name}' indexes regime function output " f"'{func_output_name}' by discrete grid '{grid_name}' " - f"(`{func_output_name}[{grid_name}]`). pylcm broadcasts " - f"function outputs to per-cell scalars before consumption, so " - f"this indexing silently produces NaN. Refactor " - f"'{func_output_name}' to take '{grid_name}' as input and " - f"return the scalar directly." + f"(`{func_output_name}[{grid_name}]`), but '{func_output_name}' " + f"already takes '{grid_name}' as input — its output is a " + f"per-cell scalar, so the indexing raises IndexError at trace " + f"time. Drop the redundant `[{grid_name}]` in '{consumer_name}', " + f"or refactor '{func_output_name}' not to take '{grid_name}' " + f"as input." ) return errors diff --git a/tests/test_function_output_grid_indexing.py b/tests/test_function_output_grid_indexing.py index fa0f2753..adc37816 100644 --- a/tests/test_function_output_grid_indexing.py +++ b/tests/test_function_output_grid_indexing.py @@ -1,14 +1,14 @@ """Tests for the regime-function-output / discrete-grid-indexed-input name clash. -A regime function whose output is then re-indexed by a discrete grid -(state, action, or derived categorical) inside another function is a -silent footgun: pylcm broadcasts function outputs to per-cell scalars -before consumption, so the indexing produces NaN at runtime instead of -the intended scalar. - -The validation layer must raise on construction with a clear message -pointing the user at the safe pattern: take the discrete grid as input -on the producing function, return a scalar. +The unsafe pattern is: a regime function `f` takes a discrete grid `g` (state, +action, or derived categorical) as an input — so `f`'s output is a per-cell +scalar — and a consumer then indexes `f[g]`. The consumer is indexing a 0-d +array by a scalar integer, which raises `IndexError` at trace time. The +validator catches the pattern at construction so the user gets a clear message +instead of a cryptic JAX trace error during solve. + +The fix is to drop the redundant `[g]` in the consumer (or refactor `f` not +to take `g`). """ import jax.numpy as jnp @@ -31,19 +31,21 @@ class RegimeId: dead: int -def _per_type_scale(some_param: FloatND) -> FloatND: - """Returns one scale per pref_type — depends only on a per-type Series param.""" - return jnp.abs(1.0 / (1.0 - some_param)) +def _per_type_scale_takes_pref_type( + pref_type: DiscreteState, some_param: FloatND +) -> FloatND: + """Takes pref_type — output is per-cell scalar; consumer must not re-index.""" + return jnp.abs(1.0 / (1.0 - some_param[pref_type])) -def _utility_with_state_indexed_function_output( +def _utility_redundantly_indexes( consumption: ContinuousAction, pref_type: DiscreteState, per_type_scale: FloatND, ) -> FloatND: - # The clash: per_type_scale is registered as a regime function output - # (returns shape (n_pref_types,)), but here it is consumed and indexed - # by the `pref_type` discrete state. + # The clash: per_type_scale's producer takes pref_type, so its output is a + # per-cell scalar. Indexing that scalar by pref_type again raises IndexError + # at trace time. return per_type_scale[pref_type] * jnp.log(consumption + 1.0) @@ -54,8 +56,8 @@ def _next_regime(period: int) -> FloatND: def _make_clashing_model() -> Model: alive = Regime( functions={ - "utility": _utility_with_state_indexed_function_output, - "per_type_scale": _per_type_scale, + "utility": _utility_redundantly_indexes, + "per_type_scale": _per_type_scale_takes_pref_type, }, states={"pref_type": DiscreteGrid(PrefType)}, state_transitions={"pref_type": None}, @@ -76,8 +78,8 @@ def _make_clashing_model() -> Model: def test_function_output_indexed_by_state_raises(): - """A regime function output indexed by a discrete state inside another regime - function must raise on construction (silent NaN bug otherwise).""" + """A regime function output redundantly indexed by a discrete state inside + another regime function must raise on construction.""" with pytest.raises( RegimeInitializationError, match=r"per_type_scale.*pref_type", @@ -85,7 +87,7 @@ def test_function_output_indexed_by_state_raises(): _make_clashing_model() -def _utility_safe( +def _utility_consumes_scalar( consumption: ContinuousAction, pref_type: DiscreteState, # noqa: ARG001 per_type_scale: FloatND, @@ -94,17 +96,48 @@ def _utility_safe( return per_type_scale * jnp.log(consumption + 1.0) -def _per_type_scale_safe(pref_type: DiscreteState, some_param: FloatND) -> FloatND: - """Safe variant: takes pref_type, returns scalar (mirrors discount_factor).""" - return jnp.abs(1.0 / (1.0 - some_param[pref_type])) +def test_safe_pattern_does_not_raise(): + """The safe pattern (function takes the state, returns a scalar; consumer + uses it directly) builds fine.""" + alive = Regime( + functions={ + "utility": _utility_consumes_scalar, + "per_type_scale": _per_type_scale_takes_pref_type, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) -def test_safe_pattern_does_not_raise(): - """The safe pattern (function takes the state, returns a scalar) builds fine.""" +def _per_type_scale_array_output(some_param: FloatND) -> FloatND: + """Does NOT take pref_type — output is `(n_pref_types,)`-shaped. + + A consumer indexing this output by `pref_type` is correct: the indexing + selects the per-type entry. The validator must NOT flag this case. + """ + return jnp.abs(1.0 / (1.0 - some_param)) + + +def test_array_valued_producer_indexed_by_state_does_not_raise(): + """When the producing function does NOT take the discrete grid as input, + its output stays array-shaped and `func_output[grid]` is correct code.""" alive = Regime( functions={ - "utility": _utility_safe, - "per_type_scale": _per_type_scale_safe, + "utility": _utility_redundantly_indexes, + "per_type_scale": _per_type_scale_array_output, }, states={"pref_type": DiscreteGrid(PrefType)}, state_transitions={"pref_type": None}, @@ -136,8 +169,8 @@ class IsMarried: def _is_married(spousal_income: DiscreteState) -> DiscreteState: return jnp.int32(spousal_income > 0) - def _per_marital_scale(some_param: FloatND) -> FloatND: - return jnp.abs(1.0 / (1.0 - some_param)) + def _per_marital_scale(is_married: DiscreteState, some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param[is_married])) def _utility_clash( consumption: ContinuousAction, @@ -179,8 +212,8 @@ class WorkChoice: no_work: int work: int - def _per_choice_scale(some_param: FloatND) -> FloatND: - return jnp.abs(1.0 / (1.0 - some_param)) + def _per_choice_scale(labor_supply: DiscreteAction, some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param[labor_supply])) def _utility_clash( consumption: ContinuousAction, @@ -228,7 +261,7 @@ def _constraint_indexing_function_output( "utility": lambda consumption, pref_type, per_type_scale: jnp.log( # noqa: ARG005 consumption + 1.0 ), - "per_type_scale": _per_type_scale, + "per_type_scale": _per_type_scale_takes_pref_type, }, states={"pref_type": DiscreteGrid(PrefType)}, state_transitions={"pref_type": None}, From d66d85a580c9b437112ec2510113d6f4e62817b0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:30:42 +0200 Subject: [PATCH 22/23] Boilerplate refresh: dags module, current pixi/hook pins, drop stale entries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - AGENTS.md: include @.ai-instructions/modules/dags.md (pylcm is dags-built; the convention/idioms are project-relevant). - .pre-commit-config.yaml: narrow GFM mdformat `files:` to root-level (AGENTS|CLAUDE|README); the `modules/.*|profiles/.*` patterns were intended for the .ai-instructions repo and never matched here. Bump check-jsonschema 0.37.1 → 0.37.2 and ruff v0.15.11 → v0.15.12 to current. - .gitignore: drop unused `# pytask` section (pylcm doesn't use pytask). - .github/workflows/main.yml: pixi-version v0.66.0 → v0.67.2 to match the locally installed pixi. All other pinned hooks/actions verified at-or-newer than the boilerplate template baseline. --- .github/workflows/main.yml | 10 +++++----- .gitignore | 3 --- .pre-commit-config.yaml | 6 +++--- AGENTS.md | 1 + 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d0ea4a07..80b272c1 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cpu @@ -59,7 +59,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: type-checking @@ -82,7 +82,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cuda12 @@ -101,7 +101,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cuda12 @@ -116,7 +116,7 @@ jobs: # - uses: actions/checkout@v6 # - uses: prefix-dev/setup-pixi@v0.9.5 # with: - # pixi-version: v0.66.0 + # pixi-version: v0.67.2 # cache: true # cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} # environments: tests-cpu diff --git a/.gitignore b/.gitignore index db181e4d..03acd3c8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,9 +31,6 @@ docs/_build/ .pixi/ node_modules/ -# pytask -.pytask.sqlite3 - # Python __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 270b9b25..f51ffb31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,11 +50,11 @@ repos: hooks: - id: yamllint - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.37.1 + rev: 0.37.2 hooks: - id: check-github-workflows - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.11 + rev: v0.15.12 hooks: - id: ruff-check args: @@ -86,7 +86,7 @@ repos: args: - --wrap - '88' - files: (AGENTS\.md|CLAUDE\.md|README\.md|modules/.*\.md|profiles/.*\.md) + files: (AGENTS\.md|CLAUDE\.md|README\.md) - id: mdformat additional_dependencies: - mdformat-myst diff --git a/AGENTS.md b/AGENTS.md index f9777113..7aea2844 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,6 @@ @.ai-instructions/profiles/tier-a.md @.ai-instructions/modules/jax.md @.ai-instructions/modules/pandas.md @.ai-instructions/modules/plotting.md +@.ai-instructions/modules/dags.md # PyLCM From f18d27c7b78ece836bf35bdfddd65d9e35432fc6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:33:21 +0200 Subject: [PATCH 23/23] Bump .ai-instructions: pyproject-fmt + ruff + check-jsonschema rev pins --- .ai-instructions | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ai-instructions b/.ai-instructions index d15d554a..528a0118 160000 --- a/.ai-instructions +++ b/.ai-instructions @@ -1 +1 @@ -Subproject commit d15d554a54785e84cb80165443fa432a22adc45c +Subproject commit 528a0118ff9e02233bfc073da891b60e81b34754