Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1c65f45
Support runtime-supplied points on continuous-action IrregSpacedGrids
hmgaudecker Apr 29, 2026
1792279
benchmarks: bump aca-model to runtime-consumption-points version
hmgaudecker Apr 29, 2026
cf00e99
Fail loudly when reading runtime IrregSpacedGrid before substitution
hmgaudecker Apr 29, 2026
db98cde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
03ba800
Fix remaining ruff check errors in test_single_feasible_action.py
hmgaudecker Apr 29, 2026
3769d2d
Raise on regime-function-output indexed by discrete state in a consumer
hmgaudecker Apr 29, 2026
282542f
benchmarks: bump aca-model to dead-regime-NaN fix
hmgaudecker Apr 29, 2026
72c83f7
Substitute runtime-supplied action gridpoints in simulate's state-act…
hmgaudecker Apr 29, 2026
db6214f
Guard log-spaced grids against non-positive start; reject non-finite …
hmgaudecker Apr 30, 2026
f56b9be
Address PR review: docstrings, type hints, validation, drop separator…
hmgaudecker Apr 30, 2026
3b7be82
LogSpacedGrid docstring: drop redundant rationale
hmgaudecker May 1, 2026
8589146
IrregSpacedGrid.to_jax docstring: shorter, point at the substituted-g…
hmgaudecker May 1, 2026
541392c
IrregSpacedGrid.to_jax error message: same shape as docstring
hmgaudecker May 1, 2026
54d22b0
Use jnp.isfinite in grid validators; drop math import
hmgaudecker May 1, 2026
ba38876
Drop cryptic aca_model reference from validator docstring
hmgaudecker May 1, 2026
4d11dea
Validator: also flag function-output indexed by a derived categorical
hmgaudecker May 1, 2026
01608bb
create_regime_state_action_space docstring: trim rationale
hmgaudecker May 1, 2026
6241056
state_action_space: move private helpers below public function (deep …
hmgaudecker May 1, 2026
2efe9e1
Drop application-specific (aca) references from test docstrings
hmgaudecker May 1, 2026
f1c4d5d
Validator: rename to discrete_grid_names, also include discrete actions
hmgaudecker May 1, 2026
d9f37ce
Validator: tighten to actual footgun shape; correct behaviour descrip…
hmgaudecker May 1, 2026
d66d85a
Boilerplate refresh: dags module, current pixi/hook pins, drop stale …
hmgaudecker May 4, 2026
f18d27c
Bump .ai-instructions: pyproject-fmt + ruff + check-jsonschema rev pins
hmgaudecker May 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ai-instructions
10 changes: 5 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.66.0
pixi-version: v0.67.2
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cpu
Expand Down Expand Up @@ -59,7 +59,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.66.0
pixi-version: v0.67.2
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: type-checking
Expand All @@ -82,7 +82,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.66.0
pixi-version: v0.67.2
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cuda12
Expand All @@ -101,7 +101,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.66.0
pixi-version: v0.67.2
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cuda12
Expand All @@ -116,7 +116,7 @@ jobs:
# - uses: actions/checkout@v6
# - uses: prefix-dev/setup-pixi@v0.9.5
# with:
# pixi-version: v0.66.0
# pixi-version: v0.67.2
# cache: true
# cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
# environments: tests-cpu
Expand Down
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ docs/_build/
.pixi/
node_modules/

# pytask
.pytask.sqlite3

# Python
__pycache__/
*.py[cod]
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ repos:
hooks:
- id: yamllint
- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.37.1
rev: 0.37.2
hooks:
- id: check-github-workflows
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.11
rev: v0.15.12
hooks:
- id: ruff-check
args:
Expand Down Expand Up @@ -86,7 +86,7 @@ repos:
args:
- --wrap
- '88'
files: (AGENTS\.md|CLAUDE\.md|README\.md|modules/.*\.md|profiles/.*\.md)
files: (AGENTS\.md|CLAUDE\.md|README\.md)
- id: mdformat
additional_dependencies:
- mdformat-myst
Expand Down
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@.ai-instructions/profiles/tier-a.md @.ai-instructions/modules/jax.md
@.ai-instructions/modules/pandas.md @.ai-instructions/modules/plotting.md
@.ai-instructions/modules/dags.md

# PyLCM

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bench_aca_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _build() -> tuple[object, object, object]:
)

