From 4123fe9739c1c4bccebaa149985d0415a4272ef1 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 07:23:00 +0200 Subject: [PATCH 01/41] Anchor consumption grid lower bound to consumption_floor parameter Consumption is now declared as `IrregSpacedGrid(n_points=N)` (no fixed points). Callers inject log-spaced gridpoints from `consumption_floor` to $300k via `aca_model.consumption_grid. inject_consumption_points(params=..., model=...)` before solving. This means the lowest consumption choice equals the per-iteration floor, removing a degree of freedom from the grid and eliminating the previous mismatch where c < floor was a legal grid choice. Requires pylcm support for runtime-supplied points on continuous action grids (PR OpenSourceEconomics/pylcm#338). aca-model CI now installs pylcm from the matching `feature/runtime-action-grids` branch. Other changes: - `consumption_grid.py`: new module with `compute_consumption_points` and `inject_consumption_points` helpers. - `benchmark.get_benchmark_params(*, model=None)`: when `model` is given, returns params with consumption points injected. - `benchmark.get_benchmark_initial_conditions`: switch from `.start` / `.stop` to `to_jax().min()` / `.max()` so it works on both `LinSpacedGrid` and `PiecewiseLinSpacedGrid` (the AIME grid is now piecewise; this was a pre-existing bug surfacing as `AttributeError`). Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 4 +- src/aca_model/baseline/regimes/_common.py | 17 +---- src/aca_model/benchmark.py | 20 +++++- src/aca_model/consumption_grid.py | 76 +++++++++++++++++++++++ tests/test_benchmark.py | 2 +- 5 files changed, 97 insertions(+), 22 deletions(-) create mode 100644 src/aca_model/consumption_grid.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 67c82fa..a0247b8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,10 +26,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm + - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@main" + git+https://github.com/OpenSourceEconomics/pylcm.git@feature/runtime-action-grids" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 688a504..efd3b73 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -194,14 +194,6 @@ class Grids: # bend points (0 → kink_0 → kink_1 → kink_2). Total = 32. _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) -# Consumption grid: log-spaced from the lower bound of the -# `consumption_floor` parameter (BOUNDS in task_estimate_parameters) -# up to a high value that brackets the unconstrained optimum for the -# richest agents in the state space. Mirrors the struct-ret design -# (concentrate gridpoints where CRRA curvature is highest). -_CONSUMPTION_GRID_START: float = 100.0 -_CONSUMPTION_GRID_STOP: float = 300_000.0 - def build_grids( grid_config: GridConfig = GRID_CONFIG, @@ -273,14 +265,7 @@ def build_grids( ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), consumption=IrregSpacedGrid( - points=tuple( - float(c) - for c in np.geomspace( - _CONSUMPTION_GRID_START, - _CONSUMPTION_GRID_STOP, - num=grid_config.n_consumption_gridpoints, - ) - ), + n_points=grid_config.n_consumption_gridpoints, ), wage_res=wage_res, hcc_persistent=hcc_persistent, diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5a822c9..a9d7128 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -44,6 +44,7 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.model import create_model from aca_model.config import BENCHMARK_GRID_CONFIG +from aca_model.consumption_grid import inject_consumption_points _PARAMS_FILE = ( Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl" @@ -96,17 +97,26 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod ) -def get_benchmark_params() -> tuple[dict[str, Any], dict[str, Any]]: +def get_benchmark_params( + *, model: Model | None = None +) -> tuple[dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, params)` snapshot. Pref-type-indexed `pd.Series` in `params` are truncated to `_N_BENCHMARK_PREF_TYPES` rows so they line up with `BenchmarkPrefType`'s categories. + + When `model` is provided, consumption gridpoints are injected into + `params` for each regime that declares `consumption` as an + `IrregSpacedGrid` with runtime-supplied points. The lower bound is + read from `params["consumption_floor"]`. """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) fixed_params = data["fixed_params"] params = _truncate_pref_type_indexed(data["params"]) + if model is not None: + params = inject_consumption_points(params=params, model=model) return fixed_params, params @@ -143,10 +153,14 @@ def get_benchmark_initial_conditions( regime = rng.choice(regime_ids, size=n_subjects).astype(np.int32) # Grid ranges come from any of the five regimes (shared structure). + # Use to_jax() so the helper handles both LinSpacedGrid and + # PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`). ref_regime = model.regimes[_INITIAL_REGIMES[0]] grids = ref_regime.states - assets_lo, assets_hi = grids["assets"].start, grids["assets"].stop - aime_lo, aime_hi = grids["aime"].start, grids["aime"].stop + assets_pts = np.asarray(grids["assets"].to_jax()) + aime_pts = np.asarray(grids["aime"].to_jax()) + assets_lo, assets_hi = float(assets_pts.min()), float(assets_pts.max()) + aime_lo, aime_hi = float(aime_pts.min()), float(aime_pts.max()) hcc_p_pts = np.asarray(grids["hcc_persistent"].to_jax()) hcc_t_pts = np.asarray(grids["hcc_transitory"].to_jax()) wage_res_pts = np.asarray(grids["log_ft_wage_res"].to_jax()) diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py new file mode 100644 index 0000000..d670d3c --- /dev/null +++ b/src/aca_model/consumption_grid.py @@ -0,0 +1,76 @@ +"""Runtime-supplied gridpoints for the consumption action. + +Consumption is declared as `IrregSpacedGrid(n_points=N)` in +`baseline.regimes._common.build_grids` so the lower bound can track +the per-iteration `consumption_floor` parameter. Callers must inject +the actual gridpoints into `params` via `inject_consumption_points` +before calling `model.solve()` / `model.simulate()`. +""" + +from collections.abc import Mapping +from typing import Any + +import jax.numpy as jnp +from jax import Array +from lcm import IrregSpacedGrid, Model + +MAX_CONSUMPTION: float = 300_000.0 +"""Upper bound of the consumption grid in $/year. Brackets the unconstrained +CRRA optimum for the highest-asset, highest-income agents in the state space.""" + + +def compute_consumption_points( + *, consumption_floor: float, n_points: int +) -> Array: + """Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`. + + Args: + consumption_floor: Lowest gridpoint, equal to the `consumption_floor` + parameter so the agent cannot pick `c < floor` even when saving + from a transfer top-up. + n_points: Total number of gridpoints. + + Returns: + 1-D float array of length `n_points`. + """ + return jnp.geomspace(consumption_floor, MAX_CONSUMPTION, num=n_points) + + +def inject_consumption_points( + *, + params: Mapping[str, Any], + model: Model, + consumption_floor: float | None = None, +) -> dict[str, Any]: + """Inject consumption gridpoints into per-regime params. + + Walks `model.regimes`, finds those with `consumption` declared as + `IrregSpacedGrid` with runtime-supplied points, and writes + `params[regime_name]["consumption"] = {"points": }`. + + Args: + params: Existing params mapping. Returned as a new dict; the input is + not mutated. + model: Model whose regime specs determine which regimes need points. + consumption_floor: Lowest gridpoint. When `None`, taken from + `params["consumption_floor"]`. + + Returns: + New params dict with consumption points injected. + """ + if consumption_floor is None: + consumption_floor = float(params["consumption_floor"]) + out: dict[str, Any] = dict(params) + for regime_name, regime in model.regimes.items(): + grid = regime.actions.get("consumption") + if not ( + isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime + ): + continue + points = compute_consumption_points( + consumption_floor=consumption_floor, n_points=grid.n_points + ) + regime_entry = dict(out.get(regime_name, {})) + regime_entry["consumption"] = {"points": points} + out[regime_name] = regime_entry + return out diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 72fb473..c1e48e0 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -13,7 +13,7 @@ def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 model = create_benchmark_model() - _, params = get_benchmark_params() + _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 ) From 134286108b7445f3e17e8824bcdd1739a98b6089 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 18:31:50 +0200 Subject: [PATCH 02/41] Refactor utility_scale_factor to take pref_type, return scalar MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `utility_scale_factor` was registered as a regime function returning a (n_pref_types,) array, then re-indexed by `pref_type` inside `bequest` and `utility`. pylcm broadcasts function outputs to per-cell scalars before consumption, so that `[pref_type]` indexing produced silent NaN in the dead regime's V — surfaced as the all-NaN failure on the ASV benchmark. Mirror the `discount_factor` pattern: take the state as input, return a per-cell scalar. Drop the `[pref_type]` indexing on `utility_scale_factor` from `utility` and `bequest` (those still index the params-Series `consumption_weight` and `coefficient_rra`, which is the supported pattern — only DAG function outputs are pre-broadcast). The matching pylcm validator (PR #338) now raises a clear `RegimeInitializationError` when a function output is consumed via state-indexing in a downstream consumer; this aca-model change is the fix that lets the dead regime construct under that validator. Tests in `test_preferences.py` and `test_model_components.py` updated to pass scalar `utility_scale_factor` and supply the new `pref_type` arg to `utility_scale_factor`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/preferences.py | 42 ++++++++++++++++++------------ src/aca_model/consumption_grid.py | 8 ++---- tests/test_model_components.py | 8 +++--- tests/test_preferences.py | 11 ++++++-- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 3b0bb5e..28a8367 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -133,8 +133,11 @@ def utility( """Within-period utility: CES aggregator over consumption and leisure. u = utility_scale_factor * ((c/eq_scale)^α * l^(1-α))^(1-γ) / (1-γ) - with log case for γ=1. `consumption_weight`, `coefficient_rra`, and - `utility_scale_factor` are indexed by `pref_type`. + with log case for γ=1. `consumption_weight` and `coefficient_rra` are + pref-type-indexed Series sourced directly from params; `utility_scale_factor` + is a regime-function output (already a per-cell scalar — must NOT be + re-indexed by pref_type, see `aca_model.agent.preferences.utility_scale_factor` + for why). """ alpha = consumption_weight[pref_type] gamma = coefficient_rra[pref_type] @@ -147,7 +150,7 @@ def utility( jnp.log(composite), composite**one_minus_gamma / one_minus_gamma, ) - return u * utility_scale_factor[pref_type] + return u * utility_scale_factor def discount_factor( @@ -164,6 +167,7 @@ def discount_factor( def utility_scale_factor( + pref_type: DiscreteState, average_consumption: float, consumption_weight: FloatND, coefficient_rra: FloatND, @@ -174,26 +178,29 @@ def utility_scale_factor( reference_age: int, scale_reference_age: int, ) -> FloatND: - """Compute scale factor so utility is approximately 1 at typical values. - - Uses leisure at `scale_reference_age` when working `scale_reference_hours` - (after fixed costs) and average consumption. Returns one scale per - preference type, indexed by pref_type. + """Compute the scale factor so utility is approximately 1 at typical values. + + Returns the scalar for the cell's `pref_type`. Mirrors the `discount_factor` + pattern: take the state as input, return a per-cell scalar. Registering this + as a regime function and then doing `utility_scale_factor[pref_type]` in a + downstream consumer is invalid — pylcm broadcasts function outputs to + per-cell scalars before consumption, and the validator in + `lcm.regime_building.validation` raises on that clash. """ + alpha = consumption_weight[pref_type] + gamma = coefficient_rra[pref_type] age_offset = scale_reference_age - reference_age average_leisure = ( time_endowment - scale_reference_hours - (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset) ) - u_cons = average_consumption**consumption_weight - u_leisure = average_leisure ** (1.0 - consumption_weight) + u_cons = average_consumption**alpha + u_leisure = average_leisure ** (1.0 - alpha) - one_minus_gamma = jnp.where( - jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra - ) + one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) raw = jnp.where( - jnp.isclose(coefficient_rra, 1.0), + jnp.isclose(gamma, 1.0), jnp.log(u_cons * u_leisure), (u_cons * u_leisure) ** one_minus_gamma / one_minus_gamma, ) @@ -237,8 +244,9 @@ def bequest( """Bequest function for terminal/dead states. bequest = scale * bwt * (max(0,a) + shifter)^(α*(1-γ)) / (1-γ) - `consumption_weight`, `coefficient_rra`, and `utility_scale_factor` - are indexed by `pref_type`. + `consumption_weight` and `coefficient_rra` are pref-type-indexed Series + from params; `utility_scale_factor` is a regime-function output (already a + per-cell scalar — must NOT be re-indexed by pref_type). """ alpha = consumption_weight[pref_type] gamma = coefficient_rra[pref_type] @@ -250,4 +258,4 @@ def bequest( jnp.log(assets_shifted), assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, ) - return val * scaled_bequest_weight * utility_scale_factor[pref_type] + return val * scaled_bequest_weight * utility_scale_factor diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index d670d3c..8ba8bc4 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -19,9 +19,7 @@ CRRA optimum for the highest-asset, highest-income agents in the state space.""" -def compute_consumption_points( - *, consumption_floor: float, n_points: int -) -> Array: +def compute_consumption_points(*, consumption_floor: float, n_points: int) -> Array: """Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`. Args: @@ -63,9 +61,7 @@ def inject_consumption_points( out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): grid = regime.actions.get("consumption") - if not ( - isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime - ): + if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): continue points = compute_consumption_points( consumption_floor=consumption_floor, n_points=grid.n_points diff --git a/tests/test_model_components.py b/tests/test_model_components.py index cbb2f72..9876ac8 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -84,7 +84,7 @@ def test_utility_positive_leisure() -> None: consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -97,7 +97,7 @@ def test_utility_log_case() -> None: consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([1.0, 1.0, 1.0]), equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) composite = 10000.0**0.4 * 3000.0**0.6 expected = jnp.log(composite) @@ -112,7 +112,7 @@ def test_bequest_positive_assets() -> None: scaled_bequest_weight=0.5, consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -125,7 +125,7 @@ def test_bequest_zero_assets() -> None: scaled_bequest_weight=0.5, consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) assert result < 0 # CRRA with γ>1 gives negative values diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 8c1921c..4ff2266 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -33,6 +33,7 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -43,11 +44,12 @@ def test_utility_scale_factor_crra() -> None: reference_age=REFERENCE_AGE, scale_reference_age=SCALE_REFERENCE_AGE, ) - assert jnp.isclose(result[0], 9_233_279_397_806_166.0, rtol=1e-6) + assert jnp.isclose(result, 9_233_279_397_806_166.0, rtol=1e-6) def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -58,7 +60,7 @@ def test_utility_scale_factor_log() -> None: reference_age=REFERENCE_AGE, scale_reference_age=SCALE_REFERENCE_AGE, ) - assert jnp.isclose(result[0], 0.113_073_257_794_546_72, rtol=1e-6) + assert jnp.isclose(result, 0.113_073_257_794_546_72, rtol=1e-6) # --- scaled_bequest_weight --- @@ -105,6 +107,7 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -129,6 +132,7 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -154,6 +158,7 @@ def test_utility_crra_regression() -> None: def test_utility_married_equivalence() -> None: """Married with equiv-scaled consumption should equal single utility.""" scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -190,6 +195,7 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -222,6 +228,7 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, From 8cd8e37d3ada3eb8f9f91d76cd678d280d10e926 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 14:59:49 +0200 Subject: [PATCH 03/41] Halve aca-model production assets batch size to fit V100 16GB Reduce n_assets_batch_size from 2 to 1 in MODEL_CONFIG so the assets state axis is streamed one slice at a time, lowering peak GPU memory during solve on the V100-PCIE-16GB. Benchmark grid config is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 1bae45f..2904ca4 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -36,7 +36,7 @@ class GridConfig: # `batch_size` on the assets grid: chunked vmap stride for the # outer state loop. Useful at prod sizes for memory reasons; set # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 2 + n_assets_batch_size: int = 1 MODEL_CONFIG = ModelConfig() From 84d484a5153c241ab318b785749fe8b103a5ca0d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 17:05:11 +0200 Subject: [PATCH 04/41] =?UTF-8?q?Revert=20assets=20batch=5Fsize=20halving?= =?UTF-8?q?=20=E2=80=94=20V100=20OOM=20was=20elsewhere?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The OOM at production grid sizes came from pylcm's deferred diagnostics flush in solve_brute (`_emit_deferred_diagnostics` materialising a fused per-period reduction graph at end-of-solve), not from per-period peak. Halving the assets batch did not address that; reverting so the production loop runs at its previous throughput. Workaround for the diagnostics OOM lives in aca-estimation's simulate tasks (log_level="off"). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 2904ca4..1bae45f 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -36,7 +36,7 @@ class GridConfig: # `batch_size` on the assets grid: chunked vmap stride for the # outer state loop. Useful at prod sizes for memory reasons; set # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 1 + n_assets_batch_size: int = 2 MODEL_CONFIG = ModelConfig() From f0892efbbe891f198e7a42ba18147e322e7b71e7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 10:35:14 +0200 Subject: [PATCH 05/41] Re-halve production assets batch size: V100 still OOMs per-period MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #339's per-period `block_until_ready` made the OOM surface inside the loop instead of at the post-loop diagnostic flush, but the 7.26 GiB allocation request was the same — it isn't the diagnostic accumulator, it's a real per-period `max_Q_over_a` working set at production grid sizes (`n_consumption=70`, `n_assets=24`, `n_aime=12`, plus the per-target next-V gather across reachable regimes). Cutting the assets-axis chunk back to 1 reduces the per-kernel peak. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 1bae45f..2904ca4 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -36,7 +36,7 @@ class GridConfig: # `batch_size` on the assets grid: chunked vmap stride for the # outer state loop. Useful at prod sizes for memory reasons; set # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 2 + n_assets_batch_size: int = 1 MODEL_CONFIG = ModelConfig() From 08e42cb1e669f6e43582539bf4afae3cfbedafcd Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 12:45:59 +0200 Subject: [PATCH 06/41] config: add n_aime_batch_size to splay AIME outer-loop on V100 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Production solve allocates a per-period Q intermediate of shape `(non-assets-states × actions)` per assets-batch slot. With `n_assets_batch_size=1` we already chunk that axis to the minimum; the remaining outer-state product (aime × wage_res × hcc × pref_type × health × ...) times the action grid still pushes past the V100 16 GB once `pref_type` is split off into its own partition lift, which removes a free factor that previously thinned the kernel. Add a sibling `n_aime_batch_size` knob (default 1, 0 in `BENCHMARK_GRID_CONFIG`) and thread it through both AIME grid types in `_build_aime_grid`. AIME has 12 prod gridpoints in the LinSpaced fallback and 32 in the PiecewiseLinSpaced production path, so a unit batch shrinks the live Q intermediate by roughly that factor — enough headroom to land back inside V100 memory. Pairs with the pylcm-side fix that stops `_DiagnosticRow` pinning per-period V templates in device memory (lazy-solve-diagnostics branch). The diagnostic leak masked the underlying batching gap; once it's gone, the Q intermediate is the next thing to size for the device. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/baseline/regimes/_common.py | 9 +++++++-- src/aca_model/benchmark.py | 3 ++- src/aca_model/config.py | 9 ++++++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index efd3b73..327aa15 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -287,7 +287,10 @@ def _build_aime_grid( """ if fixed_params is None or "pia_aime_grid" not in fixed_params: return LinSpacedGrid( - start=0.0, stop=8_000.0, n_points=grid_config.n_aime_gridpoints + start=0.0, + stop=8_000.0, + n_points=grid_config.n_aime_gridpoints, + batch_size=grid_config.n_aime_batch_size, ) kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])] pieces = ( @@ -295,7 +298,9 @@ def _build_aime_grid( Piece(interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]), Piece(interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]), ) - return PiecewiseLinSpacedGrid(pieces=pieces) + return PiecewiseLinSpacedGrid( + pieces=pieces, batch_size=grid_config.n_aime_batch_size + ) def _has_required_wage_keys(*, wage_params: Mapping[str, Any]) -> bool: diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index a9d7128..3c06670 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -76,7 +76,8 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod The benchmark uses a 2-type `BenchmarkPrefType`. No `batch_size != 0` on any grid (continuous grids inherit - `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0`). + `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0` and + `n_aime_batch_size = 0`). Args: pref_type_grid: Override for the pref_type grid. Default is a plain diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 2904ca4..37fc0c8 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -33,10 +33,12 @@ class GridConfig: n_wage_res_gridpoints: int = 5 n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 - # `batch_size` on the assets grid: chunked vmap stride for the - # outer state loop. Useful at prod sizes for memory reasons; set - # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. + # `batch_size` on the assets / AIME grids: chunked vmap stride for the + # outer state loop. Both partition the per-period Q intermediate so it + # fits in V100 16 GB once we splay across `pref_type`. Set to 0 in + # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. n_assets_batch_size: int = 1 + n_aime_batch_size: int = 1 MODEL_CONFIG = ModelConfig() @@ -50,4 +52,5 @@ class GridConfig: n_hcc_persistent_gridpoints=3, n_hcc_transitory_gridpoints=3, n_assets_batch_size=0, + n_aime_batch_size=0, ) From e08fc19705549c5e59f38e5a704b993412c491be Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 12:57:29 +0200 Subject: [PATCH 07/41] consumption_grid: read upper bound from `max_consumption` fixed param MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The grid floor already tracks the per-iteration `consumption_floor` parameter; the ceiling was a hardcoded 300k constant. Surface it as a fixed param via a marker function (`consumption_grid_upper_bound`) so callers can declare the bracket per model creation, and read it back at inject time from each regime's `resolved_fixed_params`. The marker function's output is intentionally unused — its only job is to put `max_consumption` in the regime params template so pylcm's fixed-param machinery captures it. dags.tree pruning drops the call at solve / simulate. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/baseline/regimes/_common.py | 7 +++ src/aca_model/consumption_grid.py | 75 ++++++++++++++--------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 327aa15..d46f921 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -37,6 +37,7 @@ from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.consumption_grid import consumption_grid_upper_bound from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -537,6 +538,12 @@ def build_common_functions(spec: dict[str, str]) -> dict: functions["cash_on_hand"] = assets_and_income.cash_on_hand functions["transfers"] = assets_and_income.transfers + # Marker: surfaces `max_consumption` in the params template so it + # can be supplied via fixed_params and read back at inject time + # by `inject_consumption_points`. Output unused; pruned at + # solve / simulate. + functions["consumption_grid_upper_bound"] = consumption_grid_upper_bound + return functions diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index 8ba8bc4..6238328 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -1,10 +1,12 @@ """Runtime-supplied gridpoints for the consumption action. Consumption is declared as `IrregSpacedGrid(n_points=N)` in -`baseline.regimes._common.build_grids` so the lower bound can track -the per-iteration `consumption_floor` parameter. Callers must inject -the actual gridpoints into `params` via `inject_consumption_points` -before calling `model.solve()` / `model.simulate()`. +`baseline.regimes._common.build_grids` so the bounds can track +runtime parameters: the lower bound from the per-iteration +`consumption_floor` parameter, the upper bound from the per-creation-time +`max_consumption` fixed param. Callers must inject the actual +gridpoints into `params` via `inject_consumption_points` before +calling `model.solve()` / `model.simulate()`. """ from collections.abc import Mapping @@ -14,31 +16,11 @@ from jax import Array from lcm import IrregSpacedGrid, Model -MAX_CONSUMPTION: float = 300_000.0 -"""Upper bound of the consumption grid in $/year. Brackets the unconstrained -CRRA optimum for the highest-asset, highest-income agents in the state space.""" - - -def compute_consumption_points(*, consumption_floor: float, n_points: int) -> Array: - """Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`. - - Args: - consumption_floor: Lowest gridpoint, equal to the `consumption_floor` - parameter so the agent cannot pick `c < floor` even when saving - from a transfer top-up. - n_points: Total number of gridpoints. - - Returns: - 1-D float array of length `n_points`. - """ - return jnp.geomspace(consumption_floor, MAX_CONSUMPTION, num=n_points) - def inject_consumption_points( *, params: Mapping[str, Any], model: Model, - consumption_floor: float | None = None, ) -> dict[str, Any]: """Inject consumption gridpoints into per-regime params. @@ -46,27 +28,60 @@ def inject_consumption_points( `IrregSpacedGrid` with runtime-supplied points, and writes `params[regime_name]["consumption"] = {"points": }`. + Lower bound: `params["consumption_floor"]` (varies per iteration). + Upper bound: `max_consumption` from the regime's resolved + fixed-params (set once at model creation). + Args: params: Existing params mapping. Returned as a new dict; the input is not mutated. model: Model whose regime specs determine which regimes need points. - consumption_floor: Lowest gridpoint. When `None`, taken from - `params["consumption_floor"]`. Returns: New params dict with consumption points injected. """ - if consumption_floor is None: - consumption_floor = float(params["consumption_floor"]) + consumption_floor = float(params["consumption_floor"]) out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): grid = regime.actions.get("consumption") if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): continue - points = compute_consumption_points( - consumption_floor=consumption_floor, n_points=grid.n_points + # Runtime-points grids always have `n_points` set (the constructor + # rejects the (points=None, n_points=None) combo); narrow for ty. + assert grid.n_points is not None + max_consumption = float( + model.internal_regimes[regime_name].resolved_fixed_params["max_consumption"] + ) + points = _compute_consumption_points( + consumption_floor=consumption_floor, + max_consumption=max_consumption, + n_points=grid.n_points, ) regime_entry = dict(out.get(regime_name, {})) regime_entry["consumption"] = {"points": points} out[regime_name] = regime_entry return out + + +def consumption_grid_upper_bound(max_consumption: float) -> float: + """Surface `max_consumption` in the regime params template. + + pylcm builds the params template from each regime function's + signature. `max_consumption` is read at runtime by + `inject_consumption_points` from `resolved_fixed_params`; for + that to work via pylcm's fixed-params machinery, the key must + appear in some function's signature. This marker function is + the entry point — its output is intentionally unused, and + dags.tree pruning drops the call at solve / simulate time. + """ + return max_consumption + + +def _compute_consumption_points( + *, + consumption_floor: float, + max_consumption: float, + n_points: int, +) -> Array: + """Return log-spaced consumption gridpoints from floor to max.""" + return jnp.geomspace(consumption_floor, max_consumption, num=n_points) From c1ffb2a793ea4c5488b3e99a87883b80860c6b1d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 08:19:20 +0200 Subject: [PATCH 08/41] create_model: default `max_consumption` into fixed_params The runtime-upper-bound change requires every caller to supply `max_consumption` via `fixed_params`; estimation tasks (e.g. `task_simulate_aca`) hit a `KeyError` mid-pipeline because they construct the model from data-derived `fixed_params` that have no reason to mention a grid bracket. Centralise the default in both `baseline.model.create_model` and `aca.model.create_model` so existing callers keep working with the prior 300k bracket and only opt-in callers need to override. --- src/aca_model/aca/model.py | 4 +++- src/aca_model/baseline/model.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 8b4507a..0e02b0c 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -11,6 +11,7 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes +from aca_model.baseline.model import _with_max_consumption_default from aca_model.baseline.regimes import RegimeId from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -51,6 +52,7 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) + fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( policy=policy, grid_config=grid_config, @@ -63,6 +65,6 @@ def create_model( ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", - fixed_params=fixed_params or {}, + fixed_params=fixed_params, derived_categoricals=derived_categoricals, ) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index a886495..2843fa0 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -18,6 +18,11 @@ from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +_DEFAULT_MAX_CONSUMPTION: float = 300_000.0 +"""Upper bound of the consumption grid in $/year. Brackets the unconstrained +CRRA optimum for the highest-asset, highest-income agents in the state space. +Callers can override by passing `max_consumption` in `fixed_params`.""" + def create_model( *, @@ -59,6 +64,7 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) + fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( grid_config, fixed_params=fixed_params, @@ -71,6 +77,15 @@ def create_model( ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", - fixed_params=fixed_params or {}, + fixed_params=fixed_params, derived_categoricals=derived_categoricals, ) + + +def _with_max_consumption_default( + fixed_params: Mapping[str, Any] | None, +) -> dict[str, Any]: + """Return a copy of `fixed_params` with `max_consumption` defaulted.""" + out = dict(fixed_params or {}) + out.setdefault("max_consumption", _DEFAULT_MAX_CONSUMPTION) + return out From a21768752b15f70da2fa82910e6f4cebc26689d2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 13:20:38 +0200 Subject: [PATCH 09/41] create_model: forward n_subjects through baseline + aca + benchmark Lets callers opt in to pylcm's simulate-AOT path (`Model(n_subjects=...)`) without bypassing the aca-model factories. --- src/aca_model/aca/model.py | 2 ++ src/aca_model/baseline/model.py | 2 ++ src/aca_model/benchmark.py | 10 +++++++++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 0e02b0c..9a118c6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -24,6 +24,7 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] | None = None, grid_config: GridConfig = GRID_CONFIG, + n_subjects: int | None = None, ) -> Model: """Create an ACA policy variant model. @@ -67,4 +68,5 @@ def create_model( description=f"Structural retirement model ({policy.name})", fixed_params=fixed_params, derived_categoricals=derived_categoricals, + n_subjects=n_subjects, ) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 2843fa0..9d6b04a 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -32,6 +32,7 @@ def create_model( | None = None, grid_config: GridConfig = GRID_CONFIG, pref_type_grid: DiscreteGrid | None = None, + n_subjects: int | None = None, ) -> Model: """Create the baseline structural retirement model. @@ -79,6 +80,7 @@ def create_model( description="Baseline structural retirement model (pre-ACA)", fixed_params=fixed_params, derived_categoricals=derived_categoricals, + n_subjects=n_subjects, ) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 3c06670..35bf880 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -71,7 +71,11 @@ ) -def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Model: +def create_benchmark_model( + *, + pref_type_grid: DiscreteGrid | None = None, + n_subjects: int | None = None, +) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. The benchmark uses a 2-type `BenchmarkPrefType`. No `batch_size != 0` @@ -86,6 +90,9 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod (or `PARTITION_VMAP`) to get the partition-lifted kernel — the recommended production setting for aca-model at scale, but only supported on pylcm versions that expose `DispatchStrategy`. + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the + first matching `simulate(...)` call AOT-compiles all simulate + functions for that batch shape. """ if pref_type_grid is None: pref_type_grid = DiscreteGrid(BenchmarkPrefType) @@ -95,6 +102,7 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod fixed_params=fixed_params, derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, + n_subjects=n_subjects, ) From d1eb320f0439a73a169bc5916d4af0b15208b2b6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 13:30:56 +0200 Subject: [PATCH 10/41] create_model: require n_subjects (no default) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The aca-model factories now require `n_subjects` as a kw-only int with no default — there's never a good reason for an aca-model caller to leave it unspecified, and silently letting it default to `None` (= no AOT, lazy-compile path) was exactly how the simulate-AOT benefit went unused on the prod estimation loop. Forcing each caller to make a deliberate choice catches that. Tests pass `n_subjects=1` for bare `get_params_template()` / shock-grid-inspection paths that never simulate. --- src/aca_model/aca/model.py | 2 +- src/aca_model/baseline/model.py | 2 +- src/aca_model/benchmark.py | 2 +- tests/test_benchmark.py | 2 +- tests/test_model_creation.py | 14 +++++++------- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 9a118c6..a9097b4 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -18,13 +18,13 @@ def create_model( *, + n_subjects: int, policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] | None = None, grid_config: GridConfig = GRID_CONFIG, - n_subjects: int | None = None, ) -> Model: """Create an ACA policy variant model. diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 9d6b04a..78c90eb 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -26,13 +26,13 @@ def create_model( *, + n_subjects: int, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] | None = None, grid_config: GridConfig = GRID_CONFIG, pref_type_grid: DiscreteGrid | None = None, - n_subjects: int | None = None, ) -> Model: """Create the baseline structural retirement model. diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 35bf880..08f5ec6 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -73,8 +73,8 @@ def create_benchmark_model( *, + n_subjects: int, pref_type_grid: DiscreteGrid | None = None, - n_subjects: int | None = None, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index c1e48e0..adafd66 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -12,7 +12,7 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model() + model = create_benchmark_model(n_subjects=n_subjects) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 841f2bb..d154e6b 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -21,24 +21,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model() + model = create_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model() + model = create_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model() + model = create_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model() + model = create_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +170,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model() + model = create_aca_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +211,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(policy=policy) + model = create_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +251,5 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model() + model = create_model(n_subjects=1) assert len(model.regimes) == 19 From 9e252051ad53683a8ad65e9dba68a910240103c0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 13:41:55 +0200 Subject: [PATCH 11/41] ci: install pylcm from feat/simulate-aot-n-subjects (carries Model.n_subjects) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a0247b8..bd40367 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@feature/runtime-action-grids" + git+https://github.com/OpenSourceEconomics/pylcm.git@feat/simulate-aot-n-subjects" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From cdd10169a13fbc74604f9e879276ddb4c17b53c4 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 21:02:13 +0200 Subject: [PATCH 12/41] consumption_grid: max_consumption is a required factory arg, attached to Model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The marker-function-via-DAG pattern didn't survive pylcm's pruning: `consumption_grid_upper_bound`'s output is unused, so dags.tree drops it before its `max_consumption` parameter reaches the params template, and `broadcast_to_template` has nowhere to put the value. Result: `resolved_fixed_params["max_consumption"]` was always missing, `inject_consumption_points` raised KeyError. Sidestep pylcm's params machinery for this knob: - Drop the `consumption_grid_upper_bound` marker function and the `_with_max_consumption_default` helper. - Add `max_consumption: float` (kw-only, required, no default) to all three factories: `baseline.create_model`, `aca.create_model`, `create_benchmark_model`. - Each factory attaches the value directly to the returned `Model` instance (`model.max_consumption = ...`). - `inject_consumption_points` reads `model.max_consumption` directly. No defaults — every caller passes the bracket explicitly. --- src/aca_model/aca/model.py | 13 ++++++--- src/aca_model/baseline/model.py | 35 +++++++++++------------ src/aca_model/baseline/regimes/_common.py | 7 ----- src/aca_model/benchmark.py | 2 ++ src/aca_model/consumption_grid.py | 33 ++++++--------------- tests/test_benchmark.py | 2 +- tests/test_model_creation.py | 14 ++++----- 7 files changed, 45 insertions(+), 61 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index a9097b4..942be22 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -11,7 +11,6 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes -from aca_model.baseline.model import _with_max_consumption_default from aca_model.baseline.regimes import RegimeId from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -19,6 +18,7 @@ def create_model( *, n_subjects: int, + max_consumption: float, policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, @@ -29,6 +29,7 @@ def create_model( """Create an ACA policy variant model. Args: + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. policy: Which ACA policy combination to apply. fixed_params: Parameters to fix at model creation time. These are partialled into compiled functions and removed from the params @@ -43,6 +44,9 @@ def create_model( contains `pd.Series` indexed by DAG function outputs. grid_config: Continuous-grid point counts. Defaults to production values. + max_consumption: Upper bound of the runtime consumption grid in + $/year. Attached to the returned Model and read back at inject + time by `inject_consumption_points`. Returns: pylcm Model with ACA-specific function overrides. @@ -53,7 +57,6 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) - fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( policy=policy, grid_config=grid_config, @@ -61,12 +64,14 @@ def create_model( wage_params=wage_params, ) - return Model( + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", - fixed_params=fixed_params, + fixed_params=fixed_params or {}, derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) + model.max_consumption = max_consumption + return model diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 78c90eb..0ff7c47 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -5,7 +5,7 @@ Usage: from aca_model.baseline.model import create_model - model = create_model() + model = create_model(n_subjects=...) params = get_default_params() V = model.solve(params) """ @@ -18,15 +18,11 @@ from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig -_DEFAULT_MAX_CONSUMPTION: float = 300_000.0 -"""Upper bound of the consumption grid in $/year. Brackets the unconstrained -CRRA optimum for the highest-asset, highest-income agents in the state space. -Callers can override by passing `max_consumption` in `fixed_params`.""" - def create_model( *, n_subjects: int, + max_consumption: float, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] @@ -37,6 +33,7 @@ def create_model( """Create the baseline structural retirement model. Args: + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. fixed_params: Parameters to fix at model creation time. These are partialled into compiled functions and removed from the params template. Pass data-derived constants here; only estimation @@ -54,6 +51,9 @@ def create_model( pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. + max_consumption: Upper bound of the runtime consumption grid in + $/year. Attached to the returned Model and read back at inject + time by `inject_consumption_points`. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -65,7 +65,6 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) - fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( grid_config, fixed_params=fixed_params, @@ -73,21 +72,21 @@ def create_model( pref_type_grid=pref_type_grid, ) - return Model( + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", - fixed_params=fixed_params, + fixed_params=fixed_params or {}, derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) - - -def _with_max_consumption_default( - fixed_params: Mapping[str, Any] | None, -) -> dict[str, Any]: - """Return a copy of `fixed_params` with `max_consumption` defaulted.""" - out = dict(fixed_params or {}) - out.setdefault("max_consumption", _DEFAULT_MAX_CONSUMPTION) - return out + # Attach the consumption-grid upper bound directly to the Model + # instance. Tried surfacing it via a marker function in the regime + # DAG first — pylcm's pruning drops unused-output functions before + # their parameters reach the params template, so the value never + # made it into `resolved_fixed_params`. Direct attachment sidesteps + # the templating machinery entirely; `inject_consumption_points` + # reads `model.max_consumption`. + model.max_consumption = max_consumption + return model diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index d46f921..327aa15 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -37,7 +37,6 @@ from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig -from aca_model.consumption_grid import consumption_grid_upper_bound from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -538,12 +537,6 @@ def build_common_functions(spec: dict[str, str]) -> dict: functions["cash_on_hand"] = assets_and_income.cash_on_hand functions["transfers"] = assets_and_income.transfers - # Marker: surfaces `max_consumption` in the params template so it - # can be supplied via fixed_params and read back at inject time - # by `inject_consumption_points`. Output unused; pruned at - # solve / simulate. - functions["consumption_grid_upper_bound"] = consumption_grid_upper_bound - return functions diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 08f5ec6..3d7fbad 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -74,6 +74,7 @@ def create_benchmark_model( *, n_subjects: int, + max_consumption: float, pref_type_grid: DiscreteGrid | None = None, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. @@ -103,6 +104,7 @@ def create_benchmark_model( derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, n_subjects=n_subjects, + max_consumption=max_consumption, ) diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index 6238328..bd342ee 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -3,10 +3,11 @@ Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration -`consumption_floor` parameter, the upper bound from the per-creation-time -`max_consumption` fixed param. Callers must inject the actual -gridpoints into `params` via `inject_consumption_points` before -calling `model.solve()` / `model.simulate()`. +`consumption_floor` parameter, the upper bound from a per-model +`max_consumption` knob attached to the `Model` instance by the +`create_model` factories. Callers must inject the actual gridpoints +into `params` via `inject_consumption_points` before calling +`model.solve()` / `model.simulate()`. """ from collections.abc import Mapping @@ -24,13 +25,13 @@ def inject_consumption_points( ) -> dict[str, Any]: """Inject consumption gridpoints into per-regime params. - Walks `model.regimes`, finds those with `consumption` declared as + Walks every regime, finds the action whose grid is an `IrregSpacedGrid` with runtime-supplied points, and writes `params[regime_name]["consumption"] = {"points": }`. Lower bound: `params["consumption_floor"]` (varies per iteration). - Upper bound: `max_consumption` from the regime's resolved - fixed-params (set once at model creation). + Upper bound: `model.max_consumption` (required attribute; set by + the `create_model` factory). Args: params: Existing params mapping. Returned as a new dict; the input is @@ -41,6 +42,7 @@ def inject_consumption_points( New params dict with consumption points injected. """ consumption_floor = float(params["consumption_floor"]) + max_consumption = float(model.max_consumption) out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): grid = regime.actions.get("consumption") @@ -49,9 +51,6 @@ def inject_consumption_points( # Runtime-points grids always have `n_points` set (the constructor # rejects the (points=None, n_points=None) combo); narrow for ty. assert grid.n_points is not None - max_consumption = float( - model.internal_regimes[regime_name].resolved_fixed_params["max_consumption"] - ) points = _compute_consumption_points( consumption_floor=consumption_floor, max_consumption=max_consumption, @@ -63,20 +62,6 @@ def inject_consumption_points( return out -def consumption_grid_upper_bound(max_consumption: float) -> float: - """Surface `max_consumption` in the regime params template. - - pylcm builds the params template from each regime function's - signature. `max_consumption` is read at runtime by - `inject_consumption_points` from `resolved_fixed_params`; for - that to work via pylcm's fixed-params machinery, the key must - appear in some function's signature. This marker function is - the entry point — its output is intentionally unused, and - dags.tree pruning drops the call at solve / simulate time. - """ - return max_consumption - - def _compute_consumption_points( *, consumption_floor: float, diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index adafd66..6173318 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -12,7 +12,7 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model(n_subjects=n_subjects) + model = create_benchmark_model(n_subjects=n_subjects, max_consumption=300_000.0) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index d154e6b..accde27 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -21,24 +21,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +170,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model(n_subjects=1) + model = create_aca_model(n_subjects=1, max_consumption=300_000.0) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +211,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(n_subjects=1, policy=policy) + model = create_aca_model(n_subjects=1, max_consumption=300_000.0, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +251,5 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert len(model.regimes) == 19 From 31a0ad20e70c9e1859f4268cd954979070d8b17f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 17:29:58 +0200 Subject: [PATCH 13/41] Move max_consumption to canonical constant; drop kwarg threading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `MAX_CONSUMPTION = 300_000.0` to `baseline/regimes/_common.py` next to the other grid bounds (assets `stop=500_000.0`, AIME `stop=8_000.0`). The two `create_model` factories and `create_benchmark_model` no longer take `max_consumption` as a kwarg; each factory reads the constant directly and attaches it onto `model.max_consumption`. `inject_consumption_points` is unchanged — it still reads `model.max_consumption` (the legitimate consumer that combines it with the per-iteration `consumption_floor`). Routed via the Model attribute rather than `fixed_params` because pylcm validates fixed_params keys against the regime DAG and rejects entries no function consumes (`InvalidParamsError: Unknown keys: ['max_consumption']`). Also pins the pylcm CI ref to 6c610d1 — the squash-merge of pylcm #341 (int32 lock-in) into feat/simulate-aot-n-subjects — to make this build deterministic against pylcm drift. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- src/aca_model/aca/model.py | 7 ++----- src/aca_model/baseline/model.py | 16 ++++------------ src/aca_model/baseline/regimes/_common.py | 12 ++++++++++++ src/aca_model/benchmark.py | 2 -- src/aca_model/consumption_grid.py | 15 ++++++++------- tests/test_benchmark.py | 2 +- tests/test_model_creation.py | 21 +++++++++++++-------- 8 files changed, 41 insertions(+), 36 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bd40367..110aafe 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@feat/simulate-aot-n-subjects" + git+https://github.com/OpenSourceEconomics/pylcm.git@6c610d19644d3f524ad112ed16c0621ee2ecd326" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 942be22..ee8efc6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -12,13 +12,13 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes from aca_model.baseline.regimes import RegimeId +from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - max_consumption: float, policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, @@ -44,9 +44,6 @@ def create_model( contains `pd.Series` indexed by DAG function outputs. grid_config: Continuous-grid point counts. Defaults to production values. - max_consumption: Upper bound of the runtime consumption grid in - $/year. Attached to the returned Model and read back at inject - time by `inject_consumption_points`. Returns: pylcm Model with ACA-specific function overrides. @@ -73,5 +70,5 @@ def create_model( derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) - model.max_consumption = max_consumption + model.max_consumption = MAX_CONSUMPTION return model diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 0ff7c47..fe181eb 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -16,13 +16,13 @@ from lcm import AgeGrid, DiscreteGrid, Model from aca_model.baseline.regimes import RegimeId, build_all_regimes +from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - max_consumption: float, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] @@ -51,9 +51,6 @@ def create_model( pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. - max_consumption: Upper bound of the runtime consumption grid in - $/year. Attached to the returned Model and read back at inject - time by `inject_consumption_points`. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -81,12 +78,7 @@ def create_model( derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) - # Attach the consumption-grid upper bound directly to the Model - # instance. Tried surfacing it via a marker function in the regime - # DAG first — pylcm's pruning drops unused-output functions before - # their parameters reach the params template, so the value never - # made it into `resolved_fixed_params`. Direct attachment sidesteps - # the templating machinery entirely; `inject_consumption_points` - # reads `model.max_consumption`. - model.max_consumption = max_consumption + # See `MAX_CONSUMPTION` in `baseline.regimes._common` for why this + # rides on the Model instance instead of `fixed_params`. + model.max_consumption = MAX_CONSUMPTION return model diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 327aa15..30198aa 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -195,6 +195,18 @@ class Grids: _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) +MAX_CONSUMPTION: float = 300_000.0 +"""Upper bound of the runtime consumption grid in $/year. + +Lives here next to the other grid bounds (assets `stop=500_000.0`, +AIME `stop=8_000.0`). The `create_model` factories attach this onto +`model.max_consumption` so `inject_consumption_points` can read it +back at runtime. Routed via a Model attribute rather than +`fixed_params` because pylcm validates `fixed_params` keys against +the regime DAG and rejects entries no function consumes. +""" + + def build_grids( grid_config: GridConfig = GRID_CONFIG, *, diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 3d7fbad..08f5ec6 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -74,7 +74,6 @@ def create_benchmark_model( *, n_subjects: int, - max_consumption: float, pref_type_grid: DiscreteGrid | None = None, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. @@ -104,7 +103,6 @@ def create_benchmark_model( derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, n_subjects=n_subjects, - max_consumption=max_consumption, ) diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index bd342ee..7123c1f 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -3,11 +3,12 @@ Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration -`consumption_floor` parameter, the upper bound from a per-model -`max_consumption` knob attached to the `Model` instance by the -`create_model` factories. Callers must inject the actual gridpoints -into `params` via `inject_consumption_points` before calling -`model.solve()` / `model.simulate()`. +`consumption_floor` parameter, the upper bound from +`MAX_CONSUMPTION` in `baseline.regimes._common`, which the +`create_model` factories attach to `model.max_consumption`. +Callers must inject the actual gridpoints into `params` via +`inject_consumption_points` before calling `model.solve()` / +`model.simulate()`. """ from collections.abc import Mapping @@ -30,8 +31,8 @@ def inject_consumption_points( `params[regime_name]["consumption"] = {"points": }`. Lower bound: `params["consumption_floor"]` (varies per iteration). - Upper bound: `model.max_consumption` (required attribute; set by - the `create_model` factory). + Upper bound: `model.max_consumption` (set by the `create_model` + factory from `MAX_CONSUMPTION` in `baseline.regimes._common`). Args: params: Existing params mapping. Returned as a new dict; the input is diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 6173318..adafd66 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -12,7 +12,7 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model(n_subjects=n_subjects, max_consumption=300_000.0) + model = create_benchmark_model(n_subjects=n_subjects) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index accde27..75a87d9 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -11,7 +11,7 @@ from aca_model.baseline.model import create_model from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime -from aca_model.baseline.regimes._common import build_grids +from aca_model.baseline.regimes._common import MAX_CONSUMPTION, build_grids _GRIDS = build_grids() @@ -21,24 +21,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +170,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model(n_subjects=1, max_consumption=300_000.0) + model = create_aca_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +211,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(n_subjects=1, max_consumption=300_000.0, policy=policy) + model = create_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +251,10 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert len(model.regimes) == 19 + + +def test_max_consumption_attached_from_canonical_constant() -> None: + model = create_model(n_subjects=1) + assert model.max_consumption == MAX_CONSUMPTION From 714fee0496c63547da047670fae058acfae6bfa2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 07:16:37 +0200 Subject: [PATCH 14/41] Assets grid: subtract MAX_CONSUMPTION margin from the floor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With consumption now declared as `IrregSpacedGrid(n_points=N)` and points filled at runtime from `geomspace(consumption_floor, max_consumption, N)`, the grid clusters densely just above `consumption_floor`. At the lowest-asset / highest-OOP-shock corner, those near-floor consumption choices push `next_assets = cash_on_hand - OOP - consumption` slightly below the assets grid's old lower bound (`0` for the bare model, `-max_annual_labor_income` when wage_params are available). Out-of-bounds interpolation of next-period V then injects NaN, which propagates back through E[V] and eventually fails `validate_V`. Symptom on the production solve: `Value function at age 93 in regime 'retiree_oamc_forced_forcedout': 7317 of 207360 values are NaN`, with the `[NOTE]` showing E[V] NaN concentrated at the lowest assets indices and the highest hcc_transitory shock. Subtract `MAX_CONSUMPTION` from the assets floor to give a worst-case single-period drain margin. With 24 linspace points spanning the wider range, the per-point density change is negligible; the dead state and the bare-model fallback get the margin too. The asymmetry fix is the cheapest one — no change to the consumption grid type, no change to per-iteration parameters, no new constraints. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/baseline/regimes/_common.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 30198aa..d32e2c1 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -262,10 +262,25 @@ def build_grids( sigma=1.0, ) - assets_start = 0.0 + # Assets-grid lower bound includes a one-period margin below the + # binding borrowing limit so that `next_assets = cash_on_hand - OOP - + # consumption` stays inside the grid even at the worst-shock × low- + # consumption corner. With the runtime log-spaced consumption grid + # `geomspace(consumption_floor, max_consumption, n_points)`, choices + # cluster densely just above `consumption_floor`, and at the lowest- + # asset/highest-OOP-shock corner those choices push `next_assets` + # slightly off the grid bottom — out-of-bounds interpolation of + # next-period V then injects NaN that propagates through E[V]. + # Subtracting `MAX_CONSUMPTION` gives a worst-case single-period + # drain margin; cheap at production grid sizes (24 linspace points + # over the wider range). + assets_start = -MAX_CONSUMPTION if wage_params is not None and _has_required_wage_keys(wage_params=wage_params): - assets_start = -_compute_max_annual_labor_income( - wage_params=wage_params, wage_res_grid=wage_res + assets_start = ( + -_compute_max_annual_labor_income( + wage_params=wage_params, wage_res_grid=wage_res + ) + - MAX_CONSUMPTION ) return Grids( From 63d2a3819d08cf33f95c0149b6d3531b5292e729 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:05:11 +0200 Subject: [PATCH 15/41] Revert "Assets grid: subtract MAX_CONSUMPTION margin from the floor" This reverts commit 714fee0496c63547da047670fae058acfae6bfa2. --- src/aca_model/baseline/regimes/_common.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index d32e2c1..30198aa 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -262,25 +262,10 @@ def build_grids( sigma=1.0, ) - # Assets-grid lower bound includes a one-period margin below the - # binding borrowing limit so that `next_assets = cash_on_hand - OOP - - # consumption` stays inside the grid even at the worst-shock × low- - # consumption corner. With the runtime log-spaced consumption grid - # `geomspace(consumption_floor, max_consumption, n_points)`, choices - # cluster densely just above `consumption_floor`, and at the lowest- - # asset/highest-OOP-shock corner those choices push `next_assets` - # slightly off the grid bottom — out-of-bounds interpolation of - # next-period V then injects NaN that propagates through E[V]. - # Subtracting `MAX_CONSUMPTION` gives a worst-case single-period - # drain margin; cheap at production grid sizes (24 linspace points - # over the wider range). - assets_start = -MAX_CONSUMPTION + assets_start = 0.0 if wage_params is not None and _has_required_wage_keys(wage_params=wage_params): - assets_start = ( - -_compute_max_annual_labor_income( - wage_params=wage_params, wage_res_grid=wage_res - ) - - MAX_CONSUMPTION + assets_start = -_compute_max_annual_labor_income( + wage_params=wage_params, wage_res_grid=wage_res ) return Grids( From 4ae44469ef99a8cdc26da164f0009743aaa72652 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:08:02 +0200 Subject: [PATCH 16/41] Wire pension imputation correction (FJ 2011 Appendix A.5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two new DAG functions in canwork & ss != "forced" regimes: - target_his(his, labor_supply, is_medicaid_eligible): HIS class of the surviving target regime. Mirrors the cross-HIS branches inside _make_transition_canwork (tied → nongroup when stopping work, Medicaid override → nongroup). - imputed_pension_wealth_next_period(next_aime, target_his, period, ...): computes pw_next_imputed = benefit_imputed(next_pia, next_period, target_his) · epdv_constant_pension[next_period] using bare-name parameters into 1-period-shifted views of the imputation arrays (`*_next_period`). Inlining is required because pylcm's AST shape inference doesn't trace nested calls into pensions.benefit. next_assets continues to consume pension_assets_adjustment, which now sees a real imputed_pension_wealth_next_period via the DAG (previously fixed to 0.0 in aca-estimation). The chained dependency next_aime → imputed_pension_wealth_next_period → pension_assets_adjustment is unblocked by pylcm exempting next_ names from fixed_param extraction (PR pylcm#342). Also drops pension_assets_adjustment from borrowing_constraint: a negative correction at a cross-HIS transition can leave no feasible action and inject `-inf` into V via `argmax_and_max(initial=-inf, where=F_arr)`, which then cancels with `0 * -inf = NaN`. The correction is a post-decision shift on next-period assets and must not gate the current consumption choice. --- src/aca_model/agent/assets_and_income.py | 13 +++++-- src/aca_model/baseline/health_insurance.py | 23 ++++++++++++ src/aca_model/baseline/regimes/_nongroup.py | 4 ++ src/aca_model/baseline/regimes/_retiree.py | 4 ++ src/aca_model/baseline/regimes/_tied.py | 4 ++ src/aca_model/environment/pensions.py | 41 ++++++++++++++++++++- 6 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index dfa83ef..46d4b1c 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -69,7 +69,14 @@ def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, transfers: FloatND, - pension_assets_adjustment: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources (no borrowing).""" - return consumption <= cash_on_hand + transfers + pension_assets_adjustment + """Consumption cannot exceed available resources (no borrowing). + + `pension_assets_adjustment` is excluded: it can be negative (e.g., + when the imputation overstates next-period pension wealth at a + cross-HIS transition), and including it here can leave no feasible + action at low-asset / mid-AIME corners. The correction enters + `next_assets` instead — a post-decision shift that does not gate + the current consumption choice. + """ + return consumption <= cash_on_hand + transfers diff --git a/src/aca_model/baseline/health_insurance.py b/src/aca_model/baseline/health_insurance.py index 741d160..3732d6d 100644 --- a/src/aca_model/baseline/health_insurance.py +++ b/src/aca_model/baseline/health_insurance.py @@ -246,6 +246,29 @@ def is_medicaid_eligible(is_ssi_eligible: BoolND) -> BoolND: return is_ssi_eligible +def target_his( + his: IntND, + labor_supply: DiscreteAction, + is_medicaid_eligible: BoolND, +) -> IntND: + """Return the HIS class of the surviving target regime. + + Mirrors the cross-HIS branches inside `_make_transition_canwork` (retiree, + tied, nongroup): tied agents who stop working become nongroup, and + Medicaid-eligible agents are overridden to nongroup. Used by + `imputed_pension_wealth_next_period` to look up next-period imputation + coefficients at the target's HIS. + """ + tied_to_ng = (his == HealthInsuranceState.tied) & ( + labor_supply == LaborSupply.do_not_work + ) + return jnp.where( + tied_to_ng | is_medicaid_eligible, + HealthInsuranceState.nongroup, + his, + ).astype(jnp.int32) + + def oop_with_medicaid( primary_oop: FloatND, is_medicaid_eligible: BoolND, diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index 5cdb6dc..7ee82ff 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -99,6 +99,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index ac76bfd..a941fa9 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -109,6 +109,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index 5d59274..4351cf5 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -83,6 +83,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/environment/pensions.py b/src/aca_model/environment/pensions.py index a23a800..eef72d4 100644 --- a/src/aca_model/environment/pensions.py +++ b/src/aca_model/environment/pensions.py @@ -4,7 +4,7 @@ """ import jax.numpy as jnp -from lcm.typing import FloatND, IntND, Period +from lcm.typing import ContinuousState, FloatND, IntND, Period def benefit( @@ -164,3 +164,42 @@ def assets_adjustment( * unconditional_survival_prob[period] * (pension_wealth_next_before_adjustment - imputed_pension_wealth_next_period) ) + + +def imputed_pension_wealth_next_period( + next_aime: ContinuousState, + target_his: IntND, + period: Period, + pia_table: FloatND, + pia_aime_grid: FloatND, + imp_intercept_next_period: FloatND, + imp_pia_coeff_next_period: FloatND, + imp_pia_kink_0_coeff_next_period: FloatND, + imp_pia_kink_1_coeff_next_period: FloatND, + imp_kink_0_next_period: FloatND, + imp_kink_1_next_period: FloatND, + imp_fraction_receiving_next_period: FloatND, + epdv_constant_pension_next_period: FloatND, +) -> FloatND: + """Imputed pension wealth at next period using the target regime's HIS. + + Mirrors `benefit` and `wealth` but indexes into 1-period-shifted views + of the imputation arrays so all subscripts use bare-name parameters + (`period`, `target_his`). Inlining is required: pylcm's AST shape + inference inspects the registered function's body and does not trace + through nested calls into `benefit`. + """ + next_pia = jnp.interp(next_aime, pia_aime_grid, pia_table) + + intercept = imp_intercept_next_period[period, target_his] + pia_pred = imp_pia_coeff_next_period[period, target_his] * next_pia + kink_0_adj = imp_pia_kink_0_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_0_next_period[period] + ) + kink_1_adj = imp_pia_kink_1_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_1_next_period[period] + ) + + full_benefit = jnp.maximum(0.0, intercept + pia_pred + kink_0_adj + kink_1_adj) + benefit_next = full_benefit * imp_fraction_receiving_next_period[period] + return benefit_next * epdv_constant_pension_next_period[period] From 83f22500e97a6675aa4cd15235dea359dae94f2d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:37:08 +0200 Subject: [PATCH 17/41] =?UTF-8?q?Bump=20pyproject-fmt=20v2.19.0=20?= =?UTF-8?q?=E2=86=92=20v2.21.1=20and=20ruff-pre-commit=20v0.15.6=20?= =?UTF-8?q?=E2=86=92=20v0.15.12?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e36542..f3188ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: check-hooks-apply - id: check-useless-excludes - repo: https://github.com/tox-dev/pyproject-fmt - rev: v2.19.0 + rev: v2.21.1 hooks: - id: pyproject-fmt - repo: https://github.com/lyz-code/yamlfix @@ -47,7 +47,7 @@ repos: hooks: - id: yamllint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.6 + rev: v0.15.12 hooks: - id: ruff-check args: From 3453080fd08afa049483f6ddda215a998a55b757 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 20:48:08 +0200 Subject: [PATCH 18/41] get_benchmark_params: filter obsolete imputed_pension_wealth_next_period key The frozen benchmark_params.pkl was generated when aca-estimation's _assemble_params.py still wrote the placeholder `fp["imputed_pension_wealth_next_period"] = 0.0` into fixed_params. Now that the regime registers `imputed_pension_wealth_next_period` as a DAG function (pension imputation correction in 4ae4446), pylcm's `_resolve_fixed_params` rejects the stale key with `InvalidParamsError: Unknown keys: ['imputed_pension_wealth_next_period']`. Drop the key on load so the snapshot stays valid. Regenerating `benchmark_params.pkl` end-to-end would also remove it; the filter is a no-op for a fresh snapshot. --- src/aca_model/benchmark.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 08f5ec6..47cb628 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -122,13 +122,23 @@ def get_benchmark_params( """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) - fixed_params = data["fixed_params"] + fixed_params = { + k: v for k, v in data["fixed_params"].items() if k not in _STALE_FIXED_KEYS + } params = _truncate_pref_type_indexed(data["params"]) if model is not None: params = inject_consumption_points(params=params, model=model) return fixed_params, params +# Keys that the older aca-estimation `_assemble_params.py` wrote into +# `fixed_params` but that the current regime now resolves as a DAG +# function. Drop them on load so pylcm's `_resolve_fixed_params` does +# not reject the snapshot. Regenerating `benchmark_params.pkl` would +# also remove these — the filter is a no-op when the snapshot is fresh. +_STALE_FIXED_KEYS: frozenset[str] = frozenset({"imputed_pension_wealth_next_period"}) + + def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. From b2e90bb58a1c6721046a3e860a95a29485b25117 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:04:19 +0200 Subject: [PATCH 19/41] get_benchmark_params: synthesise _next_period shifted views The frozen `benchmark_params.pkl` predates aca-data's `_shift_one_period_forward` change, so the 1-period-shifted views the pension correction consumes are missing. Synthesise them on load with the same transformation aca-data applies. Regenerating the snapshot end-to-end would also produce the keys; this filter is a no-op for a fresh snapshot. --- src/aca_model/benchmark.py | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 47cb628..5c24d4f 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -125,6 +125,7 @@ def get_benchmark_params( fixed_params = { k: v for k, v in data["fixed_params"].items() if k not in _STALE_FIXED_KEYS } + fixed_params = _add_shifted_imputation_arrays(fixed_params) params = _truncate_pref_type_indexed(data["params"]) if model is not None: params = inject_consumption_points(params=params, model=model) @@ -139,6 +140,55 @@ def get_benchmark_params( _STALE_FIXED_KEYS: frozenset[str] = frozenset({"imputed_pension_wealth_next_period"}) +# Source → derived key mapping for the 1-period-shifted views of the +# imputation arrays. The current pension correction (`imputed_pension_ +# wealth_next_period`) consumes these. The frozen `benchmark_params.pkl` +# predates aca-data's `_shift_one_period_forward` change, so synthesise +# the shifted views on load. The transformation is deterministic: row +# `period` carries the original at row `period + 1`; the last row holds +# flat. A regenerated snapshot can drop this synthesis (the filter is a +# no-op when the keys already exist). +_SHIFTED_IMPUTATION_KEYS: tuple[str, ...] = ( + "imp_intercept", + "imp_pia_coeff", + "imp_pia_kink_0_coeff", + "imp_pia_kink_1_coeff", + "imp_kink_0", + "imp_kink_1", + "imp_fraction_receiving", + "epdv_constant_pension", +) + + +def _add_shifted_imputation_arrays(fixed_params: dict[str, Any]) -> dict[str, Any]: + """Synthesise `_next_period` views from the source arrays.""" + out = dict(fixed_params) + for key in _SHIFTED_IMPUTATION_KEYS: + next_period_key = f"{key}_next_period" + if next_period_key in out or key not in out: + continue + out[next_period_key] = _shift_one_period_forward(out[key]) + return out + + +def _shift_one_period_forward(sr: pd.Series) -> pd.Series: + """Shift age-axis values forward one position (last row held flat).""" + if isinstance(sr.index, pd.MultiIndex) and sr.index.names[0] == "age": + n_periods = sr.index.levshape[0] + n_other = int( + np.prod([sr.index.levshape[i] for i in range(1, sr.index.nlevels)]) + ) + values = sr.to_numpy().reshape(n_periods, n_other) + shifted = np.concatenate([values[1:], values[-1:]], axis=0) + return pd.Series(shifted.ravel(), index=sr.index) + if sr.index.name == "age": + values = sr.to_numpy() + shifted = np.concatenate([values[1:], values[-1:]]) + return pd.Series(shifted, index=sr.index) + msg = f"Unexpected index for _shift_one_period_forward: {sr.index!r}" + raise ValueError(msg) + + def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. From 35eddcc9ee06b960c30c6ea09e3f07541f3144a6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:18:43 +0200 Subject: [PATCH 20/41] benchmark: declare target_his as derived categorical MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit target_his is a DAG function returning an HealthInsuranceState int, used to index 2D imputation arrays inside imputed_pension_wealth_next_period. pylcm needs the categorical mapping declared so array_from_series can reshape (age, target_his)-indexed Series correctly. Mirrors the existing 'his' entry — same enum class. --- src/aca_model/benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5c24d4f..13dfae4 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -56,6 +56,7 @@ "good_health": DiscreteGrid(GoodHealth), "is_married": DiscreteGrid(IsMarried), "his": DiscreteGrid(HealthInsuranceState), + "target_his": DiscreteGrid(HealthInsuranceState), "pref_type": DiscreteGrid(BenchmarkPrefType), } From 64d656791230ebf20622c247bf2935880de2fcfd Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:32:47 +0200 Subject: [PATCH 21/41] _shift_one_period_forward: rename his level to target_his The shifted imputation arrays (`imp_*_next_period`) are consumed by `imputed_pension_wealth_next_period(target_his, period, ...)`. pylcm's `_validate_and_reorder_levels` matches Series MultiIndex level names against the function's parameter names, so the level needs to be `target_his`, not `his`. --- src/aca_model/benchmark.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 13dfae4..8e242b0 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -173,7 +173,12 @@ def _add_shifted_imputation_arrays(fixed_params: dict[str, Any]) -> dict[str, An def _shift_one_period_forward(sr: pd.Series) -> pd.Series: - """Shift age-axis values forward one position (last row held flat).""" + """Shift age-axis values forward one position (last row held flat). + + For (age, his)-indexed inputs, also rename the `his` level to + `target_his` so the resulting Series matches the level naming the + consuming `imputed_pension_wealth_next_period` function expects. + """ if isinstance(sr.index, pd.MultiIndex) and sr.index.names[0] == "age": n_periods = sr.index.levshape[0] n_other = int( @@ -181,7 +186,10 @@ def _shift_one_period_forward(sr: pd.Series) -> pd.Series: ) values = sr.to_numpy().reshape(n_periods, n_other) shifted = np.concatenate([values[1:], values[-1:]], axis=0) - return pd.Series(shifted.ravel(), index=sr.index) + new_index = sr.index.rename( + [_rename_his_level(name) for name in sr.index.names] + ) + return pd.Series(shifted.ravel(), index=new_index) if sr.index.name == "age": values = sr.to_numpy() shifted = np.concatenate([values[1:], values[-1:]]) @@ -190,6 +198,11 @@ def _shift_one_period_forward(sr: pd.Series) -> pd.Series: raise ValueError(msg) +def _rename_his_level(name: str) -> str: + """Rename `his` to `target_his`, leave others alone.""" + return "target_his" if name == "his" else name + + def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. From f09b5e34102ff42f739b95be5a9d388795b734a1 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 22:00:55 +0200 Subject: [PATCH 22/41] Per-target next_assets: dead target uses next_assets_terminal (no pension chain) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit state_transitions["assets"] becomes a per-target dict. The dead target gets a simpler `next_assets_terminal` (cash + transfers - consumption - oop) without the `pension_assets_adjustment` chain, because: 1. There is no future for a dead agent — the imputation correction is meaningless. 2. `pension_assets_adjustment` consumes `imputed_pension_wealth_next_period` which consumes `next_aime`. The dead per-target transitions don't include `next_aime` (dead has no aime state), so dags can't resolve it and pylcm leaks `next_aime` into the kernel signature with no value to pass. Non-dead targets keep `assets_and_income.next_assets` (full version with the pension correction). --- src/aca_model/agent/assets_and_income.py | 19 +++++++++++++- src/aca_model/baseline/regimes/_common.py | 30 ++++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 46d4b1c..cb89c89 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -55,7 +55,7 @@ def next_assets( consumption: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: - """Compute beginning-of-next-period assets. + """Compute beginning-of-next-period assets for non-terminal targets. OOP health costs are deducted here (not from cash_on_hand) so that the consumption choice does not condition on the HCC shock realization. @@ -65,6 +65,23 @@ def next_assets( ) +def next_assets_terminal( + cash_on_hand: FloatND, + transfers: FloatND, + consumption: ContinuousAction, + oop_costs: FloatND, +) -> ContinuousState: + """Compute beginning-of-next-period assets for the dead/terminal target. + + No `pension_assets_adjustment` term: with no future, there is no + next-period pension wealth to impute against. Avoiding the dependency + also keeps the `dead` per-target transition's DAG free of `next_aime` + (which would otherwise need to come from a transition `dead` does not + have, since `aime` is not a state in the terminal regime). + """ + return cash_on_hand + transfers - consumption - oop_costs + + def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 30198aa..56887d9 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -644,7 +644,7 @@ def build_state_transitions(spec: dict[str, str]) -> dict: """Build the state transitions dict for a non-dead regime.""" transitions: dict = {} transitions["health"] = _build_per_target_health(spec) - transitions["assets"] = assets_and_income.next_assets + transitions["assets"] = _build_per_target_next_assets(spec) transitions["pref_type"] = None transitions["aime"] = ( social_security.next_aime @@ -661,6 +661,34 @@ def build_state_transitions(spec: dict[str, str]) -> dict: return transitions +def _build_per_target_next_assets(spec: dict[str, str]) -> dict: + """Build per-target assets transitions. + + The `dead` target uses `next_assets_terminal` (no + `pension_assets_adjustment`), so the dead per-target DAG does not + pull in the `next_aime`-dependent imputation chain — `dead` has no + `aime` state and pylcm cannot resolve `next_aime` there. Non-dead + targets use the full `next_assets` with the pension correction. + """ + targets = precompute_targets(spec) + id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + + result: dict = {} + seen_ids: set[int] = set() + + for target_id in targets.values(): + if target_id in seen_ids: + continue + seen_ids.add(target_id) + target_name = id_to_name.get(target_id) + if target_name is None: + continue + result[target_name] = assets_and_income.next_assets + + result["dead"] = assets_and_income.next_assets_terminal + return result + + def _build_per_target_health(spec: dict[str, str]) -> dict: """Build per-target health transitions. From e1a3eb2478c8616317afa762583f50d9c31de86d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 5 May 2026 20:45:22 +0200 Subject: [PATCH 23/41] create_model: register target_his as derived categorical at base layer The pension imputation correction's `imputed_pension_wealth_next_period` indexes shifted arrays via `arr[period, target_his]`, where `target_his` is a DAG output (computed by `health_insurance.target_his` on nongroup/tied/retiree regimes), not a state. pylcm reads the level name `target_his` off the function body via AST inference and rejects matching `pd.Series` fixed_params unless `target_his` is declared as a derived categorical. Production `task_simulate_baseline` calls `create_model(...)` directly, which previously only forwarded the user's `derived_categoricals` arg. The benchmark module was masking this by injecting target_his via `_DERIVED_CATEGORICALS`. Move the declaration to `create_model` itself so the correction works in production without per-caller setup. Tighten the param annotation: pylcm's `Model.derived_categoricals` is a flat `Mapping[str, DiscreteGrid]`, never the nested form. --- src/aca_model/baseline/model.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index fe181eb..b0a2d79 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -15,6 +15,7 @@ from lcm import AgeGrid, DiscreteGrid, Model +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -25,8 +26,7 @@ def create_model( n_subjects: int, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, + derived_categoricals: Mapping[str, DiscreteGrid] | None = None, grid_config: GridConfig = GRID_CONFIG, pref_type_grid: DiscreteGrid | None = None, ) -> Model: @@ -69,13 +69,24 @@ def create_model( pref_type_grid=pref_type_grid, ) + # `target_his` is a DAG output of `health_insurance.target_his` (set on + # nongroup/tied/retiree regimes). The pension imputation correction + # (`imputed_pension_wealth_next_period`) indexes shifted arrays by + # `arr[period, target_his]`; pylcm needs the categorical declared so + # `pd.Series` fixed_params with a `target_his` index level resolve. + base_derived: dict[str, DiscreteGrid] = { + "target_his": DiscreteGrid(HealthInsuranceState), + } + if derived_categoricals is not None: + base_derived.update(derived_categoricals) + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", fixed_params=fixed_params or {}, - derived_categoricals=derived_categoricals, + derived_categoricals=base_derived, n_subjects=n_subjects, ) # See `MAX_CONSUMPTION` in `baseline.regimes._common` for why this From 00ee7d2236be9c62286c2aeceffc3b2fc5128b4a Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 05:11:54 +0200 Subject: [PATCH 24/41] aca/model.create_model: register target_his at base layer Same fix as baseline.model.create_model e1a3eb2: ACA variant model creation also takes its own path through `Model(...)`, so the production `task_simulate_aca_*` flows hit the same "Unrecognised indexing parameter 'target_his'" error after the pension correction landed. Move the derived-categorical declaration into the function itself rather than relying on per-caller setup. Tighten the param annotation to match pylcm's flat `Mapping[str, DiscreteGrid]`. --- src/aca_model/aca/model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index ee8efc6..1cc7ff4 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -11,6 +11,7 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -22,8 +23,7 @@ def create_model( policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, + derived_categoricals: Mapping[str, DiscreteGrid] | None = None, grid_config: GridConfig = GRID_CONFIG, ) -> Model: """Create an ACA policy variant model. @@ -61,13 +61,21 @@ def create_model( wage_params=wage_params, ) + # See `baseline.model.create_model` for why `target_his` is declared + # as a base-layer derived categorical. + base_derived: dict[str, DiscreteGrid] = { + "target_his": DiscreteGrid(HealthInsuranceState), + } + if derived_categoricals is not None: + base_derived.update(derived_categoricals) + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", fixed_params=fixed_params or {}, - derived_categoricals=derived_categoricals, + derived_categoricals=base_derived, n_subjects=n_subjects, ) model.max_consumption = MAX_CONSUMPTION From edfa540ad23299d625fc2c247970014fd31fb91e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 09:39:33 +0200 Subject: [PATCH 25/41] =?UTF-8?q?tests:=20positive=20regression=20guard=20?= =?UTF-8?q?=E2=80=94=20assets=3D-$1M=20passes=20benchmark=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Asserts that `validate_initial_conditions` admits a subject placed at `assets = -1_000_000` in `retiree_nomc_inelig_canwork` under the benchmark model. Encodes the economic story: with the consumption floor / transfer system, any past assets level is representable — `c = c_floor` is always feasible because `transfers` tops up cash-on-hand to the floor. The test passes today on benchmark params; it doesn't reproduce the gpu-01 failure (production-side, separate setup loaded by `aca-estimation`'s `assemble_fixed_params`). Kept as a permanent regression guard so a future change that re-introduces a constraint shape that rejects extreme negatives is caught immediately at benchmark scale. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../test_initial_conditions_extreme_assets.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/test_initial_conditions_extreme_assets.py diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py new file mode 100644 index 0000000..0007966 --- /dev/null +++ b/tests/test_initial_conditions_extreme_assets.py @@ -0,0 +1,50 @@ +"""Subjects at extreme negative assets must clear `validate_initial_conditions`. + +The transfer system (`agent.assets_and_income.transfers`) tops cash-on-hand +to `consumption_floor * equivalence_scale` at any starting state, so the +lowest consumption-grid point is always a feasible action regardless of +how negative starting assets are. The model's constraints — and pylcm's +`validate_initial_conditions` pass — must reflect this. +""" + +import jax.numpy as jnp +from lcm.simulation.initial_conditions import validate_initial_conditions + +from aca_model.benchmark import ( + create_benchmark_model, + get_benchmark_initial_conditions, + get_benchmark_params, +) + + +def test_extreme_negative_assets_subject_passes_validation() -> None: + """A subject placed at `assets = -1_000_000` clears initial-conditions validation. + + HRS bottom-codes very-large-negative net wealth at exactly $-1{,}000{,}000$. + Such subjects should remain in the simulated population: the consumption + floor / transfer system absorbs them, with `c = c_floor` always feasible. + """ + n_subjects = 1 + model = create_benchmark_model(n_subjects=n_subjects) + _, params = get_benchmark_params(model=model) + + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + initial_conditions = { + **initial_conditions, + "assets": jnp.asarray([-1_000_000.0]), + "regime": jnp.asarray( + [model.regime_names_to_ids["retiree_nomc_inelig_canwork"]], + dtype=jnp.int32, + ), + } + + internal_params = model._process_params(params) # noqa: SLF001 + validate_initial_conditions( + initial_conditions=initial_conditions, + internal_regimes=model.internal_regimes, + regime_names_to_ids=model.regime_names_to_ids, + internal_params=internal_params, + ages=model.ages, + ) From d05df9e19433a7bdeeb30828692f3042608174d0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 10:00:09 +0200 Subject: [PATCH 26/41] borrowing_constraint: use max(cash_on_hand, floor) to dodge fp32 cancellation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The expression `cash_on_hand + transfers` suffers float32 catastrophic cancellation when `|cash_on_hand|` is much larger than `consumption_floor`. For a subject at $-1{,}000{,}000$ in starting assets: cash_on_hand ≈ -1e6 (dominated by assets) transfers = max(0, c_floor - cash_on_hand) ≈ c_floor + 1e6 cash_on_hand + transfers ≈ c_floor ± 0.1 (fp32 error at 1e6 magnitude) The lowest grid `c` is exactly `c_floor`. With unfavorable rounding, `c_floor <= c_floor - 0.1` is False — every action gets rejected and `validate_initial_conditions` raises. This is exactly the failure gpu-01 hit on `task_simulate_aca_*`: the per-constraint diagnostic showed `borrowing_constraint = False` (rejects every action by itself) while `positive_leisure = True`. The algebraic identity `cash_on_hand + transfers == max(cash_on_hand, floor)` (where `floor = c_floor * equivalence_scale`) holds exactly because `transfers` is defined as `max(0, floor - cash_on_hand)`. Substituting in: cash_on_hand + max(0, floor - cash_on_hand) = max(cash_on_hand, cash_on_hand + floor - cash_on_hand) = max(cash_on_hand, floor) The `max` form has no cancellation: it returns `floor` exactly when `cash_on_hand << floor`, and `cash_on_hand` exactly otherwise. Switch the constraint to take `consumption_floor` and `equivalence_scale` directly and compute `floor = consumption_floor * equivalence_scale` in-line. Add a precision-specific unit test asserting `c = c_floor` is admitted at `cash_on_hand = -$1M` in fp32. The pre-existing benchmark-based regression guard (`test_extreme_negative_assets_subject_passes_ validation`) didn't catch the bug because benchmark params land on the favorable side of the rounding; the new test exercises the exact cancellation case. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/assets_and_income.py | 23 +++++++++++++++---- .../test_initial_conditions_extreme_assets.py | 22 ++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index cb89c89..629fc42 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -85,15 +85,28 @@ def next_assets_terminal( def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, - transfers: FloatND, + consumption_floor: float, + equivalence_scale: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources (no borrowing). - - `pension_assets_adjustment` is excluded: it can be negative (e.g., + """Consumption cannot exceed available resources after transfers. + + Post-transfer resources are `max(cash_on_hand, consumption_floor * + equivalence_scale)`: the transfer system tops `cash_on_hand` to the + floor when below, otherwise resources are unchanged. The algebraic + identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`, + but writing it as `cash_on_hand + transfers` triggers float32 + catastrophic cancellation when `|cash_on_hand|` dwarfs + `consumption_floor` — e.g. a subject at $-1{,}000{,}000$ in starting + assets gives `(-1e6) + (c_floor + 1e6)` with ~0.1 of rounding error, + which can wipe out the `c == c_floor` boundary and reject every + feasible action. The `max` form has no cancellation. + + `pension_assets_adjustment` is excluded: it can be negative (e.g. when the imputation overstates next-period pension wealth at a cross-HIS transition), and including it here can leave no feasible action at low-asset / mid-AIME corners. The correction enters `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - return consumption <= cash_on_hand + transfers + floor = consumption_floor * equivalence_scale + return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 0007966..bd88045 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -10,6 +10,7 @@ import jax.numpy as jnp from lcm.simulation.initial_conditions import validate_initial_conditions +from aca_model.agent.assets_and_income import borrowing_constraint from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -17,6 +18,27 @@ ) +def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() -> None: + """At `cash_on_hand = -$1M` (fp32), `c = c_floor` remains a feasible choice. + + Computing `cash_on_hand + transfers` directly suffers float32 catastrophic + cancellation: `(-1e6) + (c_floor + 1e6)` loses ~0.1 of precision, enough + to wipe out the `c == c_floor` boundary. The constraint must use the + algebraically equivalent but numerically stable `max(cash_on_hand, floor)` + form. + """ + consumption_floor = 5_000.0 + admitted = bool( + borrowing_constraint( + consumption=jnp.float32(consumption_floor), + cash_on_hand=jnp.float32(-1_000_000.0), + consumption_floor=consumption_floor, + equivalence_scale=jnp.float32(1.0), + ) + ) + assert admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From 4af83596dac2161c6485bea46f16fec0744e69c9 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 10:05:40 +0200 Subject: [PATCH 27/41] ci: bump pylcm pin to e4cae2aa (post-#342, post-#340 diagnostic) The previous pin (6c610d1, "Lock integer dtype to int32 end-to-end") predates pylcm #342, so the test_initial_conditions_extreme_assets test (and any other test that solves a benchmark regime carrying the pension-imputation correction) raised: InvalidParamsError: Missing required parameter: 'retiree_nomc_inelig_canwork__imputed_pension_wealth_next_period__next_aime' #342's `regime_template` change exempts `next_` references inside transition signatures from `fixed_param` extraction, which the correction's `imputed_pension_wealth_next_period(next_aime, ...)` signature relies on. The new pin tracks `feat/simulate-aot-n-subjects`, which carries #342, #339, #340 (n_subjects API used by `create_benchmark_model`), and the per-constraint validation diagnostic. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 110aafe..565245d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@6c610d19644d3f524ad112ed16c0621ee2ecd326" + git+https://github.com/OpenSourceEconomics/pylcm.git@e4cae2aa57d4bf568b8ebbade55d44571e3a086f" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From 0c7f2d589e8dba50dfc115f33dd44ba7e6396ae0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 10:56:49 +0200 Subject: [PATCH 28/41] =?UTF-8?q?wip:=20debug=20script=20=E2=80=94=20cash?= =?UTF-8?q?=5Fon=5Fhand=20per=20failing=20subject?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug_cash_on_hand.py | 176 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 debug_cash_on_hand.py diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py new file mode 100644 index 0000000..bac4a08 --- /dev/null +++ b/debug_cash_on_hand.py @@ -0,0 +1,176 @@ +"""Print cash_on_hand for the failing subjects at every labor_supply choice. + +If `cash_on_hand` evaluates to NaN for any subject, that explains why my +new `borrowing_constraint = c <= max(cash_on_hand, floor)` rejects every +action: `max(NaN, floor) == NaN` and `c <= NaN == False`. + +Usage on gpu-01: + cd ~/aca-dev + pixi run -e cuda12 python aca-model/debug_cash_on_hand.py +""" + +import pickle + +import jax.numpy as jnp +import numpy as np +import pandas as pd +from dags import concatenate_functions + +from aca_data.config import data_catalog +from aca_estimation._assemble_params import ( + _NON_MODEL_KEYS, + assemble_fixed_params, + assemble_params, + broadcast_to_template, +) +from aca_estimation._type_prediction import triple_initdist_by_pref_type +from aca_model.aca import PolicyVariant +from aca_model.aca.model import create_model as create_aca_model +from aca_model.config import GRID_CONFIG_FOR_RUN +from aca_model.consumption_grid import inject_consumption_points + +# Subjects whose `borrowing_constraint=False` in the gpu-01 production +# diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included +# as a positive control: production showed `borrowing_constraint=True` +# for it, so its cash_on_hand should be finite. +_TARGETS: tuple[tuple[int, str], ...] = ( + (1131, "nongroup_nomc_inelig_canwork"), + (1299, "nongroup_nomc_inelig_canwork"), # positive control + (9013, "retiree_nomc_inelig_canwork"), + (10108, "nongroup_dimc_inelig_canwork"), +) + + +def _load_pickle(name: str): + with open(data_catalog[name], "rb") as fh: + return pickle.load(fh) + + +def main() -> None: + ss = _load_pickle("social_security_params") + tax = _load_pickle("tax_params") + ssi = _load_pickle("ssi_medicaid_params") + hi = _load_pickle("health_insurance_params") + pension = _load_pickle("pension_params") + wage = _load_pickle("wage_offer") + transition = _load_pickle("transition_params") + env = _load_pickle("environment_constants") + hcc_insurer = _load_pickle("hcc_insurer_params") + pref = _load_pickle("preference_start_values") + initdist_df = pd.read_pickle(data_catalog["initial_conditions"]) + + n_subjects = 3 * len(initdist_df) + bare_model = create_aca_model( + policy=PolicyVariant.ACA, grid_config=GRID_CONFIG_FOR_RUN, n_subjects=1 + ) + template = bare_model.get_params_template() + fixed_params = assemble_fixed_params( + bare_model=bare_model, + ss_params=ss, + tax_params=tax, + ssi_params=ssi, + hi_params=hi, + pension_params=pension, + wage_params=wage, + transition_params=transition, + env_params=env, + hcc_insurer_params=hcc_insurer, + pref_params=pref, + ) + broadcast_to_template(params=fixed_params, template=template, required=False) + params = assemble_params( + pref_params=pref, base_wage_profile=wage["log_ft_wage_base"] + ) + + model = create_aca_model( + n_subjects=n_subjects, + policy=PolicyVariant.ACA, + fixed_params=fixed_params, + wage_params=wage, + grid_config=GRID_CONFIG_FOR_RUN, + ) + model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS} + model_params = inject_consumption_points(params=model_params, model=model) + initial = triple_initdist_by_pref_type(initdist_df) + + internal_params = model._process_params(model_params) # noqa: SLF001 + + # Evaluate cash_on_hand and borrowing_constraint for each target subject + # at each labor_supply choice with c = consumption_floor. + consumption_floor = float(model_params["consumption_floor"]) + for subject_id, regime_name in _TARGETS: + regime = model.regimes[regime_name] + internal_regime = model.internal_regimes[regime_name] + functions = internal_regime.simulate_functions.functions + constraints = internal_regime.simulate_functions.constraints + regime_params = { + **internal_regime.resolved_fixed_params, + **dict(internal_params.get(regime_name, {})), + } + + # Build a function returning (cash_on_hand, borrowing_constraint). + targets = ["cash_on_hand"] + if "borrowing_constraint" in constraints: + targets.append("borrowing_constraint") + all_funcs = dict(functions) + all_funcs.update(dict(constraints)) + evaluator = concatenate_functions( + functions=all_funcs, + targets=targets, + return_type="dict", + enforce_signature=False, + set_annotations=True, + ) + + # Per-subject states (single subject; pull idx subject_id from the + # already-tripled initial conditions). + subject_state = { + k: v[subject_id : subject_id + 1] + for k, v in initial.items() + if k != "regime" + } + + labor_supply_grid = np.asarray(regime.actions["labor_supply"].to_jax()) + print(f"\n=== subject {subject_id} ({regime_name}) ===") + print( + f" state: assets={float(subject_state['assets'][0]):.2f}, " + f"aime={float(subject_state['aime'][0]):.2f}, " + f"spousal_income={int(subject_state['spousal_income'][0])}, " + f"health={int(subject_state['health'][0])}, " + f"hcc_persistent={float(subject_state['hcc_persistent'][0]):.4f}, " + f"hcc_transitory={float(subject_state['hcc_transitory'][0]):.4f}" + ) + for ls in labor_supply_grid: + kwargs = { + **{k: v[0] for k, v in subject_state.items()}, + "consumption": jnp.float32(consumption_floor), + "labor_supply": jnp.int32(int(ls)), + "age": jnp.float32(51.0), + "period": jnp.int32(0), + **{k: v for k, v in regime_params.items()}, + } + try: + out = evaluator( + **{ + k: v + for k, v in kwargs.items() + if k in evaluator.__signature__.parameters + } + ) + coh = float(out["cash_on_hand"]) + bc = ( + bool(out.get("borrowing_constraint", True)) + if "borrowing_constraint" in out + else "n/a" + ) + nan_flag = " <-- NaN!" if not np.isfinite(coh) else "" + print( + f" ls={int(ls):d}: cash_on_hand={coh:14.2f} " + f"borrowing_constraint(c=c_floor)={bc}{nan_flag}" + ) + except (KeyError, TypeError) as exc: + print(f" ls={int(ls):d}: eval failed: {exc!r}") + + +if __name__ == "__main__": + main() From 8ffbf5c53063919fff3bbd1f0f49ea4f1691c321 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:05:57 +0200 Subject: [PATCH 29/41] wip: fix imports in debug script (broadcast_to_template + ACA_DATA_BLD) --- debug_cash_on_hand.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index bac4a08..c9fc3bc 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -15,14 +15,14 @@ import numpy as np import pandas as pd from dags import concatenate_functions +from lcm.params.processing import broadcast_to_template -from aca_data.config import data_catalog from aca_estimation._assemble_params import ( _NON_MODEL_KEYS, assemble_fixed_params, assemble_params, - broadcast_to_template, ) +from aca_estimation.config import ACA_DATA_BLD from aca_estimation._type_prediction import triple_initdist_by_pref_type from aca_model.aca import PolicyVariant from aca_model.aca.model import create_model as create_aca_model @@ -41,23 +41,23 @@ ) -def _load_pickle(name: str): - with open(data_catalog[name], "rb") as fh: +def _load(name: str): + with open(ACA_DATA_BLD / f"{name}.pkl", "rb") as fh: return pickle.load(fh) def main() -> None: - ss = _load_pickle("social_security_params") - tax = _load_pickle("tax_params") - ssi = _load_pickle("ssi_medicaid_params") - hi = _load_pickle("health_insurance_params") - pension = _load_pickle("pension_params") - wage = _load_pickle("wage_offer") - transition = _load_pickle("transition_params") - env = _load_pickle("environment_constants") - hcc_insurer = _load_pickle("hcc_insurer_params") - pref = _load_pickle("preference_start_values") - initdist_df = pd.read_pickle(data_catalog["initial_conditions"]) + ss = _load("social_security_params") + tax = _load("tax_params") + ssi = _load("ssi_medicaid_params") + hi = _load("health_insurance_params") + pension = _load("pension_params") + wage = _load("wage_params") + transition = _load("transition_probs") + env = _load("environment_constants") + hcc_insurer = _load("hcc_insurer_params") + pref = _load("preference_start_values") + initdist_df = pd.read_pickle(ACA_DATA_BLD / "initial_conditions.pkl") n_subjects = 3 * len(initdist_df) bare_model = create_aca_model( From 81cca3c5fe0995b27e4eafceb4352bc9a11c8dc3 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:07:26 +0200 Subject: [PATCH 30/41] wip: import GRID_CONFIG_FOR_RUN from aca_estimation --- debug_cash_on_hand.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index c9fc3bc..2b9e4b4 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -26,7 +26,7 @@ from aca_estimation._type_prediction import triple_initdist_by_pref_type from aca_model.aca import PolicyVariant from aca_model.aca.model import create_model as create_aca_model -from aca_model.config import GRID_CONFIG_FOR_RUN +from aca_estimation.config import GRID_CONFIG_FOR_RUN from aca_model.consumption_grid import inject_consumption_points # Subjects whose `borrowing_constraint=False` in the gpu-01 production From e320f41ac4137a14c8f499d88aa6b49838b60f5f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:11:40 +0200 Subject: [PATCH 31/41] wip: pass derived_categoricals to create_aca_model in debug --- debug_cash_on_hand.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index 2b9e4b4..a2b6fd3 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -17,18 +17,30 @@ from dags import concatenate_functions from lcm.params.processing import broadcast_to_template +from lcm import DiscreteGrid + from aca_estimation._assemble_params import ( _NON_MODEL_KEYS, assemble_fixed_params, assemble_params, ) -from aca_estimation.config import ACA_DATA_BLD from aca_estimation._type_prediction import triple_initdist_by_pref_type +from aca_estimation.config import ACA_DATA_BLD, GRID_CONFIG_FOR_RUN from aca_model.aca import PolicyVariant from aca_model.aca.model import create_model as create_aca_model -from aca_estimation.config import GRID_CONFIG_FOR_RUN +from aca_model.agent.health import GoodHealth +from aca_model.agent.labor_market import IsMarried +from aca_model.agent.preferences import PrefType +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.consumption_grid import inject_consumption_points +_DERIVED_CATEGORICALS = { + "good_health": DiscreteGrid(GoodHealth), + "is_married": DiscreteGrid(IsMarried), + "his": DiscreteGrid(HealthInsuranceState), + "pref_type": DiscreteGrid(PrefType), +} + # Subjects whose `borrowing_constraint=False` in the gpu-01 production # diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included # as a positive control: production showed `borrowing_constraint=True` @@ -87,6 +99,7 @@ def main() -> None: policy=PolicyVariant.ACA, fixed_params=fixed_params, wage_params=wage, + derived_categoricals=_DERIVED_CATEGORICALS, grid_config=GRID_CONFIG_FOR_RUN, ) model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS} From 2208fa6777f46c55bb03d54dc06eb5fd536c8cac Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:20:46 +0200 Subject: [PATCH 32/41] wip: augment fixed_params for ACA policy in debug --- debug_cash_on_hand.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index a2b6fd3..72531f3 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -89,6 +89,11 @@ def main() -> None: hcc_insurer_params=hcc_insurer, pref_params=pref, ) + from aca_estimation.task_simulate_aca import _augment_fixed_params_for_aca + + _augment_fixed_params_for_aca( + fixed_params=fixed_params, ssi_params=ssi, policy=PolicyVariant.ACA + ) broadcast_to_template(params=fixed_params, template=template, required=False) params = assemble_params( pref_params=pref, base_wage_profile=wage["log_ft_wage_base"] From 8adabda7952b5d0e96218322f3fe7620dbb10a62 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:54:15 +0200 Subject: [PATCH 33/41] borrowing_constraint: cast consumption_floor to consumption's dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Production failure root cause: `consumption_floor` is a Python fp64 float (≈ 1597.0921419521899); `consumption` arrives from the model's fp32 grid (`jnp.geomspace(consumption_floor, ...)`), quantized to 1597.0921630859375 — one fp32 ulp above the input. Without an explicit dtype cast on the floor, `consumption_floor * equivalence_scale` keeps its fp64 type, the comparison promotes to fp64, and the lowest grid point evaluates as 1597.0921630859375 > 1597.0921419521899 → False. Constraint rejects every action. Cast `consumption_floor` to `consumption.dtype` before the multiply so both sides of the `max` use the same precision. Constraint then admits c=c_floor by exact equality in fp32. Diagnosed via the per-constraint admissibility table (pylcm 838473e/ e4cae2a): production showed `borrowing_constraint=False` at modest asset levels (e.g. -$42k), where neither cash_on_hand magnitude nor NaN propagation could explain the rejection. Local repro pinned the ulp mismatch. Add `test_borrowing_constraint_admits_c_floor_with_python_float_floor` as a regression guard at the precise production scenario. Drop the debug script; it served its purpose. Co-Authored-By: Claude Opus 4.7 (1M context) --- debug_cash_on_hand.py | 194 ------------------ src/aca_model/agent/assets_and_income.py | 2 +- .../test_initial_conditions_extreme_assets.py | 28 +++ 3 files changed, 29 insertions(+), 195 deletions(-) delete mode 100644 debug_cash_on_hand.py diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py deleted file mode 100644 index 72531f3..0000000 --- a/debug_cash_on_hand.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Print cash_on_hand for the failing subjects at every labor_supply choice. - -If `cash_on_hand` evaluates to NaN for any subject, that explains why my -new `borrowing_constraint = c <= max(cash_on_hand, floor)` rejects every -action: `max(NaN, floor) == NaN` and `c <= NaN == False`. - -Usage on gpu-01: - cd ~/aca-dev - pixi run -e cuda12 python aca-model/debug_cash_on_hand.py -""" - -import pickle - -import jax.numpy as jnp -import numpy as np -import pandas as pd -from dags import concatenate_functions -from lcm.params.processing import broadcast_to_template - -from lcm import DiscreteGrid - -from aca_estimation._assemble_params import ( - _NON_MODEL_KEYS, - assemble_fixed_params, - assemble_params, -) -from aca_estimation._type_prediction import triple_initdist_by_pref_type -from aca_estimation.config import ACA_DATA_BLD, GRID_CONFIG_FOR_RUN -from aca_model.aca import PolicyVariant -from aca_model.aca.model import create_model as create_aca_model -from aca_model.agent.health import GoodHealth -from aca_model.agent.labor_market import IsMarried -from aca_model.agent.preferences import PrefType -from aca_model.baseline.health_insurance import HealthInsuranceState -from aca_model.consumption_grid import inject_consumption_points - -_DERIVED_CATEGORICALS = { - "good_health": DiscreteGrid(GoodHealth), - "is_married": DiscreteGrid(IsMarried), - "his": DiscreteGrid(HealthInsuranceState), - "pref_type": DiscreteGrid(PrefType), -} - -# Subjects whose `borrowing_constraint=False` in the gpu-01 production -# diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included -# as a positive control: production showed `borrowing_constraint=True` -# for it, so its cash_on_hand should be finite. -_TARGETS: tuple[tuple[int, str], ...] = ( - (1131, "nongroup_nomc_inelig_canwork"), - (1299, "nongroup_nomc_inelig_canwork"), # positive control - (9013, "retiree_nomc_inelig_canwork"), - (10108, "nongroup_dimc_inelig_canwork"), -) - - -def _load(name: str): - with open(ACA_DATA_BLD / f"{name}.pkl", "rb") as fh: - return pickle.load(fh) - - -def main() -> None: - ss = _load("social_security_params") - tax = _load("tax_params") - ssi = _load("ssi_medicaid_params") - hi = _load("health_insurance_params") - pension = _load("pension_params") - wage = _load("wage_params") - transition = _load("transition_probs") - env = _load("environment_constants") - hcc_insurer = _load("hcc_insurer_params") - pref = _load("preference_start_values") - initdist_df = pd.read_pickle(ACA_DATA_BLD / "initial_conditions.pkl") - - n_subjects = 3 * len(initdist_df) - bare_model = create_aca_model( - policy=PolicyVariant.ACA, grid_config=GRID_CONFIG_FOR_RUN, n_subjects=1 - ) - template = bare_model.get_params_template() - fixed_params = assemble_fixed_params( - bare_model=bare_model, - ss_params=ss, - tax_params=tax, - ssi_params=ssi, - hi_params=hi, - pension_params=pension, - wage_params=wage, - transition_params=transition, - env_params=env, - hcc_insurer_params=hcc_insurer, - pref_params=pref, - ) - from aca_estimation.task_simulate_aca import _augment_fixed_params_for_aca - - _augment_fixed_params_for_aca( - fixed_params=fixed_params, ssi_params=ssi, policy=PolicyVariant.ACA - ) - broadcast_to_template(params=fixed_params, template=template, required=False) - params = assemble_params( - pref_params=pref, base_wage_profile=wage["log_ft_wage_base"] - ) - - model = create_aca_model( - n_subjects=n_subjects, - policy=PolicyVariant.ACA, - fixed_params=fixed_params, - wage_params=wage, - derived_categoricals=_DERIVED_CATEGORICALS, - grid_config=GRID_CONFIG_FOR_RUN, - ) - model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS} - model_params = inject_consumption_points(params=model_params, model=model) - initial = triple_initdist_by_pref_type(initdist_df) - - internal_params = model._process_params(model_params) # noqa: SLF001 - - # Evaluate cash_on_hand and borrowing_constraint for each target subject - # at each labor_supply choice with c = consumption_floor. - consumption_floor = float(model_params["consumption_floor"]) - for subject_id, regime_name in _TARGETS: - regime = model.regimes[regime_name] - internal_regime = model.internal_regimes[regime_name] - functions = internal_regime.simulate_functions.functions - constraints = internal_regime.simulate_functions.constraints - regime_params = { - **internal_regime.resolved_fixed_params, - **dict(internal_params.get(regime_name, {})), - } - - # Build a function returning (cash_on_hand, borrowing_constraint). - targets = ["cash_on_hand"] - if "borrowing_constraint" in constraints: - targets.append("borrowing_constraint") - all_funcs = dict(functions) - all_funcs.update(dict(constraints)) - evaluator = concatenate_functions( - functions=all_funcs, - targets=targets, - return_type="dict", - enforce_signature=False, - set_annotations=True, - ) - - # Per-subject states (single subject; pull idx subject_id from the - # already-tripled initial conditions). - subject_state = { - k: v[subject_id : subject_id + 1] - for k, v in initial.items() - if k != "regime" - } - - labor_supply_grid = np.asarray(regime.actions["labor_supply"].to_jax()) - print(f"\n=== subject {subject_id} ({regime_name}) ===") - print( - f" state: assets={float(subject_state['assets'][0]):.2f}, " - f"aime={float(subject_state['aime'][0]):.2f}, " - f"spousal_income={int(subject_state['spousal_income'][0])}, " - f"health={int(subject_state['health'][0])}, " - f"hcc_persistent={float(subject_state['hcc_persistent'][0]):.4f}, " - f"hcc_transitory={float(subject_state['hcc_transitory'][0]):.4f}" - ) - for ls in labor_supply_grid: - kwargs = { - **{k: v[0] for k, v in subject_state.items()}, - "consumption": jnp.float32(consumption_floor), - "labor_supply": jnp.int32(int(ls)), - "age": jnp.float32(51.0), - "period": jnp.int32(0), - **{k: v for k, v in regime_params.items()}, - } - try: - out = evaluator( - **{ - k: v - for k, v in kwargs.items() - if k in evaluator.__signature__.parameters - } - ) - coh = float(out["cash_on_hand"]) - bc = ( - bool(out.get("borrowing_constraint", True)) - if "borrowing_constraint" in out - else "n/a" - ) - nan_flag = " <-- NaN!" if not np.isfinite(coh) else "" - print( - f" ls={int(ls):d}: cash_on_hand={coh:14.2f} " - f"borrowing_constraint(c=c_floor)={bc}{nan_flag}" - ) - except (KeyError, TypeError) as exc: - print(f" ls={int(ls):d}: eval failed: {exc!r}") - - -if __name__ == "__main__": - main() diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 629fc42..374edf4 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -108,5 +108,5 @@ def borrowing_constraint( `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - floor = consumption_floor * equivalence_scale + floor = jnp.asarray(consumption_floor, dtype=consumption.dtype) * equivalence_scale return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index bd88045..6df5e18 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -39,6 +39,34 @@ def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() - assert admitted +def test_borrowing_constraint_admits_c_floor_with_python_float_floor() -> None: + """Python-fp64 `consumption_floor` against fp32 `consumption` must compare in fp32. + + `consumption_floor` arrives at the constraint as a Python float (fp64), but + `consumption` comes from the model's fp32 grid (`jnp.geomspace(...)`), + quantized to a value that differs from the fp64 input by one fp32 ulp + (~2e-5 at $c_{floor} \\approx 1597$). Without an explicit dtype cast on the + floor, the comparison promotes to fp64 and the lowest grid point fails + the constraint. The fix forces the floor into `consumption.dtype` before + the `max` so both sides use the same precision. + + Reproduces the production failure on gpu-01 where every subject in + `nongroup_nomc_inelig_canwork` (and similar regimes) hit + `borrowing_constraint=False` despite legitimate cash_on_hand values. + """ + consumption_floor = 1597.0921419521899 # production value, fp64 + consumption_fp32 = jnp.float32(consumption_floor) + admitted = bool( + borrowing_constraint( + consumption=consumption_fp32, + cash_on_hand=jnp.float32(-44_937.9), + consumption_floor=consumption_floor, # raw Python float + equivalence_scale=jnp.float32(1.0), + ) + ) + assert admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From c895bd9f8a027355263d589d93dd21cca59af902 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 06:39:26 +0200 Subject: [PATCH 34/41] borrowing_constraint: drop dtype cast workaround `jnp.asarray(consumption_floor, dtype=consumption.dtype)` quantized the Python-float `consumption_floor` to the action grid's dtype to match the fp32-quantized consumption grid, so the `c == c_floor` boundary compared as exact equality. The pylcm canonical-float boundary cast (#345) routes every continuous-grid `to_jax()` through `canonical_float_dtype()`. Under `jax_enable_x64=True` (set in `aca_model/__init__.py`) that's `fp64`, so the action grid no longer quantizes the floor and Python-float / grid-value cannot disagree on dtype in the first place. Drop the regression test pinned to the cast workaround; the `max(cash_on_hand, floor)` cancellation guard and the full validate- initial-conditions integration test stay in place. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/assets_and_income.py | 2 +- .../test_initial_conditions_extreme_assets.py | 28 ------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 374edf4..629fc42 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -108,5 +108,5 @@ def borrowing_constraint( `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - floor = jnp.asarray(consumption_floor, dtype=consumption.dtype) * equivalence_scale + floor = consumption_floor * equivalence_scale return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 6df5e18..bd88045 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -39,34 +39,6 @@ def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() - assert admitted -def test_borrowing_constraint_admits_c_floor_with_python_float_floor() -> None: - """Python-fp64 `consumption_floor` against fp32 `consumption` must compare in fp32. - - `consumption_floor` arrives at the constraint as a Python float (fp64), but - `consumption` comes from the model's fp32 grid (`jnp.geomspace(...)`), - quantized to a value that differs from the fp64 input by one fp32 ulp - (~2e-5 at $c_{floor} \\approx 1597$). Without an explicit dtype cast on the - floor, the comparison promotes to fp64 and the lowest grid point fails - the constraint. The fix forces the floor into `consumption.dtype` before - the `max` so both sides use the same precision. - - Reproduces the production failure on gpu-01 where every subject in - `nongroup_nomc_inelig_canwork` (and similar regimes) hit - `borrowing_constraint=False` despite legitimate cash_on_hand values. - """ - consumption_floor = 1597.0921419521899 # production value, fp64 - consumption_fp32 = jnp.float32(consumption_floor) - admitted = bool( - borrowing_constraint( - consumption=consumption_fp32, - cash_on_hand=jnp.float32(-44_937.9), - consumption_floor=consumption_floor, # raw Python float - equivalence_scale=jnp.float32(1.0), - ) - ) - assert admitted - - def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From e0cc62211438afd877865504281228c1f205d90e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 15:47:19 +0200 Subject: [PATCH 35/41] tests: switch helpers import to relative form MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `from tests.helpers.social_security import …` collided with the sibling `tests/__init__.py` packages in aca-data and aca-estimation when pytest collected from the aca-dev workspace root — whichever `tests` package got imported first shadowed the others. Use a relative import so each test module resolves its own helpers package unambiguously. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_social_security.py | 2 +- tests/test_ss_benefit_integration.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_social_security.py b/tests/test_social_security.py index d75e458..e399d3b 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_di_dropout_scale, compute_pia_table +from .helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 0e77ea5..dadcac9 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_pia_table +from .helpers.social_security import compute_pia_table ATOL = 0.01 From 3d2faf4a8f04b27e60e41f4bb2d3efe4c35e1f49 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 17:47:58 +0200 Subject: [PATCH 36/41] tests: drop tests/__init__.py; expose helpers via conftest sys.path Reverts the relative-import attempt and instead removes the empty tests/__init__.py (which was colliding with aca-data and aca-estimation's identically named stubs across the aca-dev workspace). A new tests/conftest.py prepends the tests directory to sys.path so `from helpers.social_security import ...` resolves unambiguously. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/__init__.py | 0 tests/conftest.py | 4 ++++ tests/test_social_security.py | 2 +- tests/test_ss_benefit_integration.py | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) delete mode 100644 tests/__init__.py create mode 100644 tests/conftest.py diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1da1dcf --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) diff --git a/tests/test_social_security.py b/tests/test_social_security.py index e399d3b..d612f7d 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from .helpers.social_security import compute_di_dropout_scale, compute_pia_table +from helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index dadcac9..488df32 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from .helpers.social_security import compute_pia_table +from helpers.social_security import compute_pia_table ATOL = 0.01 From 97c84cd02bc08461e1a4316a013c2ddf24f13261 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 22:16:09 +0200 Subject: [PATCH 37/41] Drop precision-related workarounds and function defaults MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleanup driven by pylcm's canonical-float boundary cast (#345). With every input pinned to fp64 under `jax_enable_x64=True` (which `aca_model/__init__.py` sets at import), aca-side precision workarounds no longer have a hook. Source: - `borrowing_constraint`: switch from `consumption <= max(cash_on_hand, floor)` to `consumption <= cash_on_hand + transfers`. The two are algebraically identical (`cash_on_hand + transfers == max(cash_on_hand, floor)`); the `max` form was justified by float32 catastrophic cancellation at extreme negative cash_on_hand, which cannot occur under fp64. The constraint now consumes `transfers` directly instead of recomputing `consumption_floor * equivalence_scale` — `transfers` is already a DAG node, so the resolved interface is shorter. Defaults dropped (callers must pass everything explicitly): - `aca_model.benchmark.create_benchmark_model`: `pref_type_grid`. - `aca_model.benchmark.get_benchmark_params`: `model`. - `aca_model.benchmark.get_benchmark_initial_conditions`: `n_subjects`, `seed`. - `aca_model.baseline.model.create_model`: `fixed_params`, `wage_params`, `derived_categoricals`, `grid_config`, `pref_type_grid`. - `aca_model.aca.model.create_model`: `policy`, `fixed_params`, `wage_params`, `derived_categoricals`, `grid_config`. - `aca_model.baseline.regimes.build_all_regimes`: same five. - `aca_model.aca.regimes.build_all_regimes`: same four. - `aca_model.baseline.regimes._common.build_grids`: same four. - Drop `GRID_CONFIG` import where it was only used as a default value. Tests: - New `tests/helpers/model.py` exposes `make_baseline_model` and `make_aca_model` factories that wrap `create_model` with `None` for every optional input. Tests that don't need fixed params reach the factories through the helper rather than spelling out six `None`s each. Production code stays default-free. - New `test_benchmark_simulate_obeys_borrowing_constraint`: pins the invariant `consumption <= cash_on_hand + transfers` on every alive row of the benchmark simulation. Catches a regression that drops the constraint from a regime, replaces transfers with something looser, or lets an action grid skip the floor. - `test_initial_conditions_extreme_assets`: drop the fp32-specific cancellation regression test (the runtime no longer reaches that path); replace with a pair of unit tests for the new `borrowing_constraint(consumption, cash_on_hand, transfers)` signature. --- src/aca_model/aca/model.py | 38 ++++++++------- src/aca_model/aca/regimes/__init__.py | 15 +++--- src/aca_model/agent/assets_and_income.py | 33 ++++--------- src/aca_model/baseline/model.py | 45 +++++++++--------- src/aca_model/baseline/regimes/__init__.py | 15 +++--- src/aca_model/baseline/regimes/_common.py | 10 ++-- src/aca_model/benchmark.py | 26 +++++----- tests/helpers/model.py | 38 +++++++++++++++ tests/test_benchmark.py | 47 ++++++++++++++++++- .../test_initial_conditions_extreme_assets.py | 46 ++++++++++++------ tests/test_model_creation.py | 39 ++++++++++----- tests/test_social_security.py | 2 +- tests/test_ss_benefit_integration.py | 2 +- 13 files changed, 235 insertions(+), 121 deletions(-) create mode 100644 tests/helpers/model.py diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 1cc7ff4..b76adc6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -14,36 +14,40 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId from aca_model.baseline.regimes._common import MAX_CONSUMPTION -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - policy: PolicyVariant = PolicyVariant.ACA, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid] | None = None, - grid_config: GridConfig = GRID_CONFIG, + policy: PolicyVariant, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + derived_categoricals: Mapping[str, DiscreteGrid] | None, + grid_config: GridConfig, ) -> Model: """Create an ACA policy variant model. Args: n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. - policy: Which ACA policy combination to apply. - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + policy: Which ACA policy combination to apply (e.g. + `PolicyVariant.ACA`). + fixed_params: Parameters to fix at model creation time, or `None` + to skip. Fixed params are partialled into compiled functions + and removed from the params template. Pass data-derived + constants here; only estimation parameters should go through + `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values. + Not routed to the pylcm Model. `None` skips the floor sizing. + derived_categoricals: Extra categorical mappings for derived + variables not in the model's state/action grids, or `None`. + Needed when `fixed_params` contains `pd.Series` indexed by DAG + function outputs. + grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for + production values or `BENCHMARK_GRID_CONFIG` for the + fast-but-structurally-faithful benchmark. Returns: pylcm Model with ACA-specific function overrides. diff --git a/src/aca_model/aca/regimes/__init__.py b/src/aca_model/aca/regimes/__init__.py index 2c143bd..5b9f4bf 100644 --- a/src/aca_model/aca/regimes/__init__.py +++ b/src/aca_model/aca/regimes/__init__.py @@ -10,19 +10,22 @@ from aca_model.aca.regimes._overrides import apply_aca_overrides from aca_model.baseline.regimes import build_all_regimes as baseline_build_all_regimes from aca_model.baseline.regimes._common import REGIME_SPECS -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig def build_all_regimes( - policy: PolicyVariant, - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, + policy: PolicyVariant, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, ) -> dict[str, Regime]: """Build all 19 regimes with ACA policy overrides.""" regimes = baseline_build_all_regimes( - grid_config, fixed_params=fixed_params, wage_params=wage_params + grid_config=grid_config, + fixed_params=fixed_params, + wage_params=wage_params, + pref_type_grid=None, ) result = {} for name, regime in regimes.items(): diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 629fc42..c07fdb4 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -85,28 +85,15 @@ def next_assets_terminal( def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, - consumption_floor: float, - equivalence_scale: FloatND, + transfers: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources after transfers. - - Post-transfer resources are `max(cash_on_hand, consumption_floor * - equivalence_scale)`: the transfer system tops `cash_on_hand` to the - floor when below, otherwise resources are unchanged. The algebraic - identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`, - but writing it as `cash_on_hand + transfers` triggers float32 - catastrophic cancellation when `|cash_on_hand|` dwarfs - `consumption_floor` — e.g. a subject at $-1{,}000{,}000$ in starting - assets gives `(-1e6) + (c_floor + 1e6)` with ~0.1 of rounding error, - which can wipe out the `c == c_floor` boundary and reject every - feasible action. The `max` form has no cancellation. - - `pension_assets_adjustment` is excluded: it can be negative (e.g. - when the imputation overstates next-period pension wealth at a - cross-HIS transition), and including it here can leave no feasible - action at low-asset / mid-AIME corners. The correction enters - `next_assets` instead — a post-decision shift that does not gate - the current consumption choice. + """Consumption cannot exceed post-transfer resources. + + `pension_assets_adjustment` is excluded from the constraint: it can + be negative (e.g. when the imputation overstates next-period pension + wealth at a cross-HIS transition), and including it here can leave + no feasible action at low-asset / mid-AIME corners. The correction + enters `next_assets` instead — a post-decision shift that does not + gate the current consumption choice. """ - floor = consumption_floor * equivalence_scale - return consumption <= jnp.maximum(cash_on_hand, floor) + return consumption <= cash_on_hand + transfers diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index b0a2d79..1185eeb 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -18,39 +18,42 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.baseline.regimes._common import MAX_CONSUMPTION -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid] | None = None, - grid_config: GridConfig = GRID_CONFIG, - pref_type_grid: DiscreteGrid | None = None, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + derived_categoricals: Mapping[str, DiscreteGrid] | None, + grid_config: GridConfig, + pref_type_grid: DiscreteGrid | None, ) -> Model: """Create the baseline structural retirement model. Args: n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + fixed_params: Parameters to fix at model creation time, or `None` + to skip. Fixed params are partialled into compiled functions + and removed from the params template. Pass data-derived + constants here; only estimation parameters should go through + `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values; pass `BENCHMARK_GRID_CONFIG` for a fast-but-structurally- - faithful benchmark. - pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. - Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to - substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. + Not routed to the pylcm Model. `None` skips the floor sizing. + derived_categoricals: Extra categorical mappings for derived + variables not in the model's state/action grids, or `None`. + Needed when `fixed_params` contains `pd.Series` indexed by DAG + function outputs. + grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for + production values or `BENCHMARK_GRID_CONFIG` for the + fast-but-structurally-faithful benchmark. + pref_type_grid: Pref-type `DiscreteGrid`, or `None` to use + `DiscreteGrid(PrefType)`. Pass a custom grid (e.g. with a + `DispatchStrategy.PARTITION_SCAN` strategy) to substitute the + production layout. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -63,7 +66,7 @@ def create_model( step="Y", ) regimes = build_all_regimes( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index a0eaf9e..02e8a05 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -25,7 +25,7 @@ build_dead_regime, build_grids, ) -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig __all__ = [ "REGIME_SPECS", @@ -58,11 +58,11 @@ def build_regime(name: str, grids: Grids) -> Regime: def build_all_regimes( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + pref_type_grid: DiscreteGrid | None, ) -> dict[str, Regime]: """Build all 19 baseline regimes (18 non-terminal + dead). @@ -71,10 +71,11 @@ def build_all_regimes( either being `None` keeps the corresponding static fallback. `pref_type_grid` lets callers inject a compact or partition-lifted `DiscreteGrid(...)` (e.g. the benchmark uses a 2-type - `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`). + `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`); `None` + falls back to `DiscreteGrid(PrefType)`. """ grids = build_grids( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 56887d9..25347c6 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -36,7 +36,7 @@ from aca_model.agent.preferences import PrefType from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -208,11 +208,11 @@ class Grids: def build_grids( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + pref_type_grid: DiscreteGrid | None, ) -> Grids: """Build continuous-state/action grids from a `GridConfig`. diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 8e242b0..19416f2 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -75,7 +75,7 @@ def create_benchmark_model( *, n_subjects: int, - pref_type_grid: DiscreteGrid | None = None, + pref_type_grid: DiscreteGrid, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. @@ -85,22 +85,21 @@ def create_benchmark_model( `n_aime_batch_size = 0`). Args: - pref_type_grid: Override for the pref_type grid. Default is a plain - `DiscreteGrid(BenchmarkPrefType)` (fused vmap). Pass - `DiscreteGrid(BenchmarkPrefType, dispatch=DispatchStrategy.PARTITION_SCAN)` - (or `PARTITION_VMAP`) to get the partition-lifted kernel — the - recommended production setting for aca-model at scale, but only - supported on pylcm versions that expose `DispatchStrategy`. n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the first matching `simulate(...)` call AOT-compiles all simulate functions for that batch shape. + pref_type_grid: Pref-type grid. Pass `DiscreteGrid(BenchmarkPrefType)` + for plain fused-vmap, or + `DiscreteGrid(BenchmarkPrefType, dispatch=DispatchStrategy.PARTITION_SCAN)` + (or `PARTITION_VMAP`) for the partition-lifted kernel — the + recommended production setting for aca-model at scale, but only + supported on pylcm versions that expose `DispatchStrategy`. """ - if pref_type_grid is None: - pref_type_grid = DiscreteGrid(BenchmarkPrefType) - fixed_params, _ = get_benchmark_params() + fixed_params, _ = get_benchmark_params(model=None) return create_model( grid_config=BENCHMARK_GRID_CONFIG, fixed_params=fixed_params, + wage_params=None, derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, n_subjects=n_subjects, @@ -108,7 +107,7 @@ def create_benchmark_model( def get_benchmark_params( - *, model: Model | None = None + *, model: Model | None ) -> tuple[dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, params)` snapshot. @@ -119,7 +118,8 @@ def get_benchmark_params( When `model` is provided, consumption gridpoints are injected into `params` for each regime that declares `consumption` as an `IrregSpacedGrid` with runtime-supplied points. The lower bound is - read from `params["consumption_floor"]`. + read from `params["consumption_floor"]`. Pass `model=None` to skip + injection (e.g. when constructing the model with `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) @@ -222,7 +222,7 @@ def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: def get_benchmark_initial_conditions( - *, model: Model, n_subjects: int = 100, seed: int = 42 + *, model: Model, n_subjects: int, seed: int ) -> dict[str, Array]: """Draw random feasible initial conditions across five age-51 regimes. diff --git a/tests/helpers/model.py b/tests/helpers/model.py new file mode 100644 index 0000000..930c33e --- /dev/null +++ b/tests/helpers/model.py @@ -0,0 +1,38 @@ +"""Tiny factories that wrap `create_model` with `None` for every optional input. + +Used by tests that don't need fixed params, wage params, or a custom pref-type +grid. These helpers exist so production `create_model` factories can stay +default-free without forcing every test call site to spell out +`fixed_params=None, wage_params=None, ...` six times. +""" + +from lcm import Model + +from aca_model.aca.health_insurance import PolicyVariant +from aca_model.aca.model import create_model as _create_aca_model +from aca_model.baseline.model import create_model as _create_baseline_model +from aca_model.config import GRID_CONFIG + + +def make_baseline_model(*, n_subjects: int) -> Model: + """Baseline model with `GRID_CONFIG` and no fixed/wage/derived params.""" + return _create_baseline_model( + n_subjects=n_subjects, + fixed_params=None, + wage_params=None, + derived_categoricals=None, + grid_config=GRID_CONFIG, + pref_type_grid=None, + ) + + +def make_aca_model(*, n_subjects: int, policy: PolicyVariant) -> Model: + """ACA model with `GRID_CONFIG` and no fixed/wage/derived params.""" + return _create_aca_model( + n_subjects=n_subjects, + policy=policy, + fixed_params=None, + wage_params=None, + derived_categoricals=None, + grid_config=GRID_CONFIG, + ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index adafd66..5e5a68d 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,7 +1,9 @@ """Integration test: the benchmark-sized baseline solves + simulates end-to-end.""" import pytest +from lcm import DiscreteGrid +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -12,7 +14,10 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model(n_subjects=n_subjects) + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 @@ -31,3 +36,43 @@ def test_benchmark_model_simulates_end_to_end() -> None: # Period 0 rows reflect initial conditions — no NaN in continuous states. period_0 = df.loc[df["period"] == 0] assert not period_0[["assets", "aime", "value"]].isna().any().any() + + +@pytest.mark.long_running +def test_benchmark_simulate_obeys_borrowing_constraint() -> None: + """`consumption <= cash_on_hand + transfers` holds for every alive row. + + The simulator only ever picks feasible actions — the borrowing + constraint must hold post-hoc on the simulated panel. A regression + that drops the constraint from a regime, replaces transfers with + something looser, or lets an action grid skip the floor would + surface as a row with `consumption > cash_on_hand + transfers`. + """ + n_subjects = 4 + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, params = get_benchmark_params(model=model) + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + check_initial_conditions=False, + ) + + df = result.to_dataframe(additional_targets=["cash_on_hand", "transfers"]) + alive = df.loc[df["regime"] != "dead"].copy() + slack = (alive["cash_on_hand"] + alive["transfers"]) - alive["consumption"] + # Non-negative within fp64 tolerance; allow 1e-6 of the magnitude scale + # to absorb the float64 rounding budget. + assert (slack >= -1e-6).all(), ( + f"borrowing_constraint violated on " + f"{int((slack < -1e-6).sum())} row(s); " + f"min slack = {slack.min():.6g}" + ) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index bd88045..47aeb6a 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -8,9 +8,11 @@ """ import jax.numpy as jnp +from lcm import DiscreteGrid from lcm.simulation.initial_conditions import validate_initial_conditions from aca_model.agent.assets_and_income import borrowing_constraint +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -18,27 +20,40 @@ ) -def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() -> None: - """At `cash_on_hand = -$1M` (fp32), `c = c_floor` remains a feasible choice. +def test_borrowing_constraint_admits_consumption_at_post_transfer_resources() -> None: + """`consumption == cash_on_hand + transfers` is feasible by equality.""" + cash_on_hand = jnp.asarray(-50_000.0) + transfers = jnp.asarray(55_000.0) + consumption = cash_on_hand + transfers - Computing `cash_on_hand + transfers` directly suffers float32 catastrophic - cancellation: `(-1e6) + (c_floor + 1e6)` loses ~0.1 of precision, enough - to wipe out the `c == c_floor` boundary. The constraint must use the - algebraically equivalent but numerically stable `max(cash_on_hand, floor)` - form. - """ - consumption_floor = 5_000.0 admitted = bool( borrowing_constraint( - consumption=jnp.float32(consumption_floor), - cash_on_hand=jnp.float32(-1_000_000.0), - consumption_floor=consumption_floor, - equivalence_scale=jnp.float32(1.0), + consumption=consumption, + cash_on_hand=cash_on_hand, + transfers=transfers, ) ) assert admitted +def test_borrowing_constraint_rejects_consumption_above_post_transfer_resources() -> ( + None +): + """`consumption > cash_on_hand + transfers` is rejected.""" + cash_on_hand = jnp.asarray(-50_000.0) + transfers = jnp.asarray(55_000.0) + consumption = cash_on_hand + transfers + 1.0 + + admitted = bool( + borrowing_constraint( + consumption=consumption, + cash_on_hand=cash_on_hand, + transfers=transfers, + ) + ) + assert not admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. @@ -47,7 +62,10 @@ def test_extreme_negative_assets_subject_passes_validation() -> None: floor / transfer system absorbs them, with `c = c_floor` always feasible. """ n_subjects = 1 - model = create_benchmark_model(n_subjects=n_subjects) + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 75a87d9..fca2ef6 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -3,17 +3,32 @@ from collections.abc import Mapping import pytest +from helpers.model import make_aca_model, make_baseline_model from aca_model.aca import health_insurance as aca_hi from aca_model.aca.health_insurance import PolicyVariant -from aca_model.aca.model import create_model as create_aca_model -from aca_model.aca.regimes import build_all_regimes as build_aca_regimes -from aca_model.baseline.model import create_model +from aca_model.aca.regimes import build_all_regimes as _build_aca_regimes from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime from aca_model.baseline.regimes._common import MAX_CONSUMPTION, build_grids +from aca_model.config import GRID_CONFIG -_GRIDS = build_grids() + +def build_aca_regimes(policy: PolicyVariant) -> dict: + return _build_aca_regimes( + policy=policy, + grid_config=GRID_CONFIG, + fixed_params=None, + wage_params=None, + ) + + +_GRIDS = build_grids( + grid_config=GRID_CONFIG, + fixed_params=None, + wage_params=None, + pref_type_grid=None, +) def build_regime(name: str): @@ -21,24 +36,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +185,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model(n_subjects=1) + model = make_aca_model(n_subjects=1, policy=PolicyVariant.ACA) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +226,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(n_subjects=1, policy=policy) + model = make_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,10 +266,10 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 def test_max_consumption_attached_from_canonical_constant() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert model.max_consumption == MAX_CONSUMPTION diff --git a/tests/test_social_security.py b/tests/test_social_security.py index d612f7d..90c5128 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -5,11 +5,11 @@ import jax.numpy as jnp import numpy as np +from helpers.social_security import compute_di_dropout_scale, compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 488df32..ef09775 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -5,11 +5,11 @@ """ import jax.numpy as jnp +from helpers.social_security import compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from helpers.social_security import compute_pia_table ATOL = 0.01 From 9d59174143dc065e791da909c18db15ca0856aac Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 06:18:40 +0200 Subject: [PATCH 38/41] borrowing_constraint: restore max() form for kink-stability at extreme cash MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `consumption <= cash_on_hand + transfers` form (algebraically identical to `consumption <= max(cash_on_hand, floor)`) rounds short by sub-ULP at extreme `|cash_on_hand|` ~ 1e6 — for HRS-bottom-coded subjects at `assets=-$1{,}000{,}000$`, the additive RHS comes in at `floor - 5.7e-11` (fp64), flipping the kink-boundary `<=` for the lowest consumption gridpoint. Production task_simulate_aca_no_mandate on HPC fails at validate_initial_conditions for those subjects. The `max(cash_on_hand, floor)` form has no cancellation and returns `floor` exactly when `cash_on_hand < floor`. This is a general floating-point precision concern at extreme operands, not an fp32-specific workaround. Docstring updated accordingly. Reverts the signature back to `(consumption, cash_on_hand, consumption_floor, equivalence_scale)`. Tests: - `test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash`: unit-level reproducer of the production failure — passes only with the `max` form. - The two new `_at_floor` / `_above_post_transfer_resources` unit tests switch back to the new signature. - `test_benchmark_simulate_obeys_borrowing_constraint`: post-hoc check uses `max(cash_on_hand, floor)` rather than `cash_on_hand + transfers` (the additive form has the same sub-ULP issue and would spuriously trip on the same rows). --- src/aca_model/agent/assets_and_income.py | 15 +++++- tests/test_benchmark.py | 29 +++++++---- .../test_initial_conditions_extreme_assets.py | 50 +++++++++++++++---- 3 files changed, 71 insertions(+), 23 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index c07fdb4..b0ee689 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -85,10 +85,20 @@ def next_assets_terminal( def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, - transfers: FloatND, + consumption_floor: float, + equivalence_scale: FloatND, ) -> BoolND: """Consumption cannot exceed post-transfer resources. + Post-transfer resources are `max(cash_on_hand, consumption_floor * + equivalence_scale)`: the transfer system tops `cash_on_hand` to the + floor when below, otherwise resources are unchanged. The algebraic + identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`; + the `max` form is preferred because the additive form rounds to + `floor + ε` (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which + flips the kink-boundary comparison for HRS-bottom-coded subjects at + `assets=-$1{,}000{,}000`. The `max` form returns `floor` exactly. + `pension_assets_adjustment` is excluded from the constraint: it can be negative (e.g. when the imputation overstates next-period pension wealth at a cross-HIS transition), and including it here can leave @@ -96,4 +106,5 @@ def borrowing_constraint( enters `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - return consumption <= cash_on_hand + transfers + floor = consumption_floor * equivalence_scale + return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 5e5a68d..8b1ed88 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -40,14 +40,21 @@ def test_benchmark_model_simulates_end_to_end() -> None: @pytest.mark.long_running def test_benchmark_simulate_obeys_borrowing_constraint() -> None: - """`consumption <= cash_on_hand + transfers` holds for every alive row. + """`consumption <= max(cash_on_hand, floor)` holds for every alive row. The simulator only ever picks feasible actions — the borrowing constraint must hold post-hoc on the simulated panel. A regression - that drops the constraint from a regime, replaces transfers with + that drops the constraint from a regime, replaces the floor with something looser, or lets an action grid skip the floor would - surface as a row with `consumption > cash_on_hand + transfers`. + surface as a row with `consumption > max(cash_on_hand, floor)`. + + The constraint's RHS is `max(cash_on_hand, floor)` rather than + `cash_on_hand + transfers`: the additive form rounds short by + sub-ULP at extreme `|cash_on_hand|`, so the post-hoc check would + also flip on the same kink. """ + import numpy as np + n_subjects = 4 model = create_benchmark_model( n_subjects=n_subjects, @@ -66,13 +73,15 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: check_initial_conditions=False, ) - df = result.to_dataframe(additional_targets=["cash_on_hand", "transfers"]) + df = result.to_dataframe( + additional_targets=["cash_on_hand", "equivalence_scale"] + ) alive = df.loc[df["regime"] != "dead"].copy() - slack = (alive["cash_on_hand"] + alive["transfers"]) - alive["consumption"] - # Non-negative within fp64 tolerance; allow 1e-6 of the magnitude scale - # to absorb the float64 rounding budget. - assert (slack >= -1e-6).all(), ( - f"borrowing_constraint violated on " - f"{int((slack < -1e-6).sum())} row(s); " + consumption_floor = float(params["consumption_floor"]) + floor = consumption_floor * alive["equivalence_scale"].to_numpy() + rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) + slack = rhs - alive["consumption"].to_numpy() + assert (slack >= 0).all(), ( + f"borrowing_constraint violated on {int((slack < 0).sum())} row(s); " f"min slack = {slack.min():.6g}" ) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 47aeb6a..6078547 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -20,17 +20,18 @@ ) -def test_borrowing_constraint_admits_consumption_at_post_transfer_resources() -> None: - """`consumption == cash_on_hand + transfers` is feasible by equality.""" - cash_on_hand = jnp.asarray(-50_000.0) - transfers = jnp.asarray(55_000.0) - consumption = cash_on_hand + transfers +def test_borrowing_constraint_admits_consumption_at_floor() -> None: + """`consumption == consumption_floor` at the kink is feasible by equality.""" + consumption_floor = 5_000.0 + equivalence_scale = jnp.asarray(1.0) + cash_on_hand = jnp.asarray(-50_000.0) # below floor — RHS = floor admitted = bool( borrowing_constraint( - consumption=consumption, + consumption=jnp.asarray(consumption_floor), cash_on_hand=cash_on_hand, - transfers=transfers, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, ) ) assert admitted @@ -39,21 +40,48 @@ def test_borrowing_constraint_admits_consumption_at_post_transfer_resources() -> def test_borrowing_constraint_rejects_consumption_above_post_transfer_resources() -> ( None ): - """`consumption > cash_on_hand + transfers` is rejected.""" + """`consumption > max(cash_on_hand, floor)` is rejected.""" + consumption_floor = 5_000.0 + equivalence_scale = jnp.asarray(1.0) cash_on_hand = jnp.asarray(-50_000.0) - transfers = jnp.asarray(55_000.0) - consumption = cash_on_hand + transfers + 1.0 + consumption = jnp.asarray(consumption_floor + 1.0) admitted = bool( borrowing_constraint( consumption=consumption, cash_on_hand=cash_on_hand, - transfers=transfers, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, ) ) assert not admitted +def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> None: + """The kink-boundary check survives sub-ULP rounding at `|cash_on_hand| ~ 1e6`. + + Reproduces the production failure mode at `assets=-$1{,}000{,}000$` (HRS + bottom-code): the algebraically equivalent `cash_on_hand + transfers` + form rounds to `floor - 5.7e-11` at fp64, flipping `consumption <= ...` + for the lowest consumption gridpoint. The `max(cash_on_hand, floor)` + form returns `floor` exactly. + """ + consumption_floor = 1597.0921419521899 # production value + equivalence_scale = jnp.asarray(1.0) + cash_on_hand = jnp.asarray(-1_000_000.0) + consumption = jnp.asarray(consumption_floor) # lowest grid point + + admitted = bool( + borrowing_constraint( + consumption=consumption, + cash_on_hand=cash_on_hand, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, + ) + ) + assert admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From 67edfe0f54a305c23297f17ec53aee07b7d90496 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 08:00:39 +0200 Subject: [PATCH 39/41] consumption_grid: pin first gridpoint to consumption_floor exactly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `jnp.geomspace(consumption_floor, max_consumption, num=n)` returns `consumption_floor * r^0 == consumption_floor` mathematically, but some XLA backends drift the first point by sub-ULP. CUDA at n=70 produces `consumption_floor + 2.27e-13`. The borrowing_constraint compares `consumption[0]` against `max(cash_on_hand, consumption_floor)` and any positive drift above `consumption_floor` flips the kink- boundary `<=` for subjects with very negative cash — explaining the HPC-only `task_simulate` failures (~250 subjects) that didn't reproduce on CPU. Pin the first gridpoint back to `consumption_floor` after geomspace. The same drift exists at the upper end (`pts[-1] != max_consumption` exactly) but doesn't flip any constraint comparison, so it's left alone. `tests/test_consumption_grid.py` parametrises the invariant over `n_points = 5, 16, 64, 70, 100` so a future XLA / JAX upgrade that introduces drift at any of these counts surfaces here rather than at `validate_initial_conditions` on HPC. --- src/aca_model/consumption_grid.py | 16 +++++++++-- tests/test_consumption_grid.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 tests/test_consumption_grid.py diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index 7123c1f..7e004fa 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -69,5 +69,17 @@ def _compute_consumption_points( max_consumption: float, n_points: int, ) -> Array: - """Return log-spaced consumption gridpoints from floor to max.""" - return jnp.geomspace(consumption_floor, max_consumption, num=n_points) + """Return log-spaced consumption gridpoints from floor to max. + + `jnp.geomspace` computes intermediate points as `start * r^i` with + `r = (stop/start)^(1/(n-1))`; the first point is `start * r^0`, + which is `start` mathematically but can be off by sub-ULP under + some XLA backends (CUDA + 70 points: `start + 2.27e-13`). The + borrowing constraint compares the first action against + `max(cash_on_hand, consumption_floor)`, and any positive drift + above `consumption_floor` flips the kink-boundary `<=` for + subjects with very negative cash. Pin the first element back to + `consumption_floor` exactly. + """ + pts = jnp.geomspace(consumption_floor, max_consumption, num=n_points) + return pts.at[0].set(consumption_floor) diff --git a/tests/test_consumption_grid.py b/tests/test_consumption_grid.py new file mode 100644 index 0000000..40b7caa --- /dev/null +++ b/tests/test_consumption_grid.py @@ -0,0 +1,48 @@ +"""Consumption-grid invariants required by the borrowing constraint. + +The borrowing constraint in `agent.assets_and_income.borrowing_constraint` +compares the lowest consumption action against +`max(cash_on_hand, consumption_floor * equivalence_scale)`. For subjects +with cash below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, +moderate-negative-asset retirees etc.) this RHS collapses to exactly +`consumption_floor` for singles. The constraint is feasible iff the +lowest consumption gridpoint is `<= consumption_floor`. + +`jnp.geomspace(start, stop, num=n)` returns `start * r^i` with +`r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first +point equals `start`, but XLA backends can drift by sub-ULP for some +`(start, stop, n)` combinations (observed: CUDA, n=70, drift +2.27e-13). +A positive drift above `consumption_floor` flips the kink-boundary `<=` +and rejects every action for those subjects. + +`_compute_consumption_points` therefore pins the first point back to +`consumption_floor` after `geomspace`. Test that invariant directly. +""" + +import jax.numpy as jnp +import pytest + +from aca_model.consumption_grid import _compute_consumption_points + + +@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) +def test_compute_consumption_points_first_equals_floor_exactly(n_points: int) -> None: + """The first gridpoint equals `consumption_floor` exactly under any `n_points`.""" + consumption_floor = 1597.0921419521899 # production value + pts = _compute_consumption_points( + consumption_floor=consumption_floor, + max_consumption=300_000.0, + n_points=n_points, + ) + assert float(pts[0]) == consumption_floor + + +def test_compute_consumption_points_strictly_increasing() -> None: + """Gridpoints are strictly increasing — no kink-pinning ties.""" + pts = _compute_consumption_points( + consumption_floor=1597.0921419521899, + max_consumption=300_000.0, + n_points=70, + ) + diffs = jnp.diff(pts) + assert bool((diffs > 0).all()) From d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 17:55:40 +0200 Subject: [PATCH 40/41] ci: bump pylcm pin to 2f486dc Sweeps in the dtype-barrier polish, simulate AOT-during-solve, and the persistence/benchmark fixes from feat/canonical-float-dtype. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 565245d..fa0891f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@e4cae2aa57d4bf568b8ebbade55d44571e3a086f" + git+https://github.com/OpenSourceEconomics/pylcm.git@2f486dc36425ca6339a36cc8214ab4aef1d85df2" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From 9dd1e2f40706f2dcdc2e9938797af42f3f8d0b23 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 18:00:25 +0200 Subject: [PATCH 41/41] ci: bump pylcm pin to 61c2436 Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fa0891f..1f1ab53 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@2f486dc36425ca6339a36cc8214ab4aef1d85df2" + git+https://github.com/OpenSourceEconomics/pylcm.git@61c2436b67ecd9df1c70e80b770be77681c5df63" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest