Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e881313
Package B: float dtype barriers at the API boundary
hmgaudecker May 6, 2026
2ee56ca
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 6, 2026
ef180d0
Fix Package B 32-bit precision test: build float overflow fixture wit…
hmgaudecker May 6, 2026
1a42ffb
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 6, 2026
75c8d25
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 7, 2026
1947560
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 7, 2026
09f3d03
Address PR #345 review
hmgaudecker May 7, 2026
b8dc490
bench_aca_baseline: pass pref_type_grid to create_benchmark_model
hmgaudecker May 7, 2026
3c3af21
Merge cleanup/aca-bench-no-defaults
hmgaudecker May 7, 2026
cbe65be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2026
a6b0ac1
bench_aca_baseline: hoist aca_model + lcm imports to module top
hmgaudecker May 8, 2026
381dc25
Merge feat/simulate-aot-n-subjects + reorder reshape-before-cast
hmgaudecker May 8, 2026
9d26643
_validate_param_types: drop dead branches post-whitelist
hmgaudecker May 8, 2026
530d50c
Tighten Scalar* aliases to JAX-only; convert grid endpoints at constr…
hmgaudecker May 8, 2026
f4515ec
Keep coordinate helpers strict; convert at Grid.get_coordinate boundary
hmgaudecker May 8, 2026
9fc2f49
bench_aca_baseline: build on CPU to keep parent process CUDA-free
hmgaudecker May 8, 2026
88a85ae
save_simulate_snapshot: strip AOT-compiled regimes before pickling re…
hmgaudecker May 8, 2026
f2d18fa
bench_aca_baseline: defer aca_model + lcm imports back into _build
hmgaudecker May 8, 2026
fff1537
Tighten internal types: ScalarInt n_points, JAX-only Period/Age, kw-only
hmgaudecker May 8, 2026
aea8735
linspace/logspace: drop int(n_points) cast in favour of ty:ignore
hmgaudecker May 8, 2026
f4069c1
Piecewise n_points: sum the cached _piece_n_points array
hmgaudecker May 8, 2026
bf12b61
benchmarks: bump aca-model pin to 67edfe0f
hmgaudecker May 8, 2026
1bae789
simulate: keep period: int through the loop, cast at the JIT boundary
hmgaudecker May 8, 2026
2f486dc
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 8, 2026
c419377
benchmarks: bump aca-model pin to d9339ab
hmgaudecker May 8, 2026
61c2436
simulate orchestrates simulate-AOT compile, not solve
hmgaudecker May 8, 2026
1deed36
tests: drop noqa: ARG001 + collapse x64-fixture signatures
hmgaudecker May 9, 2026
00f3b4a
test_next_state: pass JAX scalars instead of ty:ignore-ing Python ones
hmgaudecker May 9, 2026
ca66ba9
simulate: swap AOT-compiled regimes for lazy ones on the result
hmgaudecker May 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions benchmarks/bench_aca_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,30 @@


def _build() -> tuple[object, object, object]:
"""Build the aca-baseline model, params, and initial conditions."""
"""Build the aca-baseline model, params, and initial conditions.

aca_model and lcm imports are deferred to the function body — ASV's
forkserver runs `preimport` to discover benchmarks across every
`bench_*.py` module before forking workers. Importing JAX at module
top loads the multithreaded XLA backend into the forkserver; every
subsequent `os.fork()` inherits a corrupted CUDA context and the
first device op in the worker aborts with
`CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the
forkserver and confine it to the worker process.
"""
from aca_model.agent.preferences import BenchmarkPrefType
from aca_model.benchmark import (
create_benchmark_model,
get_benchmark_initial_conditions,
get_benchmark_params,
)

model = create_benchmark_model(n_subjects=_N_SUBJECTS)
from lcm import DiscreteGrid