model = create_benchmark_model()
_, model_params = get_benchmark_params()
_, model_params = get_benchmark_params(model=model)
initial_conditions = get_benchmark_initial_conditions(
model=model, n_subjects=_N_SUBJECTS, seed=0
)
Expand Down
8 changes: 4 additions & 4 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
[tool.pixi.feature.benchmarks.pypi-dependencies]
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "adc8a19328608781a5cb2a65ab2d93d580163aae" }
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" }
[tool.pixi.feature.cuda12]
platforms = [ "linux-64" ]
system-requirements = { cuda = "12" }
Expand Down
73 changes: 62 additions & 11 deletions src/lcm/grids/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,22 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array:
class LogSpacedGrid(UniformContinuousGrid):
"""A logarithmically spaced grid of continuous values.

Requires `start > 0`.

Example:
--------
Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 10, 100]`.

"""

def __post_init__(self) -> None:
_validate_continuous_grid(
start=self.start,
stop=self.stop,
n_points=self.n_points,
requires_positive_start=True,
)

def to_jax(self) -> Float1D:
"""Convert the grid to a Jax array."""
return grid_coordinates.logspace(
Expand Down Expand Up @@ -188,9 +198,23 @@ def pass_points_at_runtime(self) -> bool:
return self.points is None

def to_jax(self) -> Float1D:
"""Convert the grid to a Jax array."""
"""Convert the grid to a Jax array.

Raises `GridInitializationError` for runtime-supplied grids
(`pass_points_at_runtime=True`). To get the substituted points,
call `internal_regime.state_action_space(regime_params=...)` and
read from `.states[name]` or `.continuous_actions[name]`.
"""
if self.points is None:
return jnp.full(self.n_points, jnp.nan)
raise GridInitializationError(
f"IrregSpacedGrid declared with n_points={self.n_points} and "
f"no points; values are supplied at runtime via "
f"params['<regime>']['<grid_name>']['points']. To get the "
f"substituted points, call "
f"`internal_regime.state_action_space(regime_params=...)` and "
f"read from `.states[name]` or `.continuous_actions[name]`. "
f"Use `.n_points` if only the shape is needed."
)
return jnp.asarray(self.points)

@overload
Expand All @@ -213,13 +237,16 @@ def _validate_continuous_grid(
start: float,
stop: float,
n_points: int,
requires_positive_start: bool = False,
) -> None:
"""Validate the continuous grid parameters.

Args:
start: The start value of the grid.
stop: The stop value of the grid.
n_points: The number of points in the grid.
requires_positive_start: If True, also require `start > 0` (used by
log-spaced grids since `log(x)` is undefined for `x <= 0`).

Raises:
GridInitializationError: If the grid parameters are invalid.
Expand All @@ -235,6 +262,15 @@ def _validate_continuous_grid(
if not valid_stop_type:
error_messages.append("stop must be a scalar int or float value")

# Reject NaN/inf early — `start >= stop` returns False for NaN, so an
# un-finite start would otherwise pass silently and produce a broken grid.
if valid_start_type and not jnp.isfinite(start):
error_messages.append(f"start must be finite, got {start}")
valid_start_type = False
if valid_stop_type and not jnp.isfinite(stop):
error_messages.append(f"stop must be finite, got {stop}")
valid_stop_type = False

if not isinstance(n_points, int) or n_points < 1:
error_messages.append(
f"n_points must be an int greater than 0 but is {n_points}",
Expand All @@ -243,6 +279,12 @@ def _validate_continuous_grid(
if valid_start_type and valid_stop_type and start >= stop:
error_messages.append("start must be less than stop")

if valid_start_type and requires_positive_start and start <= 0:
error_messages.append(
f"start must be > 0 for a log-spaced grid (got {start}); "
f"`log(x)` is undefined for `x <= 0`."
)

if error_messages:
msg = format_messages(error_messages)
raise GridInitializationError(msg)
Expand Down Expand Up @@ -275,15 +317,24 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None:
f"Non-numeric elements found at indices: {non_numeric}"
)
else:
# Check that points are in ascending order
for i in range(len(points) - 1):
if points[i] >= points[i + 1]:
error_messages.append(
"Points must be in strictly ascending order. "
f"Found points[{i}]={points[i]} >= "
f"points[{i + 1}]={points[i + 1]}"
)
break
# Reject NaN/inf — comparisons with NaN are False, so the
# ascending-order check below would silently let them through.
non_finite = [(i, p) for i, p in enumerate(points) if not jnp.isfinite(p)]
if non_finite:
error_messages.append(
f"All elements of points must be finite. "
f"Non-finite elements found at: {non_finite}"
)
else:
# Check that points are in strictly ascending order
for i in range(len(points) - 1):
if points[i] >= points[i + 1]:
error_messages.append(
"Points must be in strictly ascending order. "
f"Found points[{i}]={points[i]} >= "
f"points[{i + 1}]={points[i + 1]}"
)
break

if error_messages:
msg = format_messages(error_messages)
Expand Down
73 changes: 53 additions & 20 deletions src/lcm/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ class InternalRegime:
"""Flat resolved fixed params for this regime, used by to_dataframe targets."""

def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpace:
"""Return the state-action space with runtime state grids filled in.
"""Return the state-action space with runtime grids filled in.

For IrregSpacedGrid with runtime-supplied points, the grid points come from
params as `{state_name}__points`. For _ShockGrid with runtime-supplied params,
the grid points are computed from shock params in the params dict or
resolved_fixed_params.
For IrregSpacedGrid (state or continuous action) with runtime-supplied
points, the grid points come from params as `{name}__points`. For
`_ShockGrid` with runtime-supplied params, the grid points are computed
from shock params in the params dict or `resolved_fixed_params`.

Args:
regime_params: Flat regime parameters supplied at runtime.
Expand All @@ -257,35 +257,68 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac

"""
all_params = {**self.resolved_fixed_params, **regime_params}
replacements: dict[str, ContinuousState | DiscreteState] = {}
for state_name, spec in self.grids.items():
if state_name not in self._base_state_action_space.states:
state_replacements: dict[str, ContinuousState | DiscreteState] = {}
action_replacements: dict[str, ContinuousAction] = {}
for name, spec in self.grids.items():
in_states = name in self._base_state_action_space.states
in_continuous_actions = (
name in self._base_state_action_space.continuous_actions
)
if not (in_states or in_continuous_actions):
continue
if isinstance(spec, IrregSpacedGrid) and spec.pass_points_at_runtime:
points_key = f"{state_name}__points"
points_key = f"{name}__points"
if points_key not in all_params:
continue
replacements[state_name] = cast(
"ContinuousState", all_params[points_key]
)
elif isinstance(spec, _ShockGrid) and spec.params_to_pass_at_runtime:
if in_states:
state_replacements[name] = cast(
"ContinuousState", all_params[points_key]
)
else:
action_replacements[name] = cast(
"ContinuousAction", all_params[points_key]
)
# `_ShockGrid` is state-only by construction (intrinsic
# transitions, forbidden as actions per AGENTS.md). The
# `in_states` gate makes that invariant explicit — a
# `_ShockGrid` reaching the action branch would be a model
# bug, not something this method should silently substitute.
elif (
in_states
and isinstance(spec, _ShockGrid)
and spec.params_to_pass_at_runtime
):
all_present = all(
f"{state_name}__{p}" in all_params
for p in spec.params_to_pass_at_runtime
f"{name}__{p}" in all_params for p in spec.params_to_pass_at_runtime
)
if not all_present:
continue
shock_kw: dict[str, float] = dict(spec.params)
for p in spec.params_to_pass_at_runtime:
shock_kw[p] = cast("float", all_params[f"{state_name}__{p}"])
replacements[state_name] = spec.compute_gridpoints(**shock_kw)
shock_kw[p] = cast("float", all_params[f"{name}__{p}"])
state_replacements[name] = spec.compute_gridpoints(**shock_kw)

if not replacements:
if not state_replacements and not action_replacements:
return self._base_state_action_space

new_states = dict(self._base_state_action_space.states) | replacements
new_states = (
MappingProxyType(
dict(self._base_state_action_space.states) | state_replacements
)
if state_replacements
else None
)
new_continuous_actions = (
MappingProxyType(
dict(self._base_state_action_space.continuous_actions)
| action_replacements
)
if action_replacements
else None
)
return self._base_state_action_space.replace(
states=MappingProxyType(new_states)
states=new_states,
continuous_actions=new_continuous_actions,
)


Expand Down
15 changes: 9 additions & 6 deletions src/lcm/pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,12 +504,15 @@ def _resolve_per_target_template_key(

def _is_runtime_grid_param(*, func_name: FunctionName, regime: Regime) -> bool:
"""Check if a template function key refers to a runtime grid param."""
if func_name not in regime.states:
return False
grid = regime.states[func_name]
return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or (
isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime)
)
if func_name in regime.states:
grid = regime.states[func_name]
return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or (
isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime)
)
if func_name in regime.actions:
grid = regime.actions[func_name]
return isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime
return False


def _fail_if_period_level(sr: pd.Series) -> None:
Expand Down
Loading