diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 67c82fa..ff5244e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,10 +26,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm + - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@main" + git+https://github.com/OpenSourceEconomics/pylcm.git@00f3b4a" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e36542..f3188ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: check-hooks-apply - id: check-useless-excludes - repo: https://github.com/tox-dev/pyproject-fmt - rev: v2.19.0 + rev: v2.21.1 hooks: - id: pyproject-fmt - repo: https://github.com/lyz-code/yamlfix @@ -47,7 +47,7 @@ repos: hooks: - id: yamllint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.6 + rev: v0.15.12 hooks: - id: ruff-check args: diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 8b4507a..b76adc6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -11,36 +11,43 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.baseline.regimes._common import MAX_CONSUMPTION +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, - policy: PolicyVariant = PolicyVariant.ACA, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, - grid_config: GridConfig = GRID_CONFIG, + n_subjects: int, + policy: PolicyVariant, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + derived_categoricals: Mapping[str, DiscreteGrid] | None, + grid_config: GridConfig, ) -> Model: """Create an ACA policy variant model. Args: - policy: Which ACA policy combination to apply. - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. + policy: Which ACA policy combination to apply (e.g. + `PolicyVariant.ACA`). + fixed_params: Parameters to fix at model creation time, or `None` + to skip. Fixed params are partialled into compiled functions + and removed from the params template. Pass data-derived + constants here; only estimation parameters should go through + `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values. + Not routed to the pylcm Model. `None` skips the floor sizing. + derived_categoricals: Extra categorical mappings for derived + variables not in the model's state/action grids, or `None`. + Needed when `fixed_params` contains `pd.Series` indexed by DAG + function outputs. + grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for + production values or `BENCHMARK_GRID_CONFIG` for the + fast-but-structurally-faithful benchmark. Returns: pylcm Model with ACA-specific function overrides. @@ -58,11 +65,22 @@ def create_model( wage_params=wage_params, ) - return Model( + # See `baseline.model.create_model` for why `target_his` is declared + # as a base-layer derived categorical. + base_derived: dict[str, DiscreteGrid] = { + "target_his": DiscreteGrid(HealthInsuranceState), + } + if derived_categoricals is not None: + base_derived.update(derived_categoricals) + + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", fixed_params=fixed_params or {}, - derived_categoricals=derived_categoricals, + derived_categoricals=base_derived, + n_subjects=n_subjects, ) + model.max_consumption = MAX_CONSUMPTION + return model diff --git a/src/aca_model/aca/regimes/__init__.py b/src/aca_model/aca/regimes/__init__.py index 2c143bd..5b9f4bf 100644 --- a/src/aca_model/aca/regimes/__init__.py +++ b/src/aca_model/aca/regimes/__init__.py @@ -10,19 +10,22 @@ from aca_model.aca.regimes._overrides import apply_aca_overrides from aca_model.baseline.regimes import build_all_regimes as baseline_build_all_regimes from aca_model.baseline.regimes._common import REGIME_SPECS -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig def build_all_regimes( - policy: PolicyVariant, - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, + policy: PolicyVariant, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, ) -> dict[str, Regime]: """Build all 19 regimes with ACA policy overrides.""" regimes = baseline_build_all_regimes( - grid_config, fixed_params=fixed_params, wage_params=wage_params + grid_config=grid_config, + fixed_params=fixed_params, + wage_params=wage_params, + pref_type_grid=None, ) result = {} for name, regime in regimes.items(): diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index dfa83ef..b0ee689 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -55,7 +55,7 @@ def next_assets( consumption: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: - """Compute beginning-of-next-period assets. + """Compute beginning-of-next-period assets for non-terminal targets. OOP health costs are deducted here (not from cash_on_hand) so that the consumption choice does not condition on the HCC shock realization. @@ -65,11 +65,46 @@ def next_assets( ) +def next_assets_terminal( + cash_on_hand: FloatND, + transfers: FloatND, + consumption: ContinuousAction, + oop_costs: FloatND, +) -> ContinuousState: + """Compute beginning-of-next-period assets for the dead/terminal target. + + No `pension_assets_adjustment` term: with no future, there is no + next-period pension wealth to impute against. Avoiding the dependency + also keeps the `dead` per-target transition's DAG free of `next_aime` + (which would otherwise need to come from a transition `dead` does not + have, since `aime` is not a state in the terminal regime). + """ + return cash_on_hand + transfers - consumption - oop_costs + + def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, - transfers: FloatND, - pension_assets_adjustment: FloatND, + consumption_floor: float, + equivalence_scale: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources (no borrowing).""" - return consumption <= cash_on_hand + transfers + pension_assets_adjustment + """Consumption cannot exceed post-transfer resources. + + Post-transfer resources are `max(cash_on_hand, consumption_floor * + equivalence_scale)`: the transfer system tops `cash_on_hand` to the + floor when below, otherwise resources are unchanged. The algebraic + identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`; + the `max` form is preferred because the additive form rounds to + `floor + ε` (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which + flips the kink-boundary comparison for HRS-bottom-coded subjects at + `assets=-$1{,}000{,}000`. The `max` form returns `floor` exactly. + + `pension_assets_adjustment` is excluded from the constraint: it can + be negative (e.g. when the imputation overstates next-period pension + wealth at a cross-HIS transition), and including it here can leave + no feasible action at low-asset / mid-AIME corners. The correction + enters `next_assets` instead — a post-decision shift that does not + gate the current consumption choice. + """ + floor = consumption_floor * equivalence_scale + return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 3b0bb5e..28a8367 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -133,8 +133,11 @@ def utility( """Within-period utility: CES aggregator over consumption and leisure. u = utility_scale_factor * ((c/eq_scale)^α * l^(1-α))^(1-γ) / (1-γ) - with log case for γ=1. `consumption_weight`, `coefficient_rra`, and - `utility_scale_factor` are indexed by `pref_type`. + with log case for γ=1. `consumption_weight` and `coefficient_rra` are + pref-type-indexed Series sourced directly from params; `utility_scale_factor` + is a regime-function output (already a per-cell scalar — must NOT be + re-indexed by pref_type, see `aca_model.agent.preferences.utility_scale_factor` + for why). """ alpha = consumption_weight[pref_type] gamma = coefficient_rra[pref_type] @@ -147,7 +150,7 @@ def utility( jnp.log(composite), composite**one_minus_gamma / one_minus_gamma, ) - return u * utility_scale_factor[pref_type] + return u * utility_scale_factor def discount_factor( @@ -164,6 +167,7 @@ def discount_factor( def utility_scale_factor( + pref_type: DiscreteState, average_consumption: float, consumption_weight: FloatND, coefficient_rra: FloatND, @@ -174,26 +178,29 @@ def utility_scale_factor( reference_age: int, scale_reference_age: int, ) -> FloatND: - """Compute scale factor so utility is approximately 1 at typical values. - - Uses leisure at `scale_reference_age` when working `scale_reference_hours` - (after fixed costs) and average consumption. Returns one scale per - preference type, indexed by pref_type. + """Compute the scale factor so utility is approximately 1 at typical values. + + Returns the scalar for the cell's `pref_type`. Mirrors the `discount_factor` + pattern: take the state as input, return a per-cell scalar. Registering this + as a regime function and then doing `utility_scale_factor[pref_type]` in a + downstream consumer is invalid — pylcm broadcasts function outputs to + per-cell scalars before consumption, and the validator in + `lcm.regime_building.validation` raises on that clash. """ + alpha = consumption_weight[pref_type] + gamma = coefficient_rra[pref_type] age_offset = scale_reference_age - reference_age average_leisure = ( time_endowment - scale_reference_hours - (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset) ) - u_cons = average_consumption**consumption_weight - u_leisure = average_leisure ** (1.0 - consumption_weight) + u_cons = average_consumption**alpha + u_leisure = average_leisure ** (1.0 - alpha) - one_minus_gamma = jnp.where( - jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra - ) + one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) raw = jnp.where( - jnp.isclose(coefficient_rra, 1.0), + jnp.isclose(gamma, 1.0), jnp.log(u_cons * u_leisure), (u_cons * u_leisure) ** one_minus_gamma / one_minus_gamma, ) @@ -237,8 +244,9 @@ def bequest( """Bequest function for terminal/dead states. bequest = scale * bwt * (max(0,a) + shifter)^(α*(1-γ)) / (1-γ) - `consumption_weight`, `coefficient_rra`, and `utility_scale_factor` - are indexed by `pref_type`. + `consumption_weight` and `coefficient_rra` are pref-type-indexed Series + from params; `utility_scale_factor` is a regime-function output (already a + per-cell scalar — must NOT be re-indexed by pref_type). """ alpha = consumption_weight[pref_type] gamma = coefficient_rra[pref_type] @@ -250,4 +258,4 @@ def bequest( jnp.log(assets_shifted), assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, ) - return val * scaled_bequest_weight * utility_scale_factor[pref_type] + return val * scaled_bequest_weight * utility_scale_factor diff --git a/src/aca_model/baseline/health_insurance.py b/src/aca_model/baseline/health_insurance.py index 741d160..3732d6d 100644 --- a/src/aca_model/baseline/health_insurance.py +++ b/src/aca_model/baseline/health_insurance.py @@ -246,6 +246,29 @@ def is_medicaid_eligible(is_ssi_eligible: BoolND) -> BoolND: return is_ssi_eligible +def target_his( + his: IntND, + labor_supply: DiscreteAction, + is_medicaid_eligible: BoolND, +) -> IntND: + """Return the HIS class of the surviving target regime. + + Mirrors the cross-HIS branches inside `_make_transition_canwork` (retiree, + tied, nongroup): tied agents who stop working become nongroup, and + Medicaid-eligible agents are overridden to nongroup. Used by + `imputed_pension_wealth_next_period` to look up next-period imputation + coefficients at the target's HIS. + """ + tied_to_ng = (his == HealthInsuranceState.tied) & ( + labor_supply == LaborSupply.do_not_work + ) + return jnp.where( + tied_to_ng | is_medicaid_eligible, + HealthInsuranceState.nongroup, + his, + ).astype(jnp.int32) + + def oop_with_medicaid( primary_oop: FloatND, is_medicaid_eligible: BoolND, diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index a886495..1185eeb 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -5,7 +5,7 @@ Usage: from aca_model.baseline.model import create_model - model = create_model() + model = create_model(n_subjects=...) params = get_default_params() V = model.solve(params) """ @@ -15,39 +15,45 @@ from lcm import AgeGrid, DiscreteGrid, Model +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId, build_all_regimes -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.baseline.regimes._common import MAX_CONSUMPTION +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, - grid_config: GridConfig = GRID_CONFIG, - pref_type_grid: DiscreteGrid | None = None, + n_subjects: int, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + derived_categoricals: Mapping[str, DiscreteGrid] | None, + grid_config: GridConfig, + pref_type_grid: DiscreteGrid | None, ) -> Model: """Create the baseline structural retirement model. Args: - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. + fixed_params: Parameters to fix at model creation time, or `None` + to skip. Fixed params are partialled into compiled functions + and removed from the params template. Pass data-derived + constants here; only estimation parameters should go through + `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values; pass `BENCHMARK_GRID_CONFIG` for a fast-but-structurally- - faithful benchmark. - pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. - Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to - substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. + Not routed to the pylcm Model. `None` skips the floor sizing. + derived_categoricals: Extra categorical mappings for derived + variables not in the model's state/action grids, or `None`. + Needed when `fixed_params` contains `pd.Series` indexed by DAG + function outputs. + grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for + production values or `BENCHMARK_GRID_CONFIG` for the + fast-but-structurally-faithful benchmark. + pref_type_grid: Pref-type `DiscreteGrid`, or `None` to use + `DiscreteGrid(PrefType)`. Pass a custom grid (e.g. with a + `DispatchStrategy.PARTITION_SCAN` strategy) to substitute the + production layout. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -60,17 +66,33 @@ def create_model( step="Y", ) regimes = build_all_regimes( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, ) - return Model( + # `target_his` is a DAG output of `health_insurance.target_his` (set on + # nongroup/tied/retiree regimes). The pension imputation correction + # (`imputed_pension_wealth_next_period`) indexes shifted arrays by + # `arr[period, target_his]`; pylcm needs the categorical declared so + # `pd.Series` fixed_params with a `target_his` index level resolve. + base_derived: dict[str, DiscreteGrid] = { + "target_his": DiscreteGrid(HealthInsuranceState), + } + if derived_categoricals is not None: + base_derived.update(derived_categoricals) + + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", fixed_params=fixed_params or {}, - derived_categoricals=derived_categoricals, + derived_categoricals=base_derived, + n_subjects=n_subjects, ) + # See `MAX_CONSUMPTION` in `baseline.regimes._common` for why this + # rides on the Model instance instead of `fixed_params`. + model.max_consumption = MAX_CONSUMPTION + return model diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index a0eaf9e..02e8a05 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -25,7 +25,7 @@ build_dead_regime, build_grids, ) -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig __all__ = [ "REGIME_SPECS", @@ -58,11 +58,11 @@ def build_regime(name: str, grids: Grids) -> Regime: def build_all_regimes( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + pref_type_grid: DiscreteGrid | None, ) -> dict[str, Regime]: """Build all 19 baseline regimes (18 non-terminal + dead). @@ -71,10 +71,11 @@ def build_all_regimes( either being `None` keeps the corresponding static fallback. `pref_type_grid` lets callers inject a compact or partition-lifted `DiscreteGrid(...)` (e.g. the benchmark uses a 2-type - `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`). + `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`); `None` + falls back to `DiscreteGrid(PrefType)`. """ grids = build_grids( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 688a504..25347c6 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -36,7 +36,7 @@ from aca_model.agent.preferences import PrefType from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -194,21 +194,25 @@ class Grids: # bend points (0 → kink_0 → kink_1 → kink_2). Total = 32. _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) -# Consumption grid: log-spaced from the lower bound of the -# `consumption_floor` parameter (BOUNDS in task_estimate_parameters) -# up to a high value that brackets the unconstrained optimum for the -# richest agents in the state space. Mirrors the struct-ret design -# (concentrate gridpoints where CRRA curvature is highest). -_CONSUMPTION_GRID_START: float = 100.0 -_CONSUMPTION_GRID_STOP: float = 300_000.0 + +MAX_CONSUMPTION: float = 300_000.0 +"""Upper bound of the runtime consumption grid in $/year. + +Lives here next to the other grid bounds (assets `stop=500_000.0`, +AIME `stop=8_000.0`). The `create_model` factories attach this onto +`model.max_consumption` so `inject_consumption_points` can read it +back at runtime. Routed via a Model attribute rather than +`fixed_params` because pylcm validates `fixed_params` keys against +the regime DAG and rejects entries no function consumes. +""" def build_grids( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + pref_type_grid: DiscreteGrid | None, ) -> Grids: """Build continuous-state/action grids from a `GridConfig`. @@ -273,14 +277,7 @@ def build_grids( ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), consumption=IrregSpacedGrid( - points=tuple( - float(c) - for c in np.geomspace( - _CONSUMPTION_GRID_START, - _CONSUMPTION_GRID_STOP, - num=grid_config.n_consumption_gridpoints, - ) - ), + n_points=grid_config.n_consumption_gridpoints, ), wage_res=wage_res, hcc_persistent=hcc_persistent, @@ -302,7 +299,10 @@ def _build_aime_grid( """ if fixed_params is None or "pia_aime_grid" not in fixed_params: return LinSpacedGrid( - start=0.0, stop=8_000.0, n_points=grid_config.n_aime_gridpoints + start=0.0, + stop=8_000.0, + n_points=grid_config.n_aime_gridpoints, + batch_size=grid_config.n_aime_batch_size, ) kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])] pieces = ( @@ -310,7 +310,9 @@ def _build_aime_grid( Piece(interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]), Piece(interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]), ) - return PiecewiseLinSpacedGrid(pieces=pieces) + return PiecewiseLinSpacedGrid( + pieces=pieces, batch_size=grid_config.n_aime_batch_size + ) def _has_required_wage_keys(*, wage_params: Mapping[str, Any]) -> bool: @@ -642,7 +644,7 @@ def build_state_transitions(spec: dict[str, str]) -> dict: """Build the state transitions dict for a non-dead regime.""" transitions: dict = {} transitions["health"] = _build_per_target_health(spec) - transitions["assets"] = assets_and_income.next_assets + transitions["assets"] = _build_per_target_next_assets(spec) transitions["pref_type"] = None transitions["aime"] = ( social_security.next_aime @@ -659,6 +661,34 @@ def build_state_transitions(spec: dict[str, str]) -> dict: return transitions +def _build_per_target_next_assets(spec: dict[str, str]) -> dict: + """Build per-target assets transitions. + + The `dead` target uses `next_assets_terminal` (no + `pension_assets_adjustment`), so the dead per-target DAG does not + pull in the `next_aime`-dependent imputation chain — `dead` has no + `aime` state and pylcm cannot resolve `next_aime` there. Non-dead + targets use the full `next_assets` with the pension correction. + """ + targets = precompute_targets(spec) + id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + + result: dict = {} + seen_ids: set[int] = set() + + for target_id in targets.values(): + if target_id in seen_ids: + continue + seen_ids.add(target_id) + target_name = id_to_name.get(target_id) + if target_name is None: + continue + result[target_name] = assets_and_income.next_assets + + result["dead"] = assets_and_income.next_assets_terminal + return result + + def _build_per_target_health(spec: dict[str, str]) -> dict: """Build per-target health transitions. diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index 5cdb6dc..7ee82ff 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -99,6 +99,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index ac76bfd..a941fa9 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -109,6 +109,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index 5d59274..4351cf5 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -83,6 +83,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5a822c9..19416f2 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -44,6 +44,7 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.model import create_model from aca_model.config import BENCHMARK_GRID_CONFIG +from aca_model.consumption_grid import inject_consumption_points _PARAMS_FILE = ( Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl" @@ -55,6 +56,7 @@ "good_health": DiscreteGrid(GoodHealth), "is_married": DiscreteGrid(IsMarried), "his": DiscreteGrid(HealthInsuranceState), + "target_his": DiscreteGrid(HealthInsuranceState), "pref_type": DiscreteGrid(BenchmarkPrefType), } @@ -70,46 +72,137 @@ ) -def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Model: +def create_benchmark_model( + *, + n_subjects: int, + pref_type_grid: DiscreteGrid, +) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. The benchmark uses a 2-type `BenchmarkPrefType`. No `batch_size != 0` on any grid (continuous grids inherit - `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0`). + `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0` and + `n_aime_batch_size = 0`). Args: - pref_type_grid: Override for the pref_type grid. Default is a plain - `DiscreteGrid(BenchmarkPrefType)` (fused vmap). Pass + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the + first matching `simulate(...)` call AOT-compiles all simulate + functions for that batch shape. + pref_type_grid: Pref-type grid. Pass `DiscreteGrid(BenchmarkPrefType)` + for plain fused-vmap, or `DiscreteGrid(BenchmarkPrefType, dispatch=DispatchStrategy.PARTITION_SCAN)` - (or `PARTITION_VMAP`) to get the partition-lifted kernel — the + (or `PARTITION_VMAP`) for the partition-lifted kernel — the recommended production setting for aca-model at scale, but only supported on pylcm versions that expose `DispatchStrategy`. """ - if pref_type_grid is None: - pref_type_grid = DiscreteGrid(BenchmarkPrefType) - fixed_params, _ = get_benchmark_params() + fixed_params, _ = get_benchmark_params(model=None) return create_model( grid_config=BENCHMARK_GRID_CONFIG, fixed_params=fixed_params, + wage_params=None, derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, + n_subjects=n_subjects, ) -def get_benchmark_params() -> tuple[dict[str, Any], dict[str, Any]]: +def get_benchmark_params( + *, model: Model | None +) -> tuple[dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, params)` snapshot. Pref-type-indexed `pd.Series` in `params` are truncated to `_N_BENCHMARK_PREF_TYPES` rows so they line up with `BenchmarkPrefType`'s categories. + + When `model` is provided, consumption gridpoints are injected into + `params` for each regime that declares `consumption` as an + `IrregSpacedGrid` with runtime-supplied points. The lower bound is + read from `params["consumption_floor"]`. Pass `model=None` to skip + injection (e.g. when constructing the model with `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) - fixed_params = data["fixed_params"] + fixed_params = { + k: v for k, v in data["fixed_params"].items() if k not in _STALE_FIXED_KEYS + } + fixed_params = _add_shifted_imputation_arrays(fixed_params) params = _truncate_pref_type_indexed(data["params"]) + if model is not None: + params = inject_consumption_points(params=params, model=model) return fixed_params, params +# Keys that the older aca-estimation `_assemble_params.py` wrote into +# `fixed_params` but that the current regime now resolves as a DAG +# function. Drop them on load so pylcm's `_resolve_fixed_params` does +# not reject the snapshot. Regenerating `benchmark_params.pkl` would +# also remove these — the filter is a no-op when the snapshot is fresh. +_STALE_FIXED_KEYS: frozenset[str] = frozenset({"imputed_pension_wealth_next_period"}) + + +# Source → derived key mapping for the 1-period-shifted views of the +# imputation arrays. The current pension correction (`imputed_pension_ +# wealth_next_period`) consumes these. The frozen `benchmark_params.pkl` +# predates aca-data's `_shift_one_period_forward` change, so synthesise +# the shifted views on load. The transformation is deterministic: row +# `period` carries the original at row `period + 1`; the last row holds +# flat. A regenerated snapshot can drop this synthesis (the filter is a +# no-op when the keys already exist). +_SHIFTED_IMPUTATION_KEYS: tuple[str, ...] = ( + "imp_intercept", + "imp_pia_coeff", + "imp_pia_kink_0_coeff", + "imp_pia_kink_1_coeff", + "imp_kink_0", + "imp_kink_1", + "imp_fraction_receiving", + "epdv_constant_pension", +) + + +def _add_shifted_imputation_arrays(fixed_params: dict[str, Any]) -> dict[str, Any]: + """Synthesise `_next_period` views from the source arrays.""" + out = dict(fixed_params) + for key in _SHIFTED_IMPUTATION_KEYS: + next_period_key = f"{key}_next_period" + if next_period_key in out or key not in out: + continue + out[next_period_key] = _shift_one_period_forward(out[key]) + return out + + +def _shift_one_period_forward(sr: pd.Series) -> pd.Series: + """Shift age-axis values forward one position (last row held flat). + + For (age, his)-indexed inputs, also rename the `his` level to + `target_his` so the resulting Series matches the level naming the + consuming `imputed_pension_wealth_next_period` function expects. + """ + if isinstance(sr.index, pd.MultiIndex) and sr.index.names[0] == "age": + n_periods = sr.index.levshape[0] + n_other = int( + np.prod([sr.index.levshape[i] for i in range(1, sr.index.nlevels)]) + ) + values = sr.to_numpy().reshape(n_periods, n_other) + shifted = np.concatenate([values[1:], values[-1:]], axis=0) + new_index = sr.index.rename( + [_rename_his_level(name) for name in sr.index.names] + ) + return pd.Series(shifted.ravel(), index=new_index) + if sr.index.name == "age": + values = sr.to_numpy() + shifted = np.concatenate([values[1:], values[-1:]]) + return pd.Series(shifted, index=sr.index) + msg = f"Unexpected index for _shift_one_period_forward: {sr.index!r}" + raise ValueError(msg) + + +def _rename_his_level(name: str) -> str: + """Rename `his` to `target_his`, leave others alone.""" + return "target_his" if name == "his" else name + + def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. @@ -129,7 +222,7 @@ def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: def get_benchmark_initial_conditions( - *, model: Model, n_subjects: int = 100, seed: int = 42 + *, model: Model, n_subjects: int, seed: int ) -> dict[str, Array]: """Draw random feasible initial conditions across five age-51 regimes. @@ -143,10 +236,14 @@ def get_benchmark_initial_conditions( regime = rng.choice(regime_ids, size=n_subjects).astype(np.int32) # Grid ranges come from any of the five regimes (shared structure). + # Use to_jax() so the helper handles both LinSpacedGrid and + # PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`). ref_regime = model.regimes[_INITIAL_REGIMES[0]] grids = ref_regime.states - assets_lo, assets_hi = grids["assets"].start, grids["assets"].stop - aime_lo, aime_hi = grids["aime"].start, grids["aime"].stop + assets_pts = np.asarray(grids["assets"].to_jax()) + aime_pts = np.asarray(grids["aime"].to_jax()) + assets_lo, assets_hi = float(assets_pts.min()), float(assets_pts.max()) + aime_lo, aime_hi = float(aime_pts.min()), float(aime_pts.max()) hcc_p_pts = np.asarray(grids["hcc_persistent"].to_jax()) hcc_t_pts = np.asarray(grids["hcc_transitory"].to_jax()) wage_res_pts = np.asarray(grids["log_ft_wage_res"].to_jax()) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 1bae45f..37fc0c8 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -33,10 +33,12 @@ class GridConfig: n_wage_res_gridpoints: int = 5 n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 - # `batch_size` on the assets grid: chunked vmap stride for the - # outer state loop. Useful at prod sizes for memory reasons; set - # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 2 + # `batch_size` on the assets / AIME grids: chunked vmap stride for the + # outer state loop. Both partition the per-period Q intermediate so it + # fits in V100 16 GB once we splay across `pref_type`. Set to 0 in + # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. + n_assets_batch_size: int = 1 + n_aime_batch_size: int = 1 MODEL_CONFIG = ModelConfig() @@ -50,4 +52,5 @@ class GridConfig: n_hcc_persistent_gridpoints=3, n_hcc_transitory_gridpoints=3, n_assets_batch_size=0, + n_aime_batch_size=0, ) diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py new file mode 100644 index 0000000..7e004fa --- /dev/null +++ b/src/aca_model/consumption_grid.py @@ -0,0 +1,85 @@ +"""Runtime-supplied gridpoints for the consumption action. + +Consumption is declared as `IrregSpacedGrid(n_points=N)` in +`baseline.regimes._common.build_grids` so the bounds can track +runtime parameters: the lower bound from the per-iteration +`consumption_floor` parameter, the upper bound from +`MAX_CONSUMPTION` in `baseline.regimes._common`, which the +`create_model` factories attach to `model.max_consumption`. +Callers must inject the actual gridpoints into `params` via +`inject_consumption_points` before calling `model.solve()` / +`model.simulate()`. +""" + +from collections.abc import Mapping +from typing import Any + +import jax.numpy as jnp +from jax import Array +from lcm import IrregSpacedGrid, Model + + +def inject_consumption_points( + *, + params: Mapping[str, Any], + model: Model, +) -> dict[str, Any]: + """Inject consumption gridpoints into per-regime params. + + Walks every regime, finds the action whose grid is an + `IrregSpacedGrid` with runtime-supplied points, and writes + `params[regime_name]["consumption"] = {"points": }`. + + Lower bound: `params["consumption_floor"]` (varies per iteration). + Upper bound: `model.max_consumption` (set by the `create_model` + factory from `MAX_CONSUMPTION` in `baseline.regimes._common`). + + Args: + params: Existing params mapping. Returned as a new dict; the input is + not mutated. + model: Model whose regime specs determine which regimes need points. + + Returns: + New params dict with consumption points injected. + """ + consumption_floor = float(params["consumption_floor"]) + max_consumption = float(model.max_consumption) + out: dict[str, Any] = dict(params) + for regime_name, regime in model.regimes.items(): + grid = regime.actions.get("consumption") + if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): + continue + # Runtime-points grids always have `n_points` set (the constructor + # rejects the (points=None, n_points=None) combo); narrow for ty. + assert grid.n_points is not None + points = _compute_consumption_points( + consumption_floor=consumption_floor, + max_consumption=max_consumption, + n_points=grid.n_points, + ) + regime_entry = dict(out.get(regime_name, {})) + regime_entry["consumption"] = {"points": points} + out[regime_name] = regime_entry + return out + + +def _compute_consumption_points( + *, + consumption_floor: float, + max_consumption: float, + n_points: int, +) -> Array: + """Return log-spaced consumption gridpoints from floor to max. + + `jnp.geomspace` computes intermediate points as `start * r^i` with + `r = (stop/start)^(1/(n-1))`; the first point is `start * r^0`, + which is `start` mathematically but can be off by sub-ULP under + some XLA backends (CUDA + 70 points: `start + 2.27e-13`). The + borrowing constraint compares the first action against + `max(cash_on_hand, consumption_floor)`, and any positive drift + above `consumption_floor` flips the kink-boundary `<=` for + subjects with very negative cash. Pin the first element back to + `consumption_floor` exactly. + """ + pts = jnp.geomspace(consumption_floor, max_consumption, num=n_points) + return pts.at[0].set(consumption_floor) diff --git a/src/aca_model/environment/pensions.py b/src/aca_model/environment/pensions.py index a23a800..eef72d4 100644 --- a/src/aca_model/environment/pensions.py +++ b/src/aca_model/environment/pensions.py @@ -4,7 +4,7 @@ """ import jax.numpy as jnp -from lcm.typing import FloatND, IntND, Period +from lcm.typing import ContinuousState, FloatND, IntND, Period def benefit( @@ -164,3 +164,42 @@ def assets_adjustment( * unconditional_survival_prob[period] * (pension_wealth_next_before_adjustment - imputed_pension_wealth_next_period) ) + + +def imputed_pension_wealth_next_period( + next_aime: ContinuousState, + target_his: IntND, + period: Period, + pia_table: FloatND, + pia_aime_grid: FloatND, + imp_intercept_next_period: FloatND, + imp_pia_coeff_next_period: FloatND, + imp_pia_kink_0_coeff_next_period: FloatND, + imp_pia_kink_1_coeff_next_period: FloatND, + imp_kink_0_next_period: FloatND, + imp_kink_1_next_period: FloatND, + imp_fraction_receiving_next_period: FloatND, + epdv_constant_pension_next_period: FloatND, +) -> FloatND: + """Imputed pension wealth at next period using the target regime's HIS. + + Mirrors `benefit` and `wealth` but indexes into 1-period-shifted views + of the imputation arrays so all subscripts use bare-name parameters + (`period`, `target_his`). Inlining is required: pylcm's AST shape + inference inspects the registered function's body and does not trace + through nested calls into `benefit`. + """ + next_pia = jnp.interp(next_aime, pia_aime_grid, pia_table) + + intercept = imp_intercept_next_period[period, target_his] + pia_pred = imp_pia_coeff_next_period[period, target_his] * next_pia + kink_0_adj = imp_pia_kink_0_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_0_next_period[period] + ) + kink_1_adj = imp_pia_kink_1_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_1_next_period[period] + ) + + full_benefit = jnp.maximum(0.0, intercept + pia_pred + kink_0_adj + kink_1_adj) + benefit_next = full_benefit * imp_fraction_receiving_next_period[period] + return benefit_next * epdv_constant_pension_next_period[period] diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1da1dcf --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) diff --git a/tests/helpers/model.py b/tests/helpers/model.py new file mode 100644 index 0000000..930c33e --- /dev/null +++ b/tests/helpers/model.py @@ -0,0 +1,38 @@ +"""Tiny factories that wrap `create_model` with `None` for every optional input. + +Used by tests that don't need fixed params, wage params, or a custom pref-type +grid. These helpers exist so production `create_model` factories can stay +default-free without forcing every test call site to spell out +`fixed_params=None, wage_params=None, ...` six times. +""" + +from lcm import Model + +from aca_model.aca.health_insurance import PolicyVariant +from aca_model.aca.model import create_model as _create_aca_model +from aca_model.baseline.model import create_model as _create_baseline_model +from aca_model.config import GRID_CONFIG + + +def make_baseline_model(*, n_subjects: int) -> Model: + """Baseline model with `GRID_CONFIG` and no fixed/wage/derived params.""" + return _create_baseline_model( + n_subjects=n_subjects, + fixed_params=None, + wage_params=None, + derived_categoricals=None, + grid_config=GRID_CONFIG, + pref_type_grid=None, + ) + + +def make_aca_model(*, n_subjects: int, policy: PolicyVariant) -> Model: + """ACA model with `GRID_CONFIG` and no fixed/wage/derived params.""" + return _create_aca_model( + n_subjects=n_subjects, + policy=policy, + fixed_params=None, + wage_params=None, + derived_categoricals=None, + grid_config=GRID_CONFIG, + ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 72fb473..8b1ed88 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,7 +1,9 @@ """Integration test: the benchmark-sized baseline solves + simulates end-to-end.""" import pytest +from lcm import DiscreteGrid +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -12,8 +14,11 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model() - _, params = get_benchmark_params() + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 ) @@ -31,3 +36,52 @@ def test_benchmark_model_simulates_end_to_end() -> None: # Period 0 rows reflect initial conditions — no NaN in continuous states. period_0 = df.loc[df["period"] == 0] assert not period_0[["assets", "aime", "value"]].isna().any().any() + + +@pytest.mark.long_running +def test_benchmark_simulate_obeys_borrowing_constraint() -> None: + """`consumption <= max(cash_on_hand, floor)` holds for every alive row. + + The simulator only ever picks feasible actions — the borrowing + constraint must hold post-hoc on the simulated panel. A regression + that drops the constraint from a regime, replaces the floor with + something looser, or lets an action grid skip the floor would + surface as a row with `consumption > max(cash_on_hand, floor)`. + + The constraint's RHS is `max(cash_on_hand, floor)` rather than + `cash_on_hand + transfers`: the additive form rounds short by + sub-ULP at extreme `|cash_on_hand|`, so the post-hoc check would + also flip on the same kink. + """ + import numpy as np + + n_subjects = 4 + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, params = get_benchmark_params(model=model) + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + check_initial_conditions=False, + ) + + df = result.to_dataframe( + additional_targets=["cash_on_hand", "equivalence_scale"] + ) + alive = df.loc[df["regime"] != "dead"].copy() + consumption_floor = float(params["consumption_floor"]) + floor = consumption_floor * alive["equivalence_scale"].to_numpy() + rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) + slack = rhs - alive["consumption"].to_numpy() + assert (slack >= 0).all(), ( + f"borrowing_constraint violated on {int((slack < 0).sum())} row(s); " + f"min slack = {slack.min():.6g}" + ) diff --git a/tests/test_consumption_grid.py b/tests/test_consumption_grid.py new file mode 100644 index 0000000..40b7caa --- /dev/null +++ b/tests/test_consumption_grid.py @@ -0,0 +1,48 @@ +"""Consumption-grid invariants required by the borrowing constraint. + +The borrowing constraint in `agent.assets_and_income.borrowing_constraint` +compares the lowest consumption action against +`max(cash_on_hand, consumption_floor * equivalence_scale)`. For subjects +with cash below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, +moderate-negative-asset retirees etc.) this RHS collapses to exactly +`consumption_floor` for singles. The constraint is feasible iff the +lowest consumption gridpoint is `<= consumption_floor`. + +`jnp.geomspace(start, stop, num=n)` returns `start * r^i` with +`r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first +point equals `start`, but XLA backends can drift by sub-ULP for some +`(start, stop, n)` combinations (observed: CUDA, n=70, drift +2.27e-13). +A positive drift above `consumption_floor` flips the kink-boundary `<=` +and rejects every action for those subjects. + +`_compute_consumption_points` therefore pins the first point back to +`consumption_floor` after `geomspace`. Test that invariant directly. +""" + +import jax.numpy as jnp +import pytest + +from aca_model.consumption_grid import _compute_consumption_points + + +@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) +def test_compute_consumption_points_first_equals_floor_exactly(n_points: int) -> None: + """The first gridpoint equals `consumption_floor` exactly under any `n_points`.""" + consumption_floor = 1597.0921419521899 # production value + pts = _compute_consumption_points( + consumption_floor=consumption_floor, + max_consumption=300_000.0, + n_points=n_points, + ) + assert float(pts[0]) == consumption_floor + + +def test_compute_consumption_points_strictly_increasing() -> None: + """Gridpoints are strictly increasing — no kink-pinning ties.""" + pts = _compute_consumption_points( + consumption_floor=1597.0921419521899, + max_consumption=300_000.0, + n_points=70, + ) + diffs = jnp.diff(pts) + assert bool((diffs > 0).all()) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py new file mode 100644 index 0000000..6078547 --- /dev/null +++ b/tests/test_initial_conditions_extreme_assets.py @@ -0,0 +1,118 @@ +"""Subjects at extreme negative assets must clear `validate_initial_conditions`. + +The transfer system (`agent.assets_and_income.transfers`) tops cash-on-hand +to `consumption_floor * equivalence_scale` at any starting state, so the +lowest consumption-grid point is always a feasible action regardless of +how negative starting assets are. The model's constraints — and pylcm's +`validate_initial_conditions` pass — must reflect this. +""" + +import jax.numpy as jnp +from lcm import DiscreteGrid +from lcm.simulation.initial_conditions import validate_initial_conditions + +from aca_model.agent.assets_and_income import borrowing_constraint +from aca_model.agent.preferences import BenchmarkPrefType +from aca_model.benchmark import ( + create_benchmark_model, + get_benchmark_initial_conditions, + get_benchmark_params, +) + + +def test_borrowing_constraint_admits_consumption_at_floor() -> None: + """`consumption == consumption_floor` at the kink is feasible by equality.""" + consumption_floor = 5_000.0 + equivalence_scale = jnp.asarray(1.0) + cash_on_hand = jnp.asarray(-50_000.0) # below floor — RHS = floor + + admitted = bool( + borrowing_constraint( + consumption=jnp.asarray(consumption_floor), + cash_on_hand=cash_on_hand, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, + ) + ) + assert admitted + + +def test_borrowing_constraint_rejects_consumption_above_post_transfer_resources() -> ( + None +): + """`consumption > max(cash_on_hand, floor)` is rejected.""" + consumption_floor = 5_000.0 + equivalence_scale = jnp.asarray(1.0) + cash_on_hand = jnp.asarray(-50_000.0) + consumption = jnp.asarray(consumption_floor + 1.0) + + admitted = bool( + borrowing_constraint( + consumption=consumption, + cash_on_hand=cash_on_hand, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, + ) + ) + assert not admitted + + +def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> None: + """The kink-boundary check survives sub-ULP rounding at `|cash_on_hand| ~ 1e6`. + + Reproduces the production failure mode at `assets=-$1{,}000{,}000$` (HRS + bottom-code): the algebraically equivalent `cash_on_hand + transfers` + form rounds to `floor - 5.7e-11` at fp64, flipping `consumption <= ...` + for the lowest consumption gridpoint. The `max(cash_on_hand, floor)` + form returns `floor` exactly. + """ + consumption_floor = 1597.0921419521899 # production value + equivalence_scale = jnp.asarray(1.0) + cash_on_hand = jnp.asarray(-1_000_000.0) + consumption = jnp.asarray(consumption_floor) # lowest grid point + + admitted = bool( + borrowing_constraint( + consumption=consumption, + cash_on_hand=cash_on_hand, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, + ) + ) + assert admitted + + +def test_extreme_negative_assets_subject_passes_validation() -> None: + """A subject placed at `assets = -1_000_000` clears initial-conditions validation. + + HRS bottom-codes very-large-negative net wealth at exactly $-1{,}000{,}000$. + Such subjects should remain in the simulated population: the consumption + floor / transfer system absorbs them, with `c = c_floor` always feasible. + """ + n_subjects = 1 + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, params = get_benchmark_params(model=model) + + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + initial_conditions = { + **initial_conditions, + "assets": jnp.asarray([-1_000_000.0]), + "regime": jnp.asarray( + [model.regime_names_to_ids["retiree_nomc_inelig_canwork"]], + dtype=jnp.int32, + ), + } + + internal_params = model._process_params(params) # noqa: SLF001 + validate_initial_conditions( + initial_conditions=initial_conditions, + internal_regimes=model.internal_regimes, + regime_names_to_ids=model.regime_names_to_ids, + internal_params=internal_params, + ages=model.ages, + ) diff --git a/tests/test_model_components.py b/tests/test_model_components.py index cbb2f72..9876ac8 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -84,7 +84,7 @@ def test_utility_positive_leisure() -> None: consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -97,7 +97,7 @@ def test_utility_log_case() -> None: consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([1.0, 1.0, 1.0]), equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) composite = 10000.0**0.4 * 3000.0**0.6 expected = jnp.log(composite) @@ -112,7 +112,7 @@ def test_bequest_positive_assets() -> None: scaled_bequest_weight=0.5, consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -125,7 +125,7 @@ def test_bequest_zero_assets() -> None: scaled_bequest_weight=0.5, consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) assert result < 0 # CRRA with γ>1 gives negative values diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 841f2bb..fca2ef6 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -3,17 +3,32 @@ from collections.abc import Mapping import pytest +from helpers.model import make_aca_model, make_baseline_model from aca_model.aca import health_insurance as aca_hi from aca_model.aca.health_insurance import PolicyVariant -from aca_model.aca.model import create_model as create_aca_model -from aca_model.aca.regimes import build_all_regimes as build_aca_regimes -from aca_model.baseline.model import create_model +from aca_model.aca.regimes import build_all_regimes as _build_aca_regimes from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime -from aca_model.baseline.regimes._common import build_grids +from aca_model.baseline.regimes._common import MAX_CONSUMPTION, build_grids +from aca_model.config import GRID_CONFIG -_GRIDS = build_grids() + +def build_aca_regimes(policy: PolicyVariant) -> dict: + return _build_aca_regimes( + policy=policy, + grid_config=GRID_CONFIG, + fixed_params=None, + wage_params=None, + ) + + +_GRIDS = build_grids( + grid_config=GRID_CONFIG, + fixed_params=None, + wage_params=None, + pref_type_grid=None, +) def build_regime(name: str): @@ -21,24 +36,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +185,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model() + model = make_aca_model(n_subjects=1, policy=PolicyVariant.ACA) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +226,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(policy=policy) + model = make_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +266,10 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model() + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 + + +def test_max_consumption_attached_from_canonical_constant() -> None: + model = make_baseline_model(n_subjects=1) + assert model.max_consumption == MAX_CONSUMPTION diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 8c1921c..4ff2266 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -33,6 +33,7 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -43,11 +44,12 @@ def test_utility_scale_factor_crra() -> None: reference_age=REFERENCE_AGE, scale_reference_age=SCALE_REFERENCE_AGE, ) - assert jnp.isclose(result[0], 9_233_279_397_806_166.0, rtol=1e-6) + assert jnp.isclose(result, 9_233_279_397_806_166.0, rtol=1e-6) def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -58,7 +60,7 @@ def test_utility_scale_factor_log() -> None: reference_age=REFERENCE_AGE, scale_reference_age=SCALE_REFERENCE_AGE, ) - assert jnp.isclose(result[0], 0.113_073_257_794_546_72, rtol=1e-6) + assert jnp.isclose(result, 0.113_073_257_794_546_72, rtol=1e-6) # --- scaled_bequest_weight --- @@ -105,6 +107,7 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -129,6 +132,7 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -154,6 +158,7 @@ def test_utility_crra_regression() -> None: def test_utility_married_equivalence() -> None: """Married with equiv-scaled consumption should equal single utility.""" scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -190,6 +195,7 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -222,6 +228,7 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, diff --git a/tests/test_social_security.py b/tests/test_social_security.py index d75e458..90c5128 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -5,11 +5,11 @@ import jax.numpy as jnp import numpy as np +from helpers.social_security import compute_di_dropout_scale, compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 0e77ea5..ef09775 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -5,11 +5,11 @@ """ import jax.numpy as jnp +from helpers.social_security import compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_pia_table ATOL = 0.01