model = create_benchmark_model(
n_subjects=_N_SUBJECTS,
pref_type_grid=DiscreteGrid(BenchmarkPrefType),
)
_, model_params = get_benchmark_params(model=model)
initial_conditions = get_benchmark_initial_conditions(
model=model, n_subjects=_N_SUBJECTS, seed=0
Expand Down
8 changes: 4 additions & 4 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
[tool.pixi.feature.benchmarks.pypi-dependencies]
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" }
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6" }
[tool.pixi.feature.cuda12]
platforms = [ "linux-64" ]
system-requirements = { cuda = "12" }
Expand Down Expand Up @@ -242,6 +242,12 @@ per-file-ignores."tests/*" = [
"S301", # Use of pickle
"SLF001", # Private member access
]
per-file-ignores."tests/test_dtypes.py" = [
"ARG001", # Unused function argument (x64_enabled / x64_disabled fixtures)
]
per-file-ignores."tests/test_float_dtype_invariants.py" = [
"ARG001", # Unused function argument (x64_disabled fixture)
]
per-file-ignores."tests/test_next_state.py" = [
"ARG001", # Unused function argument
"ARG005", # Unused lambda argument
Expand Down
14 changes: 8 additions & 6 deletions src/lcm/ages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.numpy as jnp

from lcm.exceptions import GridInitializationError, format_messages
from lcm.typing import Age, Float1D, Int1D
from lcm.typing import Float1D, Int1D

STEP_UNITS: MappingProxyType[str, Fraction] = MappingProxyType(
{
Expand Down Expand Up @@ -129,7 +129,7 @@ def exact_step_size(self) -> int | Fraction | None:
"""
return self._exact_step_size

def period_to_age(self, period: int) -> Age:
def period_to_age(self, period: int) -> int | float:
"""Convert a period index to the corresponding age.

Args:
Expand All @@ -151,7 +151,7 @@ def period_to_age(self, period: int) -> Age:
return int(self._values[period])
return float(self._values[period])

def age_to_period(self, age: Age) -> int:
def age_to_period(self, age: float) -> int:
"""Convert an age to the corresponding period index.

Args:
Expand All @@ -172,12 +172,14 @@ def age_to_period(self, age: Age) -> int:
raise ValueError(msg) from None

@functools.cached_property
def _age_to_period_map(self) -> dict[Age, int]:
def _age_to_period_map(self) -> dict[int | float, int]:
if self._is_integer:
return {int(v): i for i, v in enumerate(self._exact_values)}
return {float(v): i for i, v in enumerate(self._exact_values)}

def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]:
def get_periods_where(
self, predicate: Callable[[int | float], bool]
) -> tuple[int, ...]:
"""Get period indices where predicate is True.

Args:
Expand All @@ -187,7 +189,7 @@ def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]
Tuple of period indices where predicate(age) is True.

"""
_convert: Callable[[object], Age] = int if self._is_integer else float # ty: ignore[invalid-assignment]
_convert: Callable[[object], int | float] = int if self._is_integer else float # ty: ignore[invalid-assignment]
return tuple(
period
for period in range(self.n_periods)
Expand Down
60 changes: 54 additions & 6 deletions src/lcm/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
Used at every API boundary that accepts user data (params, initial
conditions, regime-id arrays) — always called from Python, never inside
JIT. Each helper validates that the value fits the target dtype and
raises a clearly-named error if not.

Casts further down the simulate stack (e.g. transition outputs landing
in the state pool) use plain `.astype` and rely on the boundary cast
above them having already pinned the canonical dtype.
raises a clearly-named error if not. Once an input has crossed the
boundary it carries the canonical dtype unchanged through the simulate
stack; downstream code does not re-cast.
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax import Array

_INT32_MIN = int(np.iinfo(np.int32).min)
_INT32_MAX = int(np.iinfo(np.int32).max)
_FLOAT32_MAX = float(np.finfo(np.float32).max)


def canonical_float_dtype() -> jnp.dtype:
"""Return pylcm's canonical float dtype, derived from `jax_enable_x64`.

Returns `jnp.float64` if `jax.config.jax_enable_x64` is True,
otherwise `jnp.float32`. The value is read at call time, not at
import, so toggling the JAX config (e.g. between tests) is honoured.
"""
return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32


def safe_to_int32(value: object, *, name: str) -> Array:
def safe_to_int_dtype(value: object, *, name: str) -> Array:
"""Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range.

Args:
Expand Down Expand Up @@ -46,3 +56,41 @@ def safe_to_int32(value: object, *, name: str) -> Array:
)
raise ValueError(msg)
return jnp.asarray(np_value, dtype=jnp.int32)


def safe_to_float_dtype(value: object, *, name: str) -> Array:
"""Cast a scalar, sequence, or array to the canonical float dtype.

Range check fires only on a down-cast:

- Down-cast (float64 → float32 under `jax_enable_x64=False`): raise
`OverflowError` if any element exceeds float32 magnitude rather
than letting JAX silently saturate to ``±inf``.
- Up-cast or same-width cast: skip the range check. Precision loss
within range is not an error — it is an inherent consequence of
`jax_enable_x64=False`.

Args:
value: A Python float, numpy/JAX scalar, or array-like.
name: Qualified name of the leaf — surfaced in the error message.

Returns:
A JAX array at `canonical_float_dtype()` (0-d if `value` was a
scalar).

Raises:
OverflowError: If down-casting to `float32` would saturate any
element to `±inf`. The message names the leaf via `name`.

"""
target_dtype = canonical_float_dtype()
np_value = np.asarray(value)
if target_dtype == jnp.float32 and np_value.size > 0:
max_mag = float(np.max(np.abs(np_value)))
if max_mag > _FLOAT32_MAX:
msg = (
f"{name}: float32 overflow — max |value| {max_mag:g} "
f"exceeds float32 max {_FLOAT32_MAX:g}."
)
raise OverflowError(msg)
return jnp.asarray(np_value, dtype=target_dtype)
Loading
Loading