Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
1c65f45
Support runtime-supplied points on continuous-action IrregSpacedGrids
hmgaudecker Apr 29, 2026
1792279
benchmarks: bump aca-model to runtime-consumption-points version
hmgaudecker Apr 29, 2026
cf00e99
Fail loudly when reading runtime IrregSpacedGrid before substitution
hmgaudecker Apr 29, 2026
db98cde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
03ba800
Fix remaining ruff check errors in test_single_feasible_action.py
hmgaudecker Apr 29, 2026
3769d2d
Raise on regime-function-output indexed by discrete state in a consumer
hmgaudecker Apr 29, 2026
282542f
benchmarks: bump aca-model to dead-regime-NaN fix
hmgaudecker Apr 29, 2026
72c83f7
Substitute runtime-supplied action gridpoints in simulate's state-act…
hmgaudecker Apr 29, 2026
db6214f
Guard log-spaced grids against non-positive start; reject non-finite …
hmgaudecker Apr 30, 2026
f56b9be
Address PR review: docstrings, type hints, validation, drop separator…
hmgaudecker Apr 30, 2026
3b7be82
LogSpacedGrid docstring: drop redundant rationale
hmgaudecker May 1, 2026
8589146
IrregSpacedGrid.to_jax docstring: shorter, point at the substituted-g…
hmgaudecker May 1, 2026
541392c
IrregSpacedGrid.to_jax error message: same shape as docstring
hmgaudecker May 1, 2026
54d22b0
Use jnp.isfinite in grid validators; drop math import
hmgaudecker May 1, 2026
ba38876
Drop cryptic aca_model reference from validator docstring
hmgaudecker May 1, 2026
4d11dea
Validator: also flag function-output indexed by a derived categorical
hmgaudecker May 1, 2026
01608bb
create_regime_state_action_space docstring: trim rationale
hmgaudecker May 1, 2026
6241056
state_action_space: move private helpers below public function (deep …
hmgaudecker May 1, 2026
2efe9e1
Drop application-specific (aca) references from test docstrings
hmgaudecker May 1, 2026
f1c4d5d
Validator: rename to discrete_grid_names, also include discrete actions
hmgaudecker May 1, 2026
d9f37ce
Validator: tighten to actual footgun shape; correct behaviour descrip…
hmgaudecker May 1, 2026
2cc46ff
solve_brute: stream NaN/Inf reductions instead of stacking-and-flushing
hmgaudecker May 1, 2026
365da07
solve_brute: stop pinning per-period V templates in diagnostic_rows
hmgaudecker May 1, 2026
bf1cdf4
solve_brute: fail-fast on NaN per period; rewrite stale diagnostic hint
hmgaudecker May 4, 2026
c7745f3
solve/simulate: surface snapshot path in NaN exception note
hmgaudecker May 4, 2026
bc067c1
Model.n_subjects: AOT-compile simulate functions for fixed batch shape
hmgaudecker May 1, 2026
8bb8259
simulate AOT: match runtime's sparse pytree, drop runtime padding
hmgaudecker May 2, 2026
54c72a0
bench_aca_baseline: pass n_subjects=_N_SUBJECTS to create_benchmark_m…
hmgaudecker May 2, 2026
92d038c
benchmarks: bump aca-model rev to carry n_subjects on factories
hmgaudecker May 2, 2026
648afcc
benchmarks: bump aca-model rev + pass max_consumption to factory
hmgaudecker May 2, 2026
596f150
simulate AOT: re-jit `next_state` / `compute_regime_transition_probs`
hmgaudecker May 2, 2026
9fd0524
simulate AOT: only compile active-period argmax, not the full age range
hmgaudecker May 3, 2026
dfb0e8b
simulate AOT: int32 for discrete state lower-args (match runtime)
hmgaudecker May 3, 2026
43723c4
build_initial_states: cast discrete states to grid dtype (one-shot)
hmgaudecker May 3, 2026
99f3f15
DiscreteGrid: pin to_jax() to int32 regardless of x64 mode
hmgaudecker May 3, 2026
458d36f
Lock integer dtype to int32 end-to-end (#341)
hmgaudecker May 3, 2026
e8ede00
benchmarks: bump aca-model rev; drop max_consumption kwarg
hmgaudecker May 3, 2026
4316b6e
solve_brute: merge resolved_fixed_params into NaN diagnostic regime_p…
hmgaudecker May 3, 2026
866a5bb
benchmarks: bump aca-model rev to 714fee0 (assets-floor margin)
hmgaudecker May 4, 2026
a51edae
regime_template: exempt next_<state> names from fixed_param extraction
hmgaudecker May 4, 2026
71a6146
Merge feature/next-state-deps-in-transitions: exempt next_<state> fro…
hmgaudecker May 4, 2026
ac93eec
Merge improve/lazy-solve-diagnostics (incl. next_<state> exempt-set fix)
hmgaudecker May 4, 2026
d66d85a
Boilerplate refresh: dags module, current pixi/hook pins, drop stale …
hmgaudecker May 4, 2026
f18d27c
Bump .ai-instructions: pyproject-fmt + ruff + check-jsonschema rev pins
hmgaudecker May 4, 2026
5e01de6
Merge feature/runtime-action-grids: boilerplate refresh + ai-instruct…
hmgaudecker May 4, 2026
f730e78
Merge feature/next-state-deps-in-transitions: boilerplate refresh + a…
hmgaudecker May 4, 2026
749f83a
Merge improve/lazy-solve-diagnostics: boilerplate refresh + ai-instru…
hmgaudecker May 4, 2026
c3e1838
Merge main: pick up #338 (runtime action grids + validator tightening)
hmgaudecker May 4, 2026
01ba1f3
Merge feature/next-state-deps-in-transitions (carries main → #338)
hmgaudecker May 4, 2026
3229a15
Merge improve/lazy-solve-diagnostics (carries main → #338 → #342)
hmgaudecker May 4, 2026
e45cf4f
regime_template: reject next_<state> params on regular DAG functions
hmgaudecker May 4, 2026
5ff174d
Merge feature/next-state-deps-in-transitions: harden next_<state> val…
hmgaudecker May 4, 2026
5bae789
Merge improve/lazy-solve-diagnostics: harden next_<state> validator
hmgaudecker May 4, 2026
c585f4b
Bump aca-model benchmark pin to 83f22500 (post pension correction)
hmgaudecker May 4, 2026
ffe8820
Merge feature/next-state-deps-in-transitions: bump aca-model benchmar…
hmgaudecker May 4, 2026
e912de4
Merge improve/lazy-solve-diagnostics: bump aca-model benchmark pin to…
hmgaudecker May 4, 2026
f0dd7b5
Revert aca-model pin: this branch lacks Model.n_subjects (introduced …
hmgaudecker May 4, 2026
149ac78
Merge feature/next-state-deps-in-transitions: revert aca-model pin to…
hmgaudecker May 4, 2026
588c9c4
Merge improve/lazy-solve-diagnostics
hmgaudecker May 4, 2026
c00b610
Bump aca-model pin to 83f22500 on #340 (carries pension correction)
hmgaudecker May 4, 2026
e89d5e4
Revert "regime_template: reject next_<state> params on regular DAG fu…
hmgaudecker May 4, 2026
d07d897
Merge feature/next-state-deps-in-transitions: revert validator that c…
hmgaudecker May 4, 2026
53443fc
Merge improve/lazy-solve-diagnostics: revert validator + carry over
hmgaudecker May 4, 2026
e92eeec
Bump aca-model pin to 3453080 (filters stale benchmark_params key)
hmgaudecker May 4, 2026
c117f4c
Bump aca-model pin to b2e90bb (synthesise shifted imputation arrays)
hmgaudecker May 4, 2026
e422876
Bump aca-model pin to 35eddcc (declare target_his derived categorical)
hmgaudecker May 4, 2026
4b9bea3
Bump aca-model pin to 64d6567 (rename shifted-array level to target_his)
hmgaudecker May 4, 2026
9c7edb6
Bump aca-model pin to f09b5e3 (per-target next_assets, dead-target te…
hmgaudecker May 4, 2026
e6066fa
Roll #340 aca-model pin back to 63d2a38 (pre-pension-correction)
hmgaudecker May 4, 2026
a908c84
get_next_state_function_for_simulation: per-target DAG mirrors solve
hmgaudecker May 5, 2026
18d4ade
Revert aca-model rollback: restore f09b5e3 pin (with pension correction)
hmgaudecker May 5, 2026
6c64a77
get_next_state_function_for_simulation: per-target DAG mirrors solve
hmgaudecker May 5, 2026
fed28cd
Merge feature/next-state-deps-in-transitions: pull pylcm simulate-pat…
hmgaudecker May 5, 2026
c969b1a
next_state: real signature for combined; fix trivially-passing test
hmgaudecker May 6, 2026
8a2ad4f
Address #342 review: simulate-path uses concatenate_functions; cleanups
hmgaudecker May 6, 2026
bcdd358
Merge feature/next-state-deps-in-transitions: address #342 review fee…
hmgaudecker May 6, 2026
17347c8
Get rid of H_variables entirely in regime_template.
hmgaudecker May 6, 2026
076b9b6
Bump .ai-instructions: TDD always; behavior-focused docstrings
hmgaudecker May 6, 2026
e6f8e41
Merge feature/next-state-deps-in-transitions: bump .ai-instructions f…
hmgaudecker May 6, 2026
164a88b
AGENTS.md: inline TDD-always testing section directly
hmgaudecker May 6, 2026
07c5ae0
Merge feature/next-state-deps-in-transitions: inline TDD section in A…
hmgaudecker May 6, 2026
5261e29
Address #339 review: drop field-count test; tighten claudish docstrings
hmgaudecker May 6, 2026
110cc0b
regime_template: collapse H_variables into single variables set
hmgaudecker May 6, 2026
4b2895c
Merge feature/next-state-deps-in-transitions: collapse H_variables
hmgaudecker May 6, 2026
a7b9e9a
solve_brute: drop misleading "~2 MB each" magic number from comment
hmgaudecker May 6, 2026
16b570f
AGENTS.md: docstring style — describe state, no PR refs, bulleted lists
hmgaudecker May 6, 2026
1251f0b
Merge feature/next-state-deps-in-transitions: docstring style + .ai-i…
hmgaudecker May 6, 2026
ff65261
solve_brute: apply docstring style — drop PR ref, magic number; bulle…
hmgaudecker May 6, 2026
2539374
Merge improve/lazy-solve-diagnostics: docstring style, .ai-instructio…
hmgaudecker May 6, 2026
e9b7cc5
test_next_state: update to assert nested output shape from #339 merge
hmgaudecker May 6, 2026
eb69432
Merge main: pick up squash of #342
hmgaudecker May 6, 2026
838473e
validate_initial_conditions: per-constraint admissibility in error me…
hmgaudecker May 6, 2026
07f951a
Merge improve/lazy-solve-diagnostics: pick up main-squash + downstrea…
hmgaudecker May 6, 2026
e4cae2a
_per_constraint_feasibility: filter args per single-constraint feasib…
hmgaudecker May 6, 2026
62392c1
Merge branch 'main' into feat/simulate-aot-n-subjects
hmgaudecker May 6, 2026
9e92ca6
benchmarks: pin aca-model to minimal n_subjects-forwarding commit
hmgaudecker May 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/bench_aca_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _build() -> tuple[object, object, object]:
get_benchmark_params,
)

model = create_benchmark_model()
model = create_benchmark_model(n_subjects=_N_SUBJECTS)
_, model_params = get_benchmark_params(model=model)
initial_conditions = get_benchmark_initial_conditions(
model=model, n_subjects=_N_SUBJECTS, seed=0
Expand Down
8 changes: 4 additions & 4 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
[tool.pixi.feature.benchmarks.pypi-dependencies]
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" }
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "3f215a2c44237b9fa3fa74bf78ef93c8d695a517" }
[tool.pixi.feature.cuda12]
platforms = [ "linux-64" ]
system-requirements = { cuda = "12" }
Expand Down
14 changes: 12 additions & 2 deletions src/lcm/grids/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,15 @@ def batch_size(self) -> int:
return self.__batch_size

def to_jax(self) -> Int1D:
"""Convert the grid to a Jax array."""
return jnp.array(self.codes)
"""Convert the grid to a Jax array.

Discrete state/action codes are pinned to `int32` regardless of the
ambient `jax_enable_x64` setting. `jnp.array([...])` would otherwise
produce `int32` in 32-bit mode and `int64` in x64 mode, and
downstream values (transitions, V-array indexing, action lookups)
inherit that ambiguity — which silently splits the JIT cache into
per-period int32/int64 variants and breaks any AOT-compiled
program that ships a single signature. `int32` covers any realistic
category count and matches the `MISSING_CAT_CODE` sentinel.
"""
return jnp.array(self.codes, dtype=jnp.int32)
72 changes: 71 additions & 1 deletion src/lcm/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Collection of classes that are used by the user to define the model and grids."""

import dataclasses
import logging
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
Expand Down Expand Up @@ -30,6 +31,7 @@
)
from lcm.regime import Regime
from lcm.regime_building.processing import InternalRegime
from lcm.simulation.compile import compile_all_simulate_functions
from lcm.simulation.initial_conditions import validate_initial_conditions
from lcm.simulation.result import SimulationResult, get_simulation_output_dtypes
from lcm.simulation.simulate import simulate
Expand Down Expand Up @@ -85,6 +87,16 @@ class Model:
fixed_params: UserParams
"""Parameters fixed at model initialization."""

n_subjects: int | None = None
"""Expected simulate batch size; enables AOT compile of simulate functions.

When set, the first matching `simulate(...)` call AOT-compiles all simulate
functions for batch shape `n_subjects` in parallel. Subsequent calls with the
same size reuse the compiled programs. Calls with a mismatching size warn
once per size and fall back to the runtime-traced path. `None` keeps the
purely lazy behaviour.
"""

_params_template: ParamsTemplate
"""Template for the model parameters."""

Expand All @@ -100,6 +112,7 @@ def __init__(
derived_categoricals: Mapping[FunctionName, DiscreteGrid] = MappingProxyType(
{}
),
n_subjects: int | None = None,
) -> None:
"""Initialize the Model.

Expand All @@ -115,17 +128,27 @@ def __init__(
not in states/actions. Broadcast to all regimes (merged with
each regime's own `derived_categoricals`). Raises if a regime
already has a conflicting entry.
n_subjects: Expected simulate batch size; if set, the first matching
`simulate(...)` call AOT-compiles all simulate functions for
batch shape `n_subjects` in parallel. `None` keeps the purely
lazy behaviour.

"""
self.description = description
self.ages = ages
self.n_periods = ages.n_periods
self.fixed_params = ensure_containers_are_immutable(fixed_params)
self.n_subjects = n_subjects
self._simulate_compile_cache: dict[
int, MappingProxyType[RegimeName, InternalRegime]
] = {}
self._warned_n_subjects: set[int] = set()

validate_model_inputs(
n_periods=self.n_periods,
regimes=regimes,
regime_id_class=regime_id_class,
n_subjects=n_subjects,
)
self.regime_names_to_ids = MappingProxyType(
dict(
Expand Down Expand Up @@ -240,6 +263,46 @@ def solve(
)
return period_to_regime_to_V_arr

def _resolve_simulate_internal_regimes(
self,
*,
actual_n_subjects: int,
internal_params: InternalParams,
log: logging.Logger,
max_compilation_workers: int | None,
) -> MappingProxyType[RegimeName, InternalRegime]:
"""Return internal_regimes to use for simulate; AOT cache when matching.

Returns the original `internal_regimes` when `n_subjects` is `None` or
when the actual batch size mismatches the declared one (logging a
warning once per mismatching size). Otherwise builds and caches the
AOT-compiled regimes for the matching size.
"""
if self.n_subjects is None:
return self.internal_regimes
if actual_n_subjects != self.n_subjects:
if actual_n_subjects not in self._warned_n_subjects:
log.warning(
"simulate called with n_subjects=%d but model declared "
"n_subjects=%d; falling back to runtime compile.",
actual_n_subjects,
self.n_subjects,
)
self._warned_n_subjects.add(actual_n_subjects)
return self.internal_regimes
if self.n_subjects not in self._simulate_compile_cache:
self._simulate_compile_cache[self.n_subjects] = (
compile_all_simulate_functions(
internal_regimes=self.internal_regimes,
internal_params=internal_params,
ages=self.ages,
n_subjects=self.n_subjects,
max_compilation_workers=max_compilation_workers,
logger=log,
)
)
return self._simulate_compile_cache[self.n_subjects]

def simulate(
self,
*,
Expand Down Expand Up @@ -341,10 +404,17 @@ def simulate(
)
exc.add_note(f"Snapshot saved to {snap_dir}")
raise
actual_n_subjects = len(next(iter(initial_conditions.values())))
simulate_internal_regimes = self._resolve_simulate_internal_regimes(
actual_n_subjects=actual_n_subjects,
internal_params=internal_params,
log=log,
max_compilation_workers=max_compilation_workers,
)
result = simulate(
internal_params=internal_params,
initial_conditions=initial_conditions,
internal_regimes=self.internal_regimes,
internal_regimes=simulate_internal_regimes,
regime_names_to_ids=self.regime_names_to_ids,
logger=log,
period_to_regime_to_V_arr=period_to_regime_to_V_arr,
Expand Down
16 changes: 16 additions & 0 deletions src/lcm/model_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ def validate_model_inputs(
n_periods: int,
regimes: Mapping[RegimeName, Regime],
regime_id_class: type,
n_subjects: int | None = None,
) -> None:
"""Validate model constructor inputs."""
_fail_if_invalid_n_subjects(n_subjects=n_subjects)

# Early exit if regimes are not lcm.Regime instances
if not all(isinstance(regime, Regime) for regime in regimes.values()):
raise ModelInitializationError(
Expand Down Expand Up @@ -201,6 +204,19 @@ def validate_model_inputs(
raise ModelInitializationError(msg)


def _fail_if_invalid_n_subjects(*, n_subjects: int | None) -> None:
"""Raise TypeError if non-int, ValueError if non-positive."""
if n_subjects is None:
return
# `bool` is a subclass of `int`; reject explicitly so True/False don't slip through.
if not isinstance(n_subjects, int) or isinstance(n_subjects, bool):
msg = f"n_subjects must be an int or None, got {type(n_subjects).__name__}."
raise TypeError(msg)
if n_subjects <= 0:
msg = f"n_subjects must be a positive integer, got {n_subjects}."
raise ValueError(msg)


def _validate_all_variables_used(regimes: Mapping[RegimeName, Regime]) -> list[str]:
"""Validate that all states and actions are used somewhere in each regime.

Expand Down
3 changes: 2 additions & 1 deletion src/lcm/pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def initial_conditions_from_dataframe( # noqa: C901
for col, arr in result_arrays.items()
}
initial_conditions["regime"] = jnp.array(
df["regime"].map(dict(regime_names_to_ids)).to_numpy()
df["regime"].map(dict(regime_names_to_ids)).to_numpy(),
dtype=jnp.int32,
)

return initial_conditions
Expand Down
4 changes: 2 additions & 2 deletions src/lcm/regime_building/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def argmax_and_max(
# When there are no dimensions to reduce over, return:
# - index 0 (trivial argmax since there's only one element)
# - the array itself (already the maximum)
return jnp.array(0), a
return jnp.array(0, dtype=jnp.int32), a

# Move axis over which to compute the argmax to the back and flatten last dims
# ==================================================================================
Expand All @@ -65,7 +65,7 @@ def argmax_and_max(
max_value_mask = a == _max
if where is not None:
max_value_mask = jnp.logical_and(max_value_mask, where)
_argmax = jnp.argmax(max_value_mask, axis=-1)
_argmax = jnp.argmax(max_value_mask, axis=-1).astype(jnp.int32)

return _argmax, _max.reshape(_argmax.shape)

Expand Down
Loading
Loading