Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4123fe9
Anchor consumption grid lower bound to consumption_floor parameter
hmgaudecker Apr 29, 2026
1342861
Refactor utility_scale_factor to take pref_type, return scalar
hmgaudecker Apr 29, 2026
8cd8e37
Halve aca-model production assets batch size to fit V100 16GB
hmgaudecker Apr 30, 2026
84d484a
Revert assets batch_size halving — V100 OOM was elsewhere
hmgaudecker Apr 30, 2026
f0892ef
Re-halve production assets batch size: V100 still OOMs per-period
hmgaudecker May 1, 2026
08e42cb
config: add n_aime_batch_size to splay AIME outer-loop on V100
hmgaudecker May 1, 2026
e08fc19
consumption_grid: read upper bound from `max_consumption` fixed param
hmgaudecker May 1, 2026
c1ffb2a
create_model: default `max_consumption` into fixed_params
hmgaudecker May 2, 2026
a217687
create_model: forward n_subjects through baseline + aca + benchmark
hmgaudecker May 2, 2026
d1eb320
create_model: require n_subjects (no default)
hmgaudecker May 2, 2026
9e25205
ci: install pylcm from feat/simulate-aot-n-subjects (carries Model.n_…
hmgaudecker May 2, 2026
cdd1016
consumption_grid: max_consumption is a required factory arg, attached…
hmgaudecker May 2, 2026
31a0ad2
Move max_consumption to canonical constant; drop kwarg threading
hmgaudecker May 3, 2026
714fee0
Assets grid: subtract MAX_CONSUMPTION margin from the floor
hmgaudecker May 4, 2026
63d2a38
Revert "Assets grid: subtract MAX_CONSUMPTION margin from the floor"
hmgaudecker May 4, 2026
4ae4446
Wire pension imputation correction (FJ 2011 Appendix A.5)
hmgaudecker May 4, 2026
83f2250
Bump pyproject-fmt v2.19.0 → v2.21.1 and ruff-pre-commit v0.15.6 → v0…
hmgaudecker May 4, 2026
3453080
get_benchmark_params: filter obsolete imputed_pension_wealth_next_per…
hmgaudecker May 4, 2026
b2e90bb
get_benchmark_params: synthesise <key>_next_period shifted views
hmgaudecker May 4, 2026
35eddcc
benchmark: declare target_his as derived categorical
hmgaudecker May 4, 2026
64d6567
_shift_one_period_forward: rename his level to target_his
hmgaudecker May 4, 2026
f09b5e3
Per-target next_assets: dead target uses next_assets_terminal (no pen…
hmgaudecker May 4, 2026
e1a3eb2
create_model: register target_his as derived categorical at base layer
hmgaudecker May 5, 2026
00ee7d2
aca/model.create_model: register target_his at base layer
hmgaudecker May 6, 2026
edfa540
tests: positive regression guard — assets=-$1M passes benchmark valid…
hmgaudecker May 6, 2026
d05df9e
borrowing_constraint: use max(cash_on_hand, floor) to dodge fp32 canc…
hmgaudecker May 6, 2026
4af8359
ci: bump pylcm pin to e4cae2aa (post-#342, post-#340 diagnostic)
hmgaudecker May 6, 2026
0c7f2d5
wip: debug script — cash_on_hand per failing subject
hmgaudecker May 6, 2026
8ffbf5c
wip: fix imports in debug script (broadcast_to_template + ACA_DATA_BLD)
hmgaudecker May 6, 2026
81cca3c
wip: import GRID_CONFIG_FOR_RUN from aca_estimation
hmgaudecker May 6, 2026
e320f41
wip: pass derived_categoricals to create_aca_model in debug
hmgaudecker May 6, 2026
2208fa6
wip: augment fixed_params for ACA policy in debug
hmgaudecker May 6, 2026
8adabda
borrowing_constraint: cast consumption_floor to consumption's dtype
hmgaudecker May 6, 2026
c895bd9
borrowing_constraint: drop dtype cast workaround
hmgaudecker May 7, 2026
e0cc622
tests: switch helpers import to relative form
hmgaudecker May 7, 2026
3d2faf4
tests: drop tests/__init__.py; expose helpers via conftest sys.path
hmgaudecker May 7, 2026
97c84cd
Drop precision-related workarounds and function defaults
hmgaudecker May 7, 2026
4901c9c
Merge cleanup/no-defaults-no-precision-workarounds
hmgaudecker May 7, 2026
9d59174
borrowing_constraint: restore max() form for kink-stability at extrem…
hmgaudecker May 8, 2026
67edfe0
consumption_grid: pin first gridpoint to consumption_floor exactly
hmgaudecker May 8, 2026
d9339ab
ci: bump pylcm pin to 2f486dc
hmgaudecker May 8, 2026
9dd1e2f
ci: bump pylcm pin to 61c2436
hmgaudecker May 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ jobs:
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install pylcm
- name: Install pylcm (unreleased feature branch required)
run: >-
pip install "pylcm @
git+https://github.com/OpenSourceEconomics/pylcm.git@main"
git+https://github.com/OpenSourceEconomics/pylcm.git@61c2436b67ecd9df1c70e80b770be77681c5df63"
- name: Install aca-model with test deps
run: pip install -e . pytest pdbp
- name: Run pytest
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.19.0
rev: v2.21.1
hooks:
- id: pyproject-fmt
- repo: https://github.com/lyz-code/yamlfix
Expand Down Expand Up @@ -47,7 +47,7 @@ repos:
hooks:
- id: yamllint
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.6
rev: v0.15.12
hooks:
- id: ruff-check
args:
Expand Down
58 changes: 38 additions & 20 deletions src/aca_model/aca/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,43 @@

from aca_model.aca import PolicyVariant
from aca_model.aca.regimes import build_all_regimes
from aca_model.baseline.health_insurance import HealthInsuranceState
from aca_model.baseline.regimes import RegimeId
from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig
from aca_model.baseline.regimes._common import MAX_CONSUMPTION
from aca_model.config import MODEL_CONFIG, GridConfig


def create_model(
*,
policy: PolicyVariant = PolicyVariant.ACA,
fixed_params: Mapping[str, Any] | None = None,
wage_params: Mapping[str, Any] | None = None,
derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]]
| None = None,
grid_config: GridConfig = GRID_CONFIG,
n_subjects: int,
policy: PolicyVariant,
fixed_params: Mapping[str, Any] | None,
wage_params: Mapping[str, Any] | None,
derived_categoricals: Mapping[str, DiscreteGrid] | None,
grid_config: GridConfig,
) -> Model:
"""Create an ACA policy variant model.

Args:
policy: Which ACA policy combination to apply.
fixed_params: Parameters to fix at model creation time. These are
partialled into compiled functions and removed from the params
template. Pass data-derived constants here; only estimation
parameters should go through `model.simulate(params=...)`.
n_subjects: Forwarded to `lcm.Model(n_subjects=...)`.
policy: Which ACA policy combination to apply (e.g.
`PolicyVariant.ACA`).
fixed_params: Parameters to fix at model creation time, or `None`
to skip. Fixed params are partialled into compiled functions
and removed from the params template. Pass data-derived
constants here; only estimation parameters should go through
`model.simulate(params=...)`.
wage_params: Data-derived wage profile dict (`log_ft_wage_mean`,
`log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build
time to size the assets-floor to `-max_annual_labor_income`.
Not routed to the pylcm Model.
derived_categoricals: Extra categorical mappings for derived variables
not in the model's state/action grids. Needed when `fixed_params`
contains `pd.Series` indexed by DAG function outputs.
grid_config: Continuous-grid point counts. Defaults to production
values.
Not routed to the pylcm Model. `None` skips the floor sizing.
derived_categoricals: Extra categorical mappings for derived
variables not in the model's state/action grids, or `None`.
Needed when `fixed_params` contains `pd.Series` indexed by DAG
function outputs.
grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for
production values or `BENCHMARK_GRID_CONFIG` for the
fast-but-structurally-faithful benchmark.

Returns:
pylcm Model with ACA-specific function overrides.
Expand All @@ -58,11 +65,22 @@ def create_model(
wage_params=wage_params,
)

return Model(
# See `baseline.model.create_model` for why `target_his` is declared
# as a base-layer derived categorical.
base_derived: dict[str, DiscreteGrid] = {
"target_his": DiscreteGrid(HealthInsuranceState),
}
if derived_categoricals is not None:
base_derived.update(derived_categoricals)

model = Model(
regimes=regimes,
ages=ages,
regime_id_class=RegimeId,
description=f"Structural retirement model ({policy.name})",
fixed_params=fixed_params or {},
derived_categoricals=derived_categoricals,
derived_categoricals=base_derived,
n_subjects=n_subjects,
)
model.max_consumption = MAX_CONSUMPTION
return model
15 changes: 9 additions & 6 deletions src/aca_model/aca/regimes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@
from aca_model.aca.regimes._overrides import apply_aca_overrides
from aca_model.baseline.regimes import build_all_regimes as baseline_build_all_regimes
from aca_model.baseline.regimes._common import REGIME_SPECS
from aca_model.config import GRID_CONFIG, GridConfig
from aca_model.config import GridConfig


def build_all_regimes(
policy: PolicyVariant,
grid_config: GridConfig = GRID_CONFIG,
*,
fixed_params: Mapping[str, Any] | None = None,
wage_params: Mapping[str, Any] | None = None,
policy: PolicyVariant,
grid_config: GridConfig,
fixed_params: Mapping[str, Any] | None,
wage_params: Mapping[str, Any] | None,
) -> dict[str, Regime]:
"""Build all 19 regimes with ACA policy overrides."""
regimes = baseline_build_all_regimes(
grid_config, fixed_params=fixed_params, wage_params=wage_params
grid_config=grid_config,
fixed_params=fixed_params,
wage_params=wage_params,
pref_type_grid=None,
)
result = {}
for name, regime in regimes.items():
Expand Down
45 changes: 40 additions & 5 deletions src/aca_model/agent/assets_and_income.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def next_assets(
consumption: ContinuousAction,
oop_costs: FloatND,
) -> ContinuousState:
"""Compute beginning-of-next-period assets.
"""Compute beginning-of-next-period assets for non-terminal targets.

OOP health costs are deducted here (not from cash_on_hand) so that the
consumption choice does not condition on the HCC shock realization.
Expand All @@ -65,11 +65,46 @@ def next_assets(
)


def next_assets_terminal(
cash_on_hand: FloatND,
transfers: FloatND,
consumption: ContinuousAction,
oop_costs: FloatND,
) -> ContinuousState:
"""Compute beginning-of-next-period assets for the dead/terminal target.

No `pension_assets_adjustment` term: with no future, there is no
next-period pension wealth to impute against. Avoiding the dependency
also keeps the `dead` per-target transition's DAG free of `next_aime`
(which would otherwise need to come from a transition `dead` does not
have, since `aime` is not a state in the terminal regime).
"""
return cash_on_hand + transfers - consumption - oop_costs


def borrowing_constraint(
consumption: ContinuousAction,
cash_on_hand: FloatND,
transfers: FloatND,
pension_assets_adjustment: FloatND,
consumption_floor: float,
equivalence_scale: FloatND,
) -> BoolND:
"""Consumption cannot exceed available resources (no borrowing)."""
return consumption <= cash_on_hand + transfers + pension_assets_adjustment
"""Consumption cannot exceed post-transfer resources.

Post-transfer resources are `max(cash_on_hand, consumption_floor *
equivalence_scale)`: the transfer system tops `cash_on_hand` to the
floor when below, otherwise resources are unchanged. The algebraic
identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`;
the `max` form is preferred because the additive form rounds to
`floor + ε` (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which
flips the kink-boundary comparison for HRS-bottom-coded subjects at
`assets=-$1{,}000{,}000`. The `max` form returns `floor` exactly.

`pension_assets_adjustment` is excluded from the constraint: it can
be negative (e.g. when the imputation overstates next-period pension
wealth at a cross-HIS transition), and including it here can leave
no feasible action at low-asset / mid-AIME corners. The correction
enters `next_assets` instead — a post-decision shift that does not
gate the current consumption choice.
"""
floor = consumption_floor * equivalence_scale
return consumption <= jnp.maximum(cash_on_hand, floor)
42 changes: 25 additions & 17 deletions src/aca_model/agent/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def utility(
"""Within-period utility: CES aggregator over consumption and leisure.

u = utility_scale_factor * ((c/eq_scale)^α * l^(1-α))^(1-γ) / (1-γ)
with log case for γ=1. `consumption_weight`, `coefficient_rra`, and
`utility_scale_factor` are indexed by `pref_type`.
with log case for γ=1. `consumption_weight` and `coefficient_rra` are
pref-type-indexed Series sourced directly from params; `utility_scale_factor`
is a regime-function output (already a per-cell scalar — must NOT be
re-indexed by pref_type, see `aca_model.agent.preferences.utility_scale_factor`
for why).
"""
alpha = consumption_weight[pref_type]
gamma = coefficient_rra[pref_type]
Expand All @@ -147,7 +150,7 @@ def utility(
jnp.log(composite),
composite**one_minus_gamma / one_minus_gamma,
)
return u * utility_scale_factor[pref_type]
return u * utility_scale_factor


def discount_factor(
Expand All @@ -164,6 +167,7 @@ def discount_factor(


def utility_scale_factor(
pref_type: DiscreteState,
average_consumption: float,
consumption_weight: FloatND,
coefficient_rra: FloatND,
Expand All @@ -174,26 +178,29 @@ def utility_scale_factor(
reference_age: int,
scale_reference_age: int,
) -> FloatND:
"""Compute scale factor so utility is approximately 1 at typical values.

Uses leisure at `scale_reference_age` when working `scale_reference_hours`
(after fixed costs) and average consumption. Returns one scale per
preference type, indexed by pref_type.
"""Compute the scale factor so utility is approximately 1 at typical values.

Returns the scalar for the cell's `pref_type`. Mirrors the `discount_factor`
pattern: take the state as input, return a per-cell scalar. Registering this
as a regime function and then doing `utility_scale_factor[pref_type]` in a
downstream consumer is invalid — pylcm broadcasts function outputs to
per-cell scalars before consumption, and the validator in
`lcm.regime_building.validation` raises on that clash.
"""
alpha = consumption_weight[pref_type]
gamma = coefficient_rra[pref_type]
age_offset = scale_reference_age - reference_age
average_leisure = (
time_endowment
- scale_reference_hours
- (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset)
)
u_cons = average_consumption**consumption_weight
u_leisure = average_leisure ** (1.0 - consumption_weight)
u_cons = average_consumption**alpha
u_leisure = average_leisure ** (1.0 - alpha)

one_minus_gamma = jnp.where(
jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra
)
one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma)
raw = jnp.where(
jnp.isclose(coefficient_rra, 1.0),
jnp.isclose(gamma, 1.0),
jnp.log(u_cons * u_leisure),
(u_cons * u_leisure) ** one_minus_gamma / one_minus_gamma,
)
Expand Down Expand Up @@ -237,8 +244,9 @@ def bequest(
"""Bequest function for terminal/dead states.

bequest = scale * bwt * (max(0,a) + shifter)^(α*(1-γ)) / (1-γ)
`consumption_weight`, `coefficient_rra`, and `utility_scale_factor`
are indexed by `pref_type`.
`consumption_weight` and `coefficient_rra` are pref-type-indexed Series
from params; `utility_scale_factor` is a regime-function output (already a
per-cell scalar — must NOT be re-indexed by pref_type).
"""
alpha = consumption_weight[pref_type]
gamma = coefficient_rra[pref_type]
Expand All @@ -250,4 +258,4 @@ def bequest(
jnp.log(assets_shifted),
assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma,
)
return val * scaled_bequest_weight * utility_scale_factor[pref_type]
return val * scaled_bequest_weight * utility_scale_factor
23 changes: 23 additions & 0 deletions src/aca_model/baseline/health_insurance.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,29 @@ def is_medicaid_eligible(is_ssi_eligible: BoolND) -> BoolND:
return is_ssi_eligible


def target_his(
his: IntND,
labor_supply: DiscreteAction,
is_medicaid_eligible: BoolND,
) -> IntND:
"""Return the HIS class of the surviving target regime.

Mirrors the cross-HIS branches inside `_make_transition_canwork` (retiree,
tied, nongroup): tied agents who stop working become nongroup, and
Medicaid-eligible agents are overridden to nongroup. Used by
`imputed_pension_wealth_next_period` to look up next-period imputation
coefficients at the target's HIS.
"""
tied_to_ng = (his == HealthInsuranceState.tied) & (
labor_supply == LaborSupply.do_not_work
)
return jnp.where(
tied_to_ng | is_medicaid_eligible,
HealthInsuranceState.nongroup,
his,
).astype(jnp.int32)


def oop_with_medicaid(
primary_oop: FloatND,
is_medicaid_eligible: BoolND,
Expand Down
Loading