diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index a9364879..8b15efab 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -47,14 +47,30 @@ def _build() -> tuple[object, object, object]: - """Build the aca-baseline model, params, and initial conditions.""" + """Build the aca-baseline model, params, and initial conditions. + + aca_model and lcm imports are deferred to the function body — ASV's + forkserver runs `preimport` to discover benchmarks across every + `bench_*.py` module before forking workers. Importing JAX at module + top loads the multithreaded XLA backend into the forkserver; every + subsequent `os.fork()` inherits a corrupted CUDA context and the + first device op in the worker aborts with + `CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the + forkserver and confine it to the worker process. + """ + from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, get_benchmark_params, ) - model = create_benchmark_model(n_subjects=_N_SUBJECTS) + from lcm import DiscreteGrid + + model = create_benchmark_model( + n_subjects=_N_SUBJECTS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 diff --git a/pixi.lock b/pixi.lock index 4115222a..7507bb3e 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6#d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6#d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev195+ga908c8405.d20260505 - sha256: 44c6bd65422fdc0a7d3167cf852107aeca15bf6687a44b57a6749ad553943f11 + version: 0.0.2.dev256+g61c2436b6.d20260509 + sha256: 80f8e8823a9b7d58cdd78b45377e426887832da149342b1563ba0bfbe91653ea requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index bb10a893..b810d68c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } @@ -242,6 +242,12 @@ per-file-ignores."tests/*" = [ "S301", # Use of pickle "SLF001", # Private member access ] +per-file-ignores."tests/test_dtypes.py" = [ + "ARG001", # Unused function argument (x64_enabled / x64_disabled fixtures) +] +per-file-ignores."tests/test_float_dtype_invariants.py" = [ + "ARG001", # Unused function argument (x64_disabled fixture) +] per-file-ignores."tests/test_next_state.py" = [ "ARG001", # Unused function argument "ARG005", # Unused lambda argument diff --git a/src/lcm/ages.py b/src/lcm/ages.py index 8dee9c73..11d2c88c 100644 --- a/src/lcm/ages.py +++ b/src/lcm/ages.py @@ -10,7 +10,7 @@ import jax.numpy as jnp from lcm.exceptions import GridInitializationError, format_messages -from lcm.typing import Age, Float1D, Int1D +from lcm.typing import Float1D, Int1D STEP_UNITS: MappingProxyType[str, Fraction] = MappingProxyType( { @@ -129,7 +129,7 @@ def exact_step_size(self) -> int | Fraction | None: """ return self._exact_step_size - def period_to_age(self, period: int) -> Age: + def period_to_age(self, period: int) -> int | float: """Convert a period index to the corresponding age. Args: @@ -151,7 +151,7 @@ def period_to_age(self, period: int) -> Age: return int(self._values[period]) return float(self._values[period]) - def age_to_period(self, age: Age) -> int: + def age_to_period(self, age: float) -> int: """Convert an age to the corresponding period index. Args: @@ -172,12 +172,14 @@ def age_to_period(self, age: Age) -> int: raise ValueError(msg) from None @functools.cached_property - def _age_to_period_map(self) -> dict[Age, int]: + def _age_to_period_map(self) -> dict[int | float, int]: if self._is_integer: return {int(v): i for i, v in enumerate(self._exact_values)} return {float(v): i for i, v in enumerate(self._exact_values)} - def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]: + def get_periods_where( + self, predicate: Callable[[int | float], bool] + ) -> tuple[int, ...]: """Get period indices where predicate is True. Args: @@ -187,7 +189,7 @@ def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...] Tuple of period indices where predicate(age) is True. """ - _convert: Callable[[object], Age] = int if self._is_integer else float # ty: ignore[invalid-assignment] + _convert: Callable[[object], int | float] = int if self._is_integer else float # ty: ignore[invalid-assignment] return tuple( period for period in range(self.n_periods) diff --git a/src/lcm/dtypes.py b/src/lcm/dtypes.py index 51dd958c..f2cf3170 100644 --- a/src/lcm/dtypes.py +++ b/src/lcm/dtypes.py @@ -3,22 +3,32 @@ Used at every API boundary that accepts user data (params, initial conditions, regime-id arrays) — always called from Python, never inside JIT. Each helper validates that the value fits the target dtype and -raises a clearly-named error if not. - -Casts further down the simulate stack (e.g. transition outputs landing -in the state pool) use plain `.astype` and rely on the boundary cast -above them having already pinned the canonical dtype. +raises a clearly-named error if not. Once an input has crossed the +boundary it carries the canonical dtype unchanged through the simulate +stack; downstream code does not re-cast. """ +import jax import jax.numpy as jnp import numpy as np from jax import Array _INT32_MIN = int(np.iinfo(np.int32).min) _INT32_MAX = int(np.iinfo(np.int32).max) +_FLOAT32_MAX = float(np.finfo(np.float32).max) + + +def canonical_float_dtype() -> jnp.dtype: + """Return pylcm's canonical float dtype, derived from `jax_enable_x64`. + + Returns `jnp.float64` if `jax.config.jax_enable_x64` is True, + otherwise `jnp.float32`. The value is read at call time, not at + import, so toggling the JAX config (e.g. between tests) is honoured. + """ + return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32 -def safe_to_int32(value: object, *, name: str) -> Array: +def safe_to_int_dtype(value: object, *, name: str) -> Array: """Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range. Args: @@ -46,3 +56,41 @@ def safe_to_int32(value: object, *, name: str) -> Array: ) raise ValueError(msg) return jnp.asarray(np_value, dtype=jnp.int32) + + +def safe_to_float_dtype(value: object, *, name: str) -> Array: + """Cast a scalar, sequence, or array to the canonical float dtype. + + Range check fires only on a down-cast: + + - Down-cast (float64 → float32 under `jax_enable_x64=False`): raise + `OverflowError` if any element exceeds float32 magnitude rather + than letting JAX silently saturate to ``±inf``. + - Up-cast or same-width cast: skip the range check. Precision loss + within range is not an error — it is an inherent consequence of + `jax_enable_x64=False`. + + Args: + value: A Python float, numpy/JAX scalar, or array-like. + name: Qualified name of the leaf — surfaced in the error message. + + Returns: + A JAX array at `canonical_float_dtype()` (0-d if `value` was a + scalar). + + Raises: + OverflowError: If down-casting to `float32` would saturate any + element to `±inf`. The message names the leaf via `name`. + + """ + target_dtype = canonical_float_dtype() + np_value = np.asarray(value) + if target_dtype == jnp.float32 and np_value.size > 0: + max_mag = float(np.max(np.abs(np_value))) + if max_mag > _FLOAT32_MAX: + msg = ( + f"{name}: float32 overflow — max |value| {max_mag:g} " + f"exceeds float32 max {_FLOAT32_MAX:g}." + ) + raise OverflowError(msg) + return jnp.asarray(np_value, dtype=target_dtype) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index 9a147c44..839d4cef 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -7,12 +7,14 @@ import jax.numpy as jnp from jax import Array +from lcm.dtypes import canonical_float_dtype from lcm.exceptions import GridInitializationError, format_messages from lcm.grids import coordinates as grid_coordinates from lcm.grids.base import Grid from lcm.typing import ( Float1D, ScalarFloat, + ScalarInt, ) @@ -37,24 +39,39 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True, kw_only=True, init=False) class UniformContinuousGrid(ContinuousGrid, ABC): - """Grid with start/stop/n_points for linearly or logarithmically spaced values.""" - - start: int | float - """The start value of the grid.""" + """Grid with start/stop/n_points for linearly or logarithmically spaced values. - stop: int | float - """The stop value of the grid.""" - - n_points: int - """The number of points in the grid.""" + `start` and `stop` are stored as JAX scalars at `canonical_float_dtype()`, + `n_points` as a `jnp.int32` JAX scalar — converted from the Python + literals (or other numeric inputs) supplied at construction. + """ - def __post_init__(self) -> None: - _validate_continuous_grid( - start=self.start, - stop=self.stop, - n_points=self.n_points, + start: ScalarFloat + """The start value of the grid (JAX scalar at `canonical_float_dtype()`).""" + + stop: ScalarFloat + """The stop value of the grid (JAX scalar at `canonical_float_dtype()`).""" + + n_points: ScalarInt + """The number of points in the grid (`jnp.int32` JAX scalar).""" + + def __init__( + self, + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, + batch_size: int = 0, + ) -> None: + _init_uniform_grid( + self, + start=start, + stop=stop, + n_points=n_points, + batch_size=batch_size, + requires_positive_start=False, ) @abstractmethod @@ -109,7 +126,10 @@ def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_linspace_coordinate( - value=value, start=self.start, stop=self.stop, n_points=self.n_points + value=value, + start=self.start, + stop=self.stop, + n_points=self.n_points, ) @@ -124,11 +144,20 @@ class LogSpacedGrid(UniformContinuousGrid): """ - def __post_init__(self) -> None: - _validate_continuous_grid( - start=self.start, - stop=self.stop, - n_points=self.n_points, + def __init__( + self, + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, + batch_size: int = 0, + ) -> None: + _init_uniform_grid( + self, + start=start, + stop=stop, + n_points=n_points, + batch_size=batch_size, requires_positive_start=True, ) @@ -145,19 +174,58 @@ def get_coordinate(self, value: Array) -> Array: ... def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: """Return the generalized coordinate of a value in the grid.""" return grid_coordinates.get_logspace_coordinate( - value=value, start=self.start, stop=self.stop, n_points=self.n_points + value=value, + start=self.start, + stop=self.stop, + n_points=self.n_points, ) -@dataclass(frozen=True, kw_only=True) +def _init_uniform_grid( + grid: UniformContinuousGrid, + *, + start: float | ScalarFloat, + stop: float | ScalarFloat, + n_points: int | ScalarInt, + batch_size: int, + requires_positive_start: bool, +) -> None: + """Cast `start` / `stop` / `n_points` to canonical JAX scalars, validate, store. + + `jnp.asarray(..., dtype=canonical_float_dtype())` and `jnp.int32(...)` lift + every numeric input at the boundary: Python literals from user + construction, JAX scalars from `dataclasses.replace` round-trips, anything + else raises here. The validator can then assume strict `ScalarFloat` / + `ScalarInt` types and only check value invariants (finiteness, ordering, + positivity). + """ + dtype = canonical_float_dtype() + start_jax = jnp.asarray(start, dtype=dtype) + stop_jax = jnp.asarray(stop, dtype=dtype) + n_points_jax = jnp.int32(n_points) + _validate_continuous_grid( + start=start_jax, + stop=stop_jax, + n_points=n_points_jax, + requires_positive_start=requires_positive_start, + ) + object.__setattr__(grid, "start", start_jax) + object.__setattr__(grid, "stop", stop_jax) + object.__setattr__(grid, "n_points", n_points_jax) + object.__setattr__(grid, "batch_size", batch_size) + + +@dataclass(frozen=True, kw_only=True, init=False) class IrregSpacedGrid(ContinuousGrid): """A grid of continuous values at irregular (user-specified) points. This grid type is useful for representing non-uniformly spaced points such as Gauss-Hermite quadrature nodes. - When `points` is omitted and only `n_points` is given, the `points` must be - supplied at runtime via the params. + `points` is stored as a JAX array at `canonical_float_dtype()`, converted + from the Python sequence supplied at construction. When `points` is + omitted and only `n_points` is given, the points must be supplied at + runtime via the params. Example: -------- @@ -166,31 +234,44 @@ class IrregSpacedGrid(ContinuousGrid): """ - points: Sequence[float] | Float1D | None = None + points: Float1D | None """The grid points in ascending order, or `None` for runtime-supplied points.""" - n_points: int | None = None + n_points: int """Number of points. Derived from `len(points)` when points are given.""" - def __post_init__(self) -> None: - if self.points is not None: - _validate_irreg_spaced_grid(self.points) - # Derive n_points from points if not explicitly set - if self.n_points is None: - object.__setattr__(self, "n_points", len(self.points)) - elif self.n_points != len(self.points): + def __init__( + self, + *, + points: Sequence[float] | Float1D | None = None, + n_points: int | None = None, + batch_size: int = 0, + ) -> None: + if points is not None: + _validate_irreg_spaced_grid(points) + derived_n = len(points) + if n_points is None: + n_points = derived_n + elif n_points != derived_n: raise GridInitializationError( - f"n_points ({self.n_points}) does not match " - f"len(points) ({len(self.points)})" + f"n_points ({n_points}) does not match len(points) ({derived_n})" ) - elif self.n_points is None: + stored_points: Float1D | None = jnp.asarray( + points, dtype=canonical_float_dtype() + ) + elif n_points is None: raise GridInitializationError( "Either points or n_points must be specified for IrregSpacedGrid." ) - elif self.n_points < 2: # noqa: PLR2004 + elif n_points < 2: # noqa: PLR2004 raise GridInitializationError( - f"n_points must be at least 2, got {self.n_points}" + f"n_points must be at least 2, got {n_points}" ) + else: + stored_points = None + object.__setattr__(self, "points", stored_points) + object.__setattr__(self, "n_points", n_points) + object.__setattr__(self, "batch_size", batch_size) @property def pass_points_at_runtime(self) -> bool: @@ -215,7 +296,7 @@ def to_jax(self) -> Float1D: f"read from `.states[name]` or `.continuous_actions[name]`. " f"Use `.n_points` if only the shape is needed." ) - return jnp.asarray(self.points) + return self.points @overload def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @@ -229,18 +310,22 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: "initialization or use IrregSpacedGrid(n_points=...) and " "supply points at runtime via params." ) - return grid_coordinates.get_irreg_coordinate(value=value, points=self.to_jax()) + return grid_coordinates.get_irreg_coordinate(value=value, points=self.points) def _validate_continuous_grid( *, - start: float, - stop: float, - n_points: int, + start: ScalarFloat, + stop: ScalarFloat, + n_points: ScalarInt, requires_positive_start: bool = False, ) -> None: """Validate the continuous grid parameters. + `start` and `stop` are post-cast canonical-dtype JAX scalars (the + boundary cast in `_init_uniform_grid` already rejects non-numeric + inputs); the checks here cover only value invariants. + Args: start: The start value of the grid. stop: The stop value of the grid. @@ -254,32 +339,24 @@ def _validate_continuous_grid( """ error_messages = [] - valid_start_type = isinstance(start, int | float) - if not valid_start_type: - error_messages.append("start must be a scalar int or float value") - - valid_stop_type = isinstance(stop, int | float) - if not valid_stop_type: - error_messages.append("stop must be a scalar int or float value") - # Reject NaN/inf early — `start >= stop` returns False for NaN, so an # un-finite start would otherwise pass silently and produce a broken grid. - if valid_start_type and not jnp.isfinite(start): + start_finite = bool(jnp.isfinite(start)) + if not start_finite: error_messages.append(f"start must be finite, got {start}") - valid_start_type = False - if valid_stop_type and not jnp.isfinite(stop): + stop_finite = bool(jnp.isfinite(stop)) + if not stop_finite: error_messages.append(f"stop must be finite, got {stop}") - valid_stop_type = False - if not isinstance(n_points, int) or n_points < 1: + if n_points < 1: error_messages.append( f"n_points must be an int greater than 0 but is {n_points}", ) - if valid_start_type and valid_stop_type and start >= stop: + if start_finite and stop_finite and start >= stop: error_messages.append("start must be less than stop") - if valid_start_type and requires_positive_start and start <= 0: + if start_finite and requires_positive_start and start <= 0: error_messages.append( f"start must be > 0 for a log-spaced grid (got {start}); " f"`log(x)` is undefined for `x <= 0`." diff --git a/src/lcm/grids/coordinates.py b/src/lcm/grids/coordinates.py index 0edce4bd..dc95b40c 100644 --- a/src/lcm/grids/coordinates.py +++ b/src/lcm/grids/coordinates.py @@ -1,27 +1,9 @@ """Functions to generate and work with different kinds of grids. -Grid generation functions must have the following signature: - - Signature (start: ScalarFloat, stop: ScalarFloat, n_points: int) -> jax.Array - -They take start and end points and create a grid of points between them. - - -Interpolation info functions must have the following signature: - - Signature ( - value: ScalarFloat, - start: ScalarFloat, - stop: ScalarFloat, - n_points: int - ) -> ScalarInt - -They take the information required to generate a grid, and return an index corresponding -to the value, which is a point in the space but not necessarily a grid point. - -Some of the arguments will not be used by all functions but the aligned interface makes -it easy to call functions interchangeably. - +These helpers operate on JAX scalars (`ScalarFloat` for endpoints/values, +`ScalarInt` or Python `int` for `n_points`) — every production caller is +either a `Grid` method that has already converted user input to a JAX +scalar, or a piecewise dispatch that selects pieces via `searchsorted`. """ from typing import overload @@ -32,14 +14,19 @@ from lcm.typing import Float1D, ScalarFloat, ScalarInt -def linspace(*, start: ScalarFloat, stop: ScalarFloat, n_points: int) -> Float1D: +def linspace( + *, + start: ScalarFloat, + stop: ScalarFloat, + n_points: ScalarInt, +) -> Float1D: """Wrapper around jnp.linspace. Returns a linearly spaced grid between start and stop with n_points, including both endpoints. """ - return jnp.linspace(start, stop, n_points) + return jnp.linspace(start, stop, n_points) # ty: ignore[no-matching-overload] @overload @@ -70,7 +57,12 @@ def get_linspace_coordinate( return (value - start) / step_length -def logspace(*, start: ScalarFloat, stop: ScalarFloat, n_points: int) -> Float1D: +def logspace( + *, + start: ScalarFloat, + stop: ScalarFloat, + n_points: ScalarInt, +) -> Float1D: """Wrapper around jnp.logspace. Returns a logarithmically spaced grid between start and stop with n_points, @@ -87,7 +79,7 @@ def logspace(*, start: ScalarFloat, stop: ScalarFloat, n_points: int) -> Float1D """ start_linear = jnp.log(start) stop_linear = jnp.log(stop) - grid = jnp.logspace(start_linear, stop_linear, n_points, base=jnp.e) + grid = jnp.logspace(start_linear, stop_linear, n_points, base=jnp.e) # ty: ignore[invalid-argument-type] return grid.at[0].set(start).at[-1].set(stop) diff --git a/src/lcm/grids/piecewise.py b/src/lcm/grids/piecewise.py index e3a252f2..43ff2a91 100644 --- a/src/lcm/grids/piecewise.py +++ b/src/lcm/grids/piecewise.py @@ -13,12 +13,17 @@ Float1D, Int1D, ScalarFloat, + ScalarInt, ) -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True, kw_only=True, init=False) class Piece: - """A piece of a piecewise linearly spaced grid.""" + """A piece of a piecewise linearly spaced grid. + + `n_points` is stored as a `jnp.int32` JAX scalar, converted from the + Python literal supplied at construction. + """ interval: str | portion.Interval """The interval for this piece. @@ -26,8 +31,17 @@ class Piece: Can be a string like "[1, 4)" or a `portion.Interval`. """ - n_points: int - """The number of grid points in this piece.""" + n_points: ScalarInt + """The number of grid points in this piece (`jnp.int32` JAX scalar).""" + + def __init__( + self, + *, + interval: str | portion.Interval, + n_points: int | ScalarInt, + ) -> None: + object.__setattr__(self, "interval", interval) + object.__setattr__(self, "n_points", jnp.int32(n_points)) @dataclass(frozen=True, kw_only=True) @@ -69,14 +83,14 @@ def __post_init__(self) -> None: _init_piecewise_grid_cache(self) @property - def n_points(self) -> int: + def n_points(self) -> ScalarInt: """Return the total number of points in the grid.""" - return sum(p.n_points for p in self.pieces) + return self._piece_n_points.sum() def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" piece_arrays = [ - jnp.linspace(self._piece_starts[i], self._piece_stops[i], p.n_points) + jnp.linspace(self._piece_starts[i], self._piece_stops[i], p.n_points) # ty: ignore[no-matching-overload] for i, p in enumerate(self.pieces) ] return jnp.concatenate(piece_arrays) @@ -136,9 +150,9 @@ def __post_init__(self) -> None: _init_piecewise_grid_cache(self) @property - def n_points(self) -> int: + def n_points(self) -> ScalarInt: """Return the total number of points in the grid.""" - return sum(p.n_points for p in self.pieces) + return self._piece_n_points.sum() def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" @@ -267,7 +281,7 @@ def _validate_piecewise_lin_spaced_grid( # noqa: C901, PLR0912 ) continue - if not isinstance(piece.n_points, int) or piece.n_points < 2: # noqa: PLR2004 + if piece.n_points < 2: # noqa: PLR2004 error_messages.append( f"pieces[{i}].n_points must be an int >= 2, but is {piece.n_points}" ) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 59617eb6..c74842e6 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -271,13 +271,9 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac if points_key not in all_params: continue if in_states: - state_replacements[name] = cast( - "ContinuousState", all_params[points_key] - ) + state_replacements[name] = all_params[points_key] else: - action_replacements[name] = cast( - "ContinuousAction", all_params[points_key] - ) + action_replacements[name] = all_params[points_key] # `_ShockGrid` is state-only by construction (intrinsic # transitions, forbidden as actions per AGENTS.md). The # `in_states` gate makes that invariant explicit — a diff --git a/src/lcm/model.py b/src/lcm/model.py index a44abb83..23104c5b 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -25,7 +25,8 @@ initial_conditions_from_dataframe, ) from lcm.params.processing import ( - process_params, + broadcast_to_template, + cast_params_to_canonical_dtypes, ) from lcm.persistence import ( save_simulate_snapshot, @@ -118,21 +119,9 @@ class Model: _warned_n_subjects: set[int] """Mismatching `actual_n_subjects` already warned about (one warning each).""" - _simulate_compile_future: ( - Future[MappingProxyType[RegimeName, InternalRegime]] | None - ) - """Pending background AOT compile started by `solve(...)`, or `None`. - - `solve(...)` kicks off `compile_all_simulate_functions` in a single - background thread so XLA compilation overlaps with the GPU-bound - backward induction. `simulate(...)` awaits the future before - dispatching the AOT-compiled program. Cleared after the result lands - in `_simulate_compile_cache`. - """ - _simulate_compile_lock: threading.Lock - """Serialises mutations of `_simulate_compile_cache`, `_warned_n_subjects`, - and `_simulate_compile_future`. + """Serialises mutations of `_simulate_compile_cache` and + `_warned_n_subjects`. The check-then-set on each container is held under this lock. The consequent `log.warning` call sits outside the lock so concurrent @@ -180,7 +169,6 @@ def __init__( self.n_subjects = n_subjects self._simulate_compile_cache = {} self._warned_n_subjects = set() - self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() validate_model_inputs( @@ -216,16 +204,13 @@ def __getstate__(self) -> dict[str, object]: Drops `_simulate_compile_lock` (a `threading.Lock`, not pickleable), `_simulate_compile_cache` (compiled XLA programs that can't survive - a process boundary), `_warned_n_subjects` (its companion set), and - `_simulate_compile_future` (a `Future` tied to the originating thread - pool). + a process boundary), and `_warned_n_subjects` (its companion set). `__setstate__` restores all three to their fresh state. """ state = self.__dict__.copy() state.pop("_simulate_compile_lock", None) state.pop("_simulate_compile_cache", None) state.pop("_warned_n_subjects", None) - state.pop("_simulate_compile_future", None) return state def __setstate__(self, state: dict[str, object]) -> None: @@ -233,7 +218,6 @@ def __setstate__(self, state: dict[str, object]) -> None: self.__dict__.update(state) self._simulate_compile_cache = {} self._warned_n_subjects = set() - self._simulate_compile_future = None self._simulate_compile_lock = threading.Lock() def get_params_template(self) -> UserFacingParamsTemplate: @@ -297,17 +281,34 @@ def solve( internal_params=internal_params, ages=self.ages, ) - self._maybe_start_simulate_compile_async( + return self._solve_compiled( internal_params=internal_params, + params=params, + log=get_logger(log_level=log_level), + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, max_compilation_workers=max_compilation_workers, - logger=get_logger(log_level=log_level), ) + + def _solve_compiled( + self, + *, + internal_params: InternalParams, + params: UserParams, + log: logging.Logger, + log_level: LogLevel, + log_path: str | Path | None, + log_keep_n_latest: int, + max_compilation_workers: int | None, + ) -> MappingProxyType[int, MappingProxyType[RegimeName, FloatND]]: + """Run backward induction, persisting a snapshot on debug or NaN failure.""" try: period_to_regime_to_V_arr = solve( internal_params=internal_params, ages=self.ages, internal_regimes=self.internal_regimes, - logger=get_logger(log_level=log_level), + logger=log, enable_jit=self.enable_jit, max_compilation_workers=max_compilation_workers, ) @@ -332,61 +333,55 @@ def solve( ) return period_to_regime_to_V_arr - def _maybe_start_simulate_compile_async( + def _spawn_simulate_compile( self, *, + n_subjects: int, internal_params: InternalParams, max_compilation_workers: int | None, logger: logging.Logger, - ) -> None: - """Spawn `compile_all_simulate_functions` in a background thread. + ) -> Future[MappingProxyType[RegimeName, InternalRegime]]: + """Submit `compile_all_simulate_functions` to a single-thread executor. - Called from `solve(...)` so the simulate-side XLA compilation runs in - parallel with the GPU-bound backward induction. No-op when - `n_subjects is None`, when the cache for this size is already - populated, or when a compile is already in flight. + Caller decides whether to spawn (`n_subjects` set, batch shape + matches, no cache hit). The returned `Future` runs in parallel with + whatever the caller does next — typically `_solve_compiled(...)`. """ - if self.n_subjects is None: - return - with self._simulate_compile_lock: - if self.n_subjects in self._simulate_compile_cache: - return - if self._simulate_compile_future is not None: - return - executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="lcm-simulate-compile" - ) - self._simulate_compile_future = executor.submit( - compile_all_simulate_functions, - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=self.n_subjects, - max_compilation_workers=max_compilation_workers, - logger=logger, - ) - executor.shutdown(wait=False) + executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="lcm-simulate-compile" + ) + future = executor.submit( + compile_all_simulate_functions, + internal_regimes=self.internal_regimes, + internal_params=internal_params, + ages=self.ages, + n_subjects=n_subjects, + max_compilation_workers=max_compilation_workers, + logger=logger, + ) + executor.shutdown(wait=False) + return future def _resolve_simulate_internal_regimes( self, *, + compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None, actual_n_subjects: int, - internal_params: InternalParams, log: logging.Logger, - max_compilation_workers: int | None, ) -> MappingProxyType[RegimeName, InternalRegime]: """Return internal_regimes to use for simulate; AOT cache when matching. - Three dispatch cases: + Dispatch by `n_subjects` and batch-shape match: - `n_subjects is None`: return the original `internal_regimes` (purely lazy path). - - `actual_n_subjects != n_subjects`: return the original - `internal_regimes` and log a warning the first time each - mismatching size is seen. - - `actual_n_subjects == n_subjects`: return the cached AOT-compiled - regimes. If `solve(...)` started a background compile, await it - here; otherwise compile synchronously. + - `actual_n_subjects != n_subjects`: warn once per mismatching size, + return the original `internal_regimes`. + - `actual_n_subjects == n_subjects`, `compile_future is not None`: + await it and cache the result. + - `actual_n_subjects == n_subjects`, `compile_future is None`: cache + must already hold the entry (caller spawned only on cache miss); + return the cached compiled regimes. """ if self.n_subjects is None: return self.internal_regimes @@ -403,28 +398,12 @@ def _resolve_simulate_internal_regimes( self.n_subjects, ) return self.internal_regimes - with self._simulate_compile_lock: - if self.n_subjects in self._simulate_compile_cache: - return self._simulate_compile_cache[self.n_subjects] - future = self._simulate_compile_future - if future is not None: - compiled = future.result() + if compile_future is not None: + compiled = compile_future.result() with self._simulate_compile_lock: self._simulate_compile_cache[self.n_subjects] = compiled - self._simulate_compile_future = None return compiled with self._simulate_compile_lock: - if self.n_subjects not in self._simulate_compile_cache: - self._simulate_compile_cache[self.n_subjects] = ( - compile_all_simulate_functions( - internal_regimes=self.internal_regimes, - internal_params=internal_params, - ages=self.ages, - n_subjects=self.n_subjects, - max_compilation_workers=max_compilation_workers, - logger=log, - ) - ) return self._simulate_compile_cache[self.n_subjects] def simulate( @@ -507,33 +486,35 @@ def simulate( ages=self.ages, ) log = get_logger(log_level=log_level) - if period_to_regime_to_V_arr is None: - try: - period_to_regime_to_V_arr = solve( + actual_n_subjects = len(next(iter(initial_conditions.values()))) + n_subjects = self.n_subjects + compile_future: Future[MappingProxyType[RegimeName, InternalRegime]] | None = ( + None + ) + if n_subjects is not None and n_subjects == actual_n_subjects: + with self._simulate_compile_lock: + needs_compile = n_subjects not in self._simulate_compile_cache + if needs_compile: + compile_future = self._spawn_simulate_compile( + n_subjects=n_subjects, internal_params=internal_params, - ages=self.ages, - internal_regimes=self.internal_regimes, - logger=log, - enable_jit=self.enable_jit, max_compilation_workers=max_compilation_workers, + logger=log, ) - except InvalidValueFunctionError as exc: - if log_path is not None and exc.partial_solution is not None: - snap_dir = save_solve_snapshot( - model=self, - params=params, - period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] - log_path=Path(log_path), - log_keep_n_latest=log_keep_n_latest, - ) - exc.add_note(f"Snapshot saved to {snap_dir}") - raise - actual_n_subjects = len(next(iter(initial_conditions.values()))) + if period_to_regime_to_V_arr is None: + period_to_regime_to_V_arr = self._solve_compiled( + internal_params=internal_params, + params=params, + log=log, + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, + max_compilation_workers=max_compilation_workers, + ) simulate_internal_regimes = self._resolve_simulate_internal_regimes( + compile_future=compile_future, actual_n_subjects=actual_n_subjects, - internal_params=internal_params, log=log, - max_compilation_workers=max_compilation_workers, ) result = simulate( internal_params=internal_params, @@ -546,6 +527,13 @@ def simulate( simulation_output_dtypes=self.simulation_output_dtypes, seed=seed, ) + # AOT-compiled regimes carry `jax.stages.Compiled` callables that + # wrap an unpicklable `LoadedExecutable`. `to_dataframe` only reads + # the lazy DAG functions / constraints / transitions on + # `simulate_functions`, never the compiled callables — so swap in + # the lazy regimes to keep the result cloudpickle-safe. + if simulate_internal_regimes is not self.internal_regimes: + result._internal_regimes = self.internal_regimes # noqa: SLF001 if log_level == "debug" and log_path is not None: save_simulate_snapshot( model=self, @@ -559,9 +547,15 @@ def simulate( return result def _process_params(self, params: UserParams) -> InternalParams: - """Broadcast, convert Series, and validate user params.""" - internal_params = process_params( - params=params, params_template=self._params_template + """Broadcast, convert Series, dtype-cast, and validate user params. + + Step order matters: `convert_series_in_params` runs *between* + `broadcast_to_template` and `cast_params_to_canonical_dtypes` so + the dtype cast walks a uniform tree (no `pd.Series` to special- + case). + """ + internal_params = broadcast_to_template( + params=params, template=self._params_template, required=True ) if has_series(internal_params): internal_params = convert_series_in_params( @@ -570,6 +564,7 @@ def _process_params(self, params: UserParams) -> InternalParams: regimes=self.regimes, regime_names_to_ids=self.regime_names_to_ids, ) + internal_params = cast_params_to_canonical_dtypes(internal_params) _validate_param_types(internal_params) return internal_params diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index d141c6a3..4f149736 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -20,6 +20,7 @@ from lcm.params import MappingLeaf from lcm.params.processing import ( broadcast_to_template, + cast_params_to_canonical_dtypes, create_params_template, ) from lcm.params.sequence_leaf import SequenceLeaf @@ -128,6 +129,7 @@ def _build_regimes_and_template_with_fixed_params( regimes=regimes, regime_names_to_ids=regime_names_to_ids, ) + fixed_internal = cast_params_to_canonical_dtypes(fixed_internal) _validate_param_types(fixed_internal) return ( @@ -421,11 +423,11 @@ def _filter_kwargs_for_func( def _validate_param_types(internal_params: InternalParams) -> None: - """Raise if any param leaf is not a Python scalar or JAX array. + """Raise if any param leaf is not a JAX `Array` or container leaf. - After processing, every leaf value (including inside MappingLeaf / - SequenceLeaf containers) must be a Python scalar (float, int, bool) or a - JAX array. Notably, numpy arrays and pandas Series are not accepted. + Defense-in-depth check after `cast_params_to_canonical_dtypes`: by the + time this runs, every leaf must be a JAX `Array`, or a `MappingLeaf` / + `SequenceLeaf` whose contents recursively satisfy the same rule. """ for regime_name, regime_params in internal_params.items(): for key, value in regime_params.items(): @@ -433,7 +435,7 @@ def _validate_param_types(internal_params: InternalParams) -> None: def _check_leaf(value: object, path: str) -> None: - """Check a single leaf value, recursing into MappingLeaf/SequenceLeaf.""" + """Check a single leaf, recursing into `MappingLeaf` / `SequenceLeaf`.""" if isinstance(value, MappingLeaf): for k, v in value.data.items(): _check_leaf(v, f"{path}.{k}") @@ -442,17 +444,8 @@ def _check_leaf(value: object, path: str) -> None: for i, v in enumerate(value.data): _check_leaf(v, f"{path}[{i}]") return - if isinstance(value, (float, int, bool)): + if isinstance(value, Array): return - if hasattr(value, "dtype") and hasattr(value, "shape"): - if isinstance(value, Array): - return - type_name = type(value).__module__ + "." + type(value).__name__ - msg = ( - f"Parameter '{path}' is a {type_name} (shape {value.shape}). " - f"Use jnp.array() or pass a pd.Series with a named index." - ) - raise InvalidParamsError(msg) type_name = type(value).__module__ + "." + type(value).__name__ - msg = f"Parameter '{path}' has unexpected type {type_name}." + msg = f"Parameter {path!r} is a {type_name}, expected a JAX Array." raise InvalidParamsError(msg) diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index cf1b4cb0..adc04eac 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -12,6 +12,7 @@ from jax import Array from lcm.ages import PSEUDO_STATE_NAMES, AgeGrid +from lcm.dtypes import canonical_float_dtype from lcm.grids import DiscreteGrid, IrregSpacedGrid from lcm.params import MappingLeaf from lcm.params.sequence_leaf import SequenceLeaf @@ -149,7 +150,7 @@ def initial_conditions_from_dataframe( # noqa: C901 initial_conditions: dict[str, Array] = { col: jnp.array(arr, dtype=jnp.int32) if col in discrete_state_names - else jnp.array(arr) + else jnp.array(arr, dtype=canonical_float_dtype()) for col, arr in result_arrays.items() } initial_conditions["regime"] = jnp.array( @@ -371,12 +372,12 @@ def array_from_series( """ if func is None: - return jnp.array(sr.to_numpy(), dtype=float) + return jnp.array(sr.to_numpy(), dtype=canonical_float_dtype()) indexing_params = _get_func_indexing_params(func=func, array_param_name=param_name) if not indexing_params: - return jnp.array(sr.to_numpy(), dtype=float) + return jnp.array(sr.to_numpy(), dtype=canonical_float_dtype()) grids = _resolve_categoricals( regimes=regimes, @@ -706,7 +707,7 @@ def _scatter_series( shape = [m.size for m in level_mappings] if len(series) == 0: - return jnp.full(shape, fill_value) + return jnp.full(shape, fill_value, dtype=canonical_float_dtype()) index_arrays = [ _map_level( @@ -717,7 +718,7 @@ def _scatter_series( result = np.full(shape, fill_value) result[tuple(index_arrays)] = series.to_numpy() - return jnp.array(result) + return jnp.array(result, dtype=canonical_float_dtype()) def _map_level(*, mapping: _LevelMapping, level_values: pd.Index) -> np.ndarray: diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index ae937c44..d6afbccd 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -1,10 +1,23 @@ """Process user-provided params into internal params. `process_params` resolves user-supplied parameters against the model's -template, then runs a boundary-cast pass that normalises every integer -leaf — Python `int`, typed JAX integer arrays, numpy integer arrays, -and integers inside `MappingLeaf` / `SequenceLeaf` — to `jnp.int32`. -Out-of-range values surface as `ValueError` with the offending leaf's +template, then runs a boundary-cast pass that normalises every numeric +leaf to a canonical pylcm dtype: + +- Python `bool` (and `np.bool_` arrays) cast to `jnp.bool_`. +- Python `int` and typed integer arrays cast to `jnp.int32`. Out-of- + range values surface as `ValueError`. +- Python `float` and typed float arrays cast to `canonical_float_dtype()`. + Down-cast overflow surfaces as `OverflowError`. +- `MappingLeaf` / `SequenceLeaf` containers recurse. + +The pass runs as the *last* step over `internal_params` — `pd.Series` +leaves are reshaped to JAX arrays via `convert_series_in_params` +beforehand, so by the time the cast walks the tree, every numeric leaf +is either a JAX array, a numpy array, or a Python scalar. + +Anything else (`pd.Series` (defensive), strings, complex/object arrays, +custom objects) raises `InvalidParamsError` with the offending leaf's qualified name. """ @@ -12,11 +25,13 @@ from types import MappingProxyType from typing import Any, cast +import jax.numpy as jnp import numpy as np +import pandas as pd from dags.tree import QNAME_DELIMITER, qname_from_tree_path, tree_path_from_qname from jax import Array -from lcm.dtypes import safe_to_int32 +from lcm.dtypes import safe_to_float_dtype, safe_to_int_dtype from lcm.exceptions import InvalidNameError, InvalidParamsError from lcm.interfaces import InternalRegime from lcm.params.mapping_leaf import MappingLeaf @@ -47,11 +62,16 @@ def process_params( - Regime level: `{"regime_0": {"arg_0": 0.0}}` — propagates within regime_0 - Function level: `{"regime_0": {"func": {"arg_0": 0.0}}}` — direct specification - The output always matches the params_template skeleton. Every integer - leaf — Python `int`, typed JAX or numpy integer arrays, and integers - inside `MappingLeaf` / `SequenceLeaf` — is cast to `jnp.int32` so the - AOT signature is stable across calls. Python `bool` and float leaves - are handled by the float-side cast pass. + The output always matches the params_template skeleton. Every numeric + leaf — Python `bool` / `int` / `float`, typed JAX or numpy arrays, and + numerics inside `MappingLeaf` / `SequenceLeaf` — is cast to the + canonical pylcm dtype so the AOT signature is stable across calls. + + Callers that pass `pd.Series` leaves should orchestrate the steps + themselves: `broadcast_to_template` (resolve), `convert_series_in_params` + (multi-index reshape), then `cast_params_to_canonical_dtypes`. The + one-shot `process_params` raises on `pd.Series` because the dtype + cast does not know how to reshape multi-index data. Args: params: User-provided parameters dictionary. @@ -61,13 +81,19 @@ def process_params( Immutable mapping with the same structure as params_template. Raises: - InvalidParamsError: If params contains unexpected keys or type mismatches. + InvalidParamsError: If params contains unexpected keys, type + mismatches, or unsupported leaf types. InvalidNameError: If the same parameter is specified at multiple levels. ValueError: If a typed integer leaf carries a value outside the int32 range; the message names the offending parameter qname. + OverflowError: If a typed float leaf would saturate to `±inf` on + down-cast to `float32`; the message names the offending qname. """ - return broadcast_to_template(params=params, template=params_template, required=True) + internal = broadcast_to_template( + params=params, template=params_template, required=True + ) + return cast_params_to_canonical_dtypes(internal) def broadcast_to_template( @@ -84,6 +110,10 @@ def broadcast_to_template( 2. Regime level: `regime__param` 3. Model level: `param` + Returns the resolved structure with leaves left as the user supplied + them; dtype canonicalisation is a separate step + (`cast_params_to_canonical_dtypes`). + Args: params: User-provided values at any nesting depth. template: Target structure defining all valid 3-part keys. @@ -129,62 +159,118 @@ def broadcast_to_template( if unknown: raise InvalidParamsError(f"Unknown keys: {sorted(unknown)}") - for regime, leaves in result.items(): - for param_qname, value in leaves.items(): - leaves[param_qname] = _cast_int_leaves_to_int32( - value, name=f"{regime}{QNAME_DELIMITER}{param_qname}" - ) - return cast( "InternalParams", MappingProxyType({k: MappingProxyType(v) for k, v in result.items()}), ) -def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 - """Normalise integer leaves in a params value to `jnp.int32`. +def cast_params_to_canonical_dtypes(internal_params: InternalParams) -> InternalParams: + """Cast every numeric leaf of `internal_params` to its canonical pylcm dtype. + + Runs as a separate pass so the orchestrator can interpose + `convert_series_in_params` between broadcast and cast — by the time + this pass walks the tree, no `pd.Series` leaf should remain. + + Args: + internal_params: Output of `broadcast_to_template`, optionally + after `convert_series_in_params`. + + Returns: + New immutable mapping with every leaf cast to its canonical dtype. + + """ + return cast( + "InternalParams", + MappingProxyType( + { + regime: MappingProxyType( + { + param_qname: _cast_leaves_to_canonical_dtype( + value, name=f"{regime}{QNAME_DELIMITER}{param_qname}" + ) + for param_qname, value in leaves.items() + } + ) + for regime, leaves in internal_params.items() + } + ), + ) + + +def _cast_leaves_to_canonical_dtype(value: Any, *, name: str) -> Any: # noqa: ANN401, C901, PLR0911 + """Cast a single params leaf to its canonical pylcm dtype. + + Strict whitelist — every code path either casts or raises. Casts: - - Python `int` scalars — to `jnp.int32` so the DAG sees a JAX scalar - with a pinned dtype rather than a Python int that JAX would - otherwise promote per call site. - - Typed JAX or numpy integer arrays (`jnp.array(..., dtype=jnp.int64)`, - `np.array(...)`) — cast to `int32` to keep the AOT signature stable. - - Integer leaves inside `MappingLeaf` / `SequenceLeaf` — recurse. + - `MappingLeaf` / `SequenceLeaf`: recurse on contents. + - Python `bool`: `jnp.bool_(value)` (must come before `int` — + `True` is a Python `int` subclass). + - Python `int`: `safe_to_int_dtype(value)` → `jnp.int32`. + - Python `float`: `safe_to_float_dtype(value)` → canonical float. + - JAX or numpy array, dispatch on `dtype.kind`: + - `"b"` (bool) → `jnp.asarray(..., dtype=jnp.bool_)`. + - `"i"` / `"u"` (signed/unsigned int) → `safe_to_int_dtype`. + - `"f"` (float) → `safe_to_float_dtype`. + + Raises `InvalidParamsError` for: - Passes through unchanged: + - `pd.Series`: defensive — the orchestrator must run + `convert_series_in_params` before this pass. + - Array dtypes other than bool/int/float (e.g. complex, object, + string). + - Anything else (`str`, `None`, `dict`, lists, custom objects). - - Python `bool` scalars — handled by the float-side cast pass once - it lands. - - Float and non-numeric typed leaves — handled by a separate float- - normalisation pass. """ if isinstance(value, MappingLeaf): return MappingLeaf( { - k: _cast_int_leaves_to_int32(v, name=f"{name}.{k}") + k: _cast_leaves_to_canonical_dtype(v, name=f"{name}.{k}") for k, v in value.data.items() } ) if isinstance(value, SequenceLeaf): return SequenceLeaf( [ - _cast_int_leaves_to_int32(v, name=f"{name}[{i}]") + _cast_leaves_to_canonical_dtype(v, name=f"{name}[{i}]") for i, v in enumerate(value.data) ] ) - # `bool` is a subclass of `int`, so test for it first and short-circuit - # — bool handling lands with the float-side cast pass, not here. + if isinstance(value, pd.Series): + msg = ( + f"{name!r}: pd.Series leaf reached the dtype cast — " + f"`convert_series_in_params` must run between " + f"`broadcast_to_template` and `cast_params_to_canonical_dtypes`." + ) + raise InvalidParamsError(msg) + # `bool` before `int` — `True` is a Python `int` subclass. if isinstance(value, bool): - return value + return jnp.bool_(value) if isinstance(value, int): - return safe_to_int32(value, name=name) - if isinstance(value, (Array, np.ndarray)) and np.issubdtype( - value.dtype, np.integer - ): - return safe_to_int32(value, name=name) - return value + return safe_to_int_dtype(value, name=name) + if isinstance(value, float): + return safe_to_float_dtype(value, name=name) + if isinstance(value, (Array, np.ndarray)): + kind = value.dtype.kind + if kind == "b": + return jnp.asarray(value, dtype=jnp.bool_) + if kind in ("i", "u"): + return safe_to_int_dtype(value, name=name) + if kind == "f": + return safe_to_float_dtype(value, name=name) + msg = ( + f"{name!r}: array dtype {value.dtype} not supported " + f"(expected bool / int / float)." + ) + raise InvalidParamsError(msg) + msg = ( + f"{name!r}: unsupported leaf type {type(value).__name__} " + f"(expected bool / int / float / numpy or JAX array / " + f"MappingLeaf / SequenceLeaf)." + ) + raise InvalidParamsError(msg) def _find_candidates( diff --git a/src/lcm/persistence.py b/src/lcm/persistence.py index 6f746173..c8498a19 100644 --- a/src/lcm/persistence.py +++ b/src/lcm/persistence.py @@ -219,7 +219,10 @@ def save_simulate_snapshot( _save_pkl(snap_dir / "model.pkl", model) _save_pkl(snap_dir / "params.pkl", params) _save_pkl(snap_dir / "initial_conditions.pkl", initial_conditions) - _save_pkl(snap_dir / "result.pkl", _strip_V_arr_from_result(result)) + _save_pkl( + snap_dir / "result.pkl", + _strip_V_arr_from_result(result=result, model=model), + ) _save_h5(snap_dir / "arrays.h5", period_to_regime_to_V_arr) _write_metadata( snap_dir, @@ -290,14 +293,22 @@ def _find_project_root() -> Path | None: return None -def _strip_V_arr_from_result(result: SimulationResult) -> SimulationResult: - """Create a copy of result with value arrays replaced by an empty mapping. - - Avoid storing period_to_regime_to_V_arr both in the pickle and in the HDF5 file. +def _strip_V_arr_from_result( + *, result: SimulationResult, model: Model +) -> SimulationResult: + """Create a copy of result with value arrays and compiled callables stripped. + `period_to_regime_to_V_arr` is dropped to avoid storing it both in the + pickle and in the HDF5 file. `_internal_regimes` is overwritten with the + model's lazy-path `internal_regimes`: when `Model(n_subjects=N)` is set + the result carries the AOT-compiled regimes, whose + `jax.stages.Compiled` callables hold a `LoadedExecutable` that cannot + be pickled. The lazy regimes carry the same metadata and cloud-pickle + cleanly (model.pkl uses the same set). """ stripped = copy.copy(result) object.__setattr__(stripped, "_period_to_regime_to_V_arr", MappingProxyType({})) + object.__setattr__(stripped, "_internal_regimes", model.internal_regimes) return stripped diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 69ca8d9d..3db3063e 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -16,6 +16,7 @@ from jax import numpy as jnp from lcm.ages import PSEUDO_STATE_NAMES, AgeGrid +from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype from lcm.exceptions import ( InvalidInitialConditionsError, format_messages, @@ -77,9 +78,14 @@ def build_initial_states( n_subjects, MISSING_CAT_CODE, dtype=target_dtype ) elif state_name in initial_states: - flat[key] = initial_states[state_name] + # Cast user-supplied continuous states to the canonical float + # dtype so the simulate state pool has one signature across + # periods regardless of the user-supplied dtype. + flat[key] = safe_to_float_dtype( + initial_states[state_name], name=f"initial_states.{state_name}" + ) else: - flat[key] = jnp.full(n_subjects, jnp.nan) + flat[key] = jnp.full(n_subjects, jnp.nan, dtype=canonical_float_dtype()) return MappingProxyType(flat) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d1ab42ab..12040623 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -31,6 +31,8 @@ IntND, RegimeName, RegimeNamesToIds, + ScalarFloat, + ScalarInt, ) from lcm.utils.error_handling import validate_V from lcm.utils.logging import ( @@ -199,7 +201,7 @@ def _simulate_regime_in_period( regime_name: RegimeName, internal_regime: InternalRegime, period: int, - age: float, + age: ScalarInt | ScalarFloat, states: MappingProxyType[str, Array], subject_regime_ids: Int1D, new_subject_regime_ids: Int1D, diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index ba7cc39c..74f54896 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -25,6 +25,8 @@ Int1D, RegimeName, RegimeNamesToIds, + ScalarFloat, + ScalarInt, ) from lcm.utils.namespace import flatten_regime_namespace @@ -70,7 +72,7 @@ def calculate_next_states( internal_regime: InternalRegime, optimal_actions: MappingProxyType[ActionName, Array], period: int, - age: float, + age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, states: MappingProxyType[str, Array], state_action_space: StateActionSpace, @@ -148,7 +150,7 @@ def calculate_next_regime_membership( state_action_space: StateActionSpace, optimal_actions: MappingProxyType[ActionName, Array], period: int, - age: float, + age: ScalarInt | ScalarFloat, regime_params: FlatRegimeParams, regime_names_to_ids: MappingProxyType[RegimeName, int], new_subject_regime_ids: Int1D, @@ -287,19 +289,9 @@ def _update_states_for_subjects( for target, target_next_states in computed_next_states.items(): for next_state_name, next_state_values in target_next_states.items(): state_name = f"{target}__{next_state_name.removeprefix('next_')}" - target_dtype = all_states[state_name].dtype - # Preserve storage dtype only when the transition output is the - # same numeric kind. Across kinds (e.g. int storage + float - # transition output) leave JAX's promotion in place; the - # cross-kind boundary cast belongs to Package B. - new_values = ( - next_state_values.astype(target_dtype) - if next_state_values.dtype.kind == target_dtype.kind - else next_state_values - ) updated_states[state_name] = jnp.where( subject_indices, - new_values, + next_state_values, all_states[state_name], ) diff --git a/src/lcm/typing.py b/src/lcm/typing.py index da6815cf..9553ebea 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -22,13 +22,13 @@ type Int1D = Int32[Array, "_"] # noqa: F821 type Bool1D = Bool[Array, "_"] # noqa: F821 -# Many JAX functions are designed to work with scalar numerical values. This also -# includes zero dimensional jax arrays. -type ScalarInt = int | Int32[Scalar, ""] -type ScalarFloat = float | Float[Scalar, ""] +# Zero-dimensional JAX scalars — pylcm's canonical scalar form post boundary cast. +type ScalarInt = Int32[Scalar, ""] +type ScalarFloat = Float[Scalar, ""] +type ScalarBool = Bool[Scalar, ""] -type Period = int | Int1D -type Age = int | float +type Period = ScalarInt +type Age = ScalarInt | ScalarFloat type RegimeName = str type StateName = str type ActionName = str @@ -54,7 +54,7 @@ # Internal regime parameters: A flat mapping with function-qualified names. # Keys are always function-qualified (e.g., "utility__risk_aversion", # "H__discount_factor"). Values are scalars or arrays. -type FlatRegimeParams = MappingProxyType[str, bool | float | Array] +type FlatRegimeParams = MappingProxyType[str, Array] type InternalParams = MappingProxyType[RegimeName, FlatRegimeParams] # Immutable templates, used internally diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index faeed9fe..5d89a054 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -39,7 +39,7 @@ def validate_V( *, V_arr: Array, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, regime_name: RegimeName | None = None, partial_solution: object = None, compute_intermediates: Callable | None = None, @@ -287,8 +287,8 @@ def validate_regime_transition_probs( regime_transition_probs: MappingProxyType[str, Array], active_regimes_next_period: tuple[RegimeName, ...], regime_name: RegimeName, - age: ScalarInt | ScalarFloat, - next_age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, + next_age: float | ScalarInt | ScalarFloat, state_action_values: MappingProxyType[str, Array] | None = None, ) -> None: """Validate regime transition probabilities. @@ -543,7 +543,7 @@ def _validate_no_reachable_incomplete_targets( regime_transition_probs: MappingProxyType[str, Array], active_regimes_next_period: tuple[RegimeName, ...], regime_name: RegimeName, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, ) -> None: """Check that targets with incomplete stochastic transitions are unreachable. diff --git a/src/lcm/utils/logging.py b/src/lcm/utils/logging.py index a339af67..c924efe8 100644 --- a/src/lcm/utils/logging.py +++ b/src/lcm/utils/logging.py @@ -60,7 +60,7 @@ def log_nan_in_V( *, logger: logging.Logger, regime_name: str, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, V_arr: FloatND, ) -> None: """Log a warning if V_arr contains NaN or Inf values. @@ -79,7 +79,7 @@ def log_nan_in_V( def log_period_header( *, logger: logging.Logger, - age: ScalarInt | ScalarFloat, + age: float | ScalarInt | ScalarFloat, n_active_regimes: int, ) -> None: """Log the start of a period. diff --git a/tests/conftest.py b/tests/conftest.py index 5d9e627a..b889345f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ +from collections.abc import Iterator from dataclasses import make_dataclass import pytest +from jax import config as jax_config # Module-level precision settings (updated by pytest_configure based on --precision) X64_ENABLED: bool = True @@ -28,11 +30,31 @@ def pytest_configure(config): X64_ENABLED = config.getoption("--precision") == "64" DECIMAL_PRECISION = 12 if X64_ENABLED else 5 - from jax import config as jax_config # noqa: PLC0415 - jax_config.update("jax_enable_x64", val=X64_ENABLED) @pytest.fixture(scope="session") def binary_category_class(): return make_dataclass("BinaryCategoryClass", [("cat0", int, 0), ("cat1", int, 1)]) + + +@pytest.fixture(name="x64_disabled") +def _fixture_x64_disabled() -> Iterator[None]: + """Run the test with `jax_enable_x64=False`, restoring afterwards.""" + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=False) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) + + +@pytest.fixture(name="x64_enabled") +def _fixture_x64_enabled() -> Iterator[None]: + """Run the test with `jax_enable_x64=True`, restoring afterwards.""" + previous = jax_config.read("jax_enable_x64") + jax_config.update("jax_enable_x64", val=True) + try: + yield + finally: + jax_config.update("jax_enable_x64", val=previous) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 1ead8fa2..cb434a72 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -64,10 +64,10 @@ def test_simulate_using_raw_inputs(simulate_inputs): { "working_life": MappingProxyType( { - "H__discount_factor": 1.0, - "utility__disutility_of_work": 1.0, - "next_wealth__interest_rate": 0.05, - "next_regime__final_age_alive": 0, + "H__discount_factor": jnp.asarray(1.0), + "utility__disutility_of_work": jnp.asarray(1.0), + "next_wealth__interest_rate": jnp.asarray(0.05), + "next_regime__final_age_alive": jnp.asarray(0), } ), "dead": MappingProxyType({}), diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index 660d99db..698256e4 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -156,21 +156,20 @@ def test_simulate_first_matching_call_populates_aot_cache() -> None: assert n_subjects in model._simulate_compile_cache -def test_solve_with_n_subjects_kicks_off_background_simulate_compile() -> None: - """`solve(...)` spawns the simulate AOT compile in the background. +def test_solve_does_not_populate_simulate_compile_cache() -> None: + """`solve(...)` does not touch simulate-side compile state. - The follow-on `simulate(...)` then awaits the in-flight `Future` instead - of compiling synchronously, so XLA compilation overlaps with the - GPU-bound backward induction in production. + Simulate AOT compilation is driven entirely by `simulate(...)`; calling + `solve(...)` alone leaves `_simulate_compile_cache` empty. """ n_periods = 3 n_subjects = 4 model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) params = get_params(n_periods=n_periods) - assert model._simulate_compile_future is None model.solve(params=params) - assert model._simulate_compile_future is not None + + assert dict(model._simulate_compile_cache) == {} _DECLARED_N = 4 @@ -270,6 +269,30 @@ def test_simulate_warns_only_once_per_mismatching_size( assert len(mismatch_warnings) == 1 +def test_simulate_result_pickles_when_n_subjects_matches() -> None: + """`simulate(...)` returns a result that round-trips through cloudpickle. + + With `n_subjects` matching the batch shape, the simulate path runs + AOT-compiled callables that wrap `LoadedExecutable` (unpicklable). + `to_dataframe` doesn't need those callables, so the returned result + must carry the lazy regimes — otherwise downstream pickling + (e.g. pytask handing the result to the next task) fails. + """ + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + + result = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=n_subjects), + ) + + restored = cloudpickle.loads(cloudpickle.dumps(result)) + assert restored.n_subjects == n_subjects + + def test_unpickled_model_can_simulate_with_aot() -> None: """A cloudpickle round-tripped `Model` still drives `simulate(...)` with AOT.""" n_periods = 3 diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 5538e5e1..9b6095d1 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -52,7 +52,7 @@ def test_solve_brute(): # ================================================================================== # create the params # ================================================================================== - internal_params = MappingProxyType({"discount_factor": 0.9}) + internal_params = MappingProxyType({"discount_factor": jnp.asarray(0.9)}) # ================================================================================== # create the list of state_action_spaces diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 43894cbd..8121a021 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from lcm.dtypes import safe_to_int32 +from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype, safe_to_int_dtype @pytest.mark.parametrize( @@ -12,9 +12,9 @@ [7, np.asarray([0, 1, -3], dtype=np.int64)], ids=["python-int", "int64-array"], ) -def test_safe_to_int32_returns_int32(value: object) -> None: - """`safe_to_int32` returns a `jnp.int32` array for any in-range int input.""" - out = safe_to_int32(value, name="x") +def test_safe_to_int_dtype_returns_int32(value: object): + """`safe_to_int_dtype` returns a `jnp.int32` array for any in-range int input.""" + out = safe_to_int_dtype(value, name="x") assert out.dtype == jnp.int32 @@ -26,31 +26,81 @@ def test_safe_to_int32_returns_int32(value: object) -> None: ], ids=["python-int", "int64-array"], ) -def test_safe_to_int32_preserves_in_range_values( - value: object, expected: object -) -> None: - """`safe_to_int32` preserves element values for in-range inputs.""" - out = safe_to_int32(value, name="x") +def test_safe_to_int_dtype_preserves_in_range_values(value: object, expected: object): + """`safe_to_int_dtype` preserves element values for in-range inputs.""" + out = safe_to_int_dtype(value, name="x") np.testing.assert_array_equal(np.asarray(out), expected) -def test_safe_to_int32_raises_on_python_int_overflow() -> None: +def test_safe_to_int_dtype_raises_on_python_int_overflow(): """A Python int above int32 max raises `ValueError` naming the leaf.""" with pytest.raises(ValueError, match="my_param"): - safe_to_int32(2**32, name="my_param") + safe_to_int_dtype(2**32, name="my_param") -def test_safe_to_int32_raises_on_array_overflow() -> None: +def test_safe_to_int_dtype_raises_on_array_overflow(): """An int64 array containing values above int32 max raises with the leaf name.""" # Use numpy here: `jnp.asarray(..., dtype=jnp.int64)` truncates to int32 # under `jax_enable_x64=False` and trips JAX's own overflow guard before - # `safe_to_int32` ever sees the value. + # `safe_to_int_dtype` ever sees the value. arr = np.asarray([1, 2, 2**32], dtype=np.int64) with pytest.raises(ValueError, match="regime"): - safe_to_int32(arr, name="regime") + safe_to_int_dtype(arr, name="regime") -def test_safe_to_int32_raises_on_underflow() -> None: +def test_safe_to_int_dtype_raises_on_underflow(): """A Python int below int32 min raises `ValueError` naming the leaf.""" with pytest.raises(ValueError, match="offset"): - safe_to_int32(-(2**40), name="offset") + safe_to_int_dtype(-(2**40), name="offset") + + +def test_canonical_float_dtype_is_float32_under_no_x64(x64_disabled: None): + """`canonical_float_dtype()` is `float32` when `jax_enable_x64=False`.""" + assert canonical_float_dtype() == jnp.float32 + + +def test_canonical_float_dtype_is_float64_under_x64(x64_enabled: None): + """`canonical_float_dtype()` is `float64` when `jax_enable_x64=True`.""" + assert canonical_float_dtype() == jnp.float64 + + +def test_safe_to_float_dtype_casts_python_float_to_canonical(x64_disabled: None): + """A Python float lands at `float32` under no-x64.""" + out = safe_to_float_dtype(0.5, name="x") + assert out.dtype == jnp.float32 + assert float(out) == 0.5 + + +def test_safe_to_float_dtype_casts_float64_array_to_float32(x64_disabled: None): + """A `float64` array within float32 range is downcast to `float32`. + + Build the input with `np.asarray` rather than `jnp.asarray` — under + `jax_enable_x64=False`, JAX silently truncates a `float64` request + to `float32` at construction time, so a JAX-built input would never + reach the helper as `float64` and the down-cast path would not be + exercised. + """ + arr = np.asarray([0.1, 0.2, 0.3], dtype=np.float64) + out = safe_to_float_dtype(arr, name="x") + assert out.dtype == jnp.float32 + + +def test_safe_to_float_dtype_passes_array_through_under_x64(x64_enabled: None): + """Under x64, a `float64` array is preserved (no down-cast required).""" + arr = jnp.asarray([0.1, 0.2, 0.3], dtype=jnp.float64) + out = safe_to_float_dtype(arr, name="x") + assert out.dtype == jnp.float64 + + +def test_safe_to_float_dtype_raises_on_overflow_when_downcasting(x64_disabled: None): + """A `float64` value above float32 max raises `OverflowError`, naming the leaf.""" + big = 1e40 + with pytest.raises(OverflowError, match="big_param"): + safe_to_float_dtype(big, name="big_param") + + +def test_safe_to_float_dtype_no_overflow_check_when_upcasting(x64_enabled: None): + """Casting `float32` -> `float64` (up) skips the overflow check.""" + arr = jnp.asarray([0.1, 0.2], dtype=jnp.float32) + out = safe_to_float_dtype(arr, name="x") + assert out.dtype == jnp.float64 diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py new file mode 100644 index 00000000..068ccbbb --- /dev/null +++ b/tests/test_float_dtype_invariants.py @@ -0,0 +1,281 @@ +"""Float dtypes follow `canonical_float_dtype()` across pylcm boundaries.""" + +from collections.abc import Callable +from types import MappingProxyType + +import jax.numpy as jnp +import numpy as np +import pytest + +from lcm.dtypes import canonical_float_dtype +from lcm.grids import IrregSpacedGrid, LinSpacedGrid, LogSpacedGrid +from lcm.params import MappingLeaf +from lcm.params.processing import process_params +from lcm.params.sequence_leaf import SequenceLeaf +from lcm.simulation.initial_conditions import build_initial_states +from tests.test_models.deterministic.regression import ( + RegimeId, + get_model, + get_params, +) + + +def test_build_initial_states_casts_user_float64_to_canonical(x64_disabled: None): + """A float64 continuous initial state lands at `canonical_float_dtype()`.""" + model = get_model(n_periods=3) + initial_states = { + "wealth": np.asarray([20.0, 50.0], dtype=np.float64), + "age": np.asarray([18.0, 18.0], dtype=np.float64), + } + flat = build_initial_states( + initial_states=initial_states, # ty: ignore[invalid-argument-type] + internal_regimes=model.internal_regimes, + ) + assert flat["working_life__wealth"].dtype == canonical_float_dtype() + + +def test_build_initial_states_casts_user_int_to_canonical(x64_disabled: None): + """A continuous initial state given as int32 lands at `canonical_float_dtype()`.""" + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.asarray([20, 50], dtype=jnp.int32), + "age": jnp.asarray([18, 18], dtype=jnp.int32), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + assert flat["working_life__wealth"].dtype == canonical_float_dtype() + + +def test_build_initial_states_missing_continuous_fallback_dtype_is_canonical( + x64_disabled: None, +): + """A missing continuous state falls back to a canonical-dtype array.""" + model = get_model(n_periods=3) + # Supply a placeholder state to set n_subjects without touching `wealth`. + flat = build_initial_states( + initial_states={"placeholder": jnp.asarray([0.0, 0.0])}, + internal_regimes=model.internal_regimes, + ) + assert flat["working_life__wealth"].dtype == canonical_float_dtype() + + +def test_build_initial_states_missing_continuous_fallback_values_are_nan( + x64_disabled: None, +): + """A missing continuous state falls back to an all-NaN array. + + Pinning only the dtype would let a regression that fills the fallback + with zeros (or anything else representable) pass; assert the values. + """ + model = get_model(n_periods=3) + flat = build_initial_states( + initial_states={"placeholder": jnp.asarray([0.0, 0.0])}, + internal_regimes=model.internal_regimes, + ) + assert bool(jnp.all(jnp.isnan(flat["working_life__wealth"]))) + + +def test_process_params_casts_float64_array_to_canonical_under_no_x64( + x64_disabled: None, +): + """A `float64` array param is downcast to `float32` under `jax_enable_x64=False`. + + Build with `np.asarray` rather than `jnp.asarray` — the JAX builder + silently truncates to `float32` under no-x64 at construction time, so a + JAX-built input would never reach the helper as `float64`. + """ + template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) + user_params = { + "regime_a": {"schedule": np.asarray([0.1, 0.2, 0.3], dtype=np.float64)} + } + + out = process_params( + params=user_params, # ty: ignore[invalid-argument-type] + params_template=template, # ty: ignore[invalid-argument-type] + ) + + schedule = out["regime_a"]["schedule"] + assert schedule.dtype == jnp.float32 + + +def test_process_params_casts_python_float_to_canonical(x64_disabled: None): + """A Python `float` param leaf is cast to `canonical_float_dtype()`.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"discount_factor": "float"})} + ) + user_params = {"regime_a": {"discount_factor": 0.95}} + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + discount_factor = out["regime_a"]["discount_factor"] + np.testing.assert_allclose(float(discount_factor), 0.95, rtol=1e-6) + assert discount_factor.dtype == canonical_float_dtype() + + +def test_process_params_float_array_overflow_raises_with_qualified_name( + x64_disabled: None, +): + """An out-of-float32 float64 array raises naming the qualified leaf.""" + template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) + user_params = {"regime_a": {"schedule": np.asarray([0.0, 1e40], dtype=np.float64)}} + + with pytest.raises(OverflowError, match="schedule"): + process_params( + params=user_params, # ty: ignore[invalid-argument-type] + params_template=template, # ty: ignore[invalid-argument-type] + ) + + +def test_simulate_state_pool_dtype_stable_across_periods(x64_disabled: None): + """A multi-period simulate keeps every state's dtype stable across periods. + + The intended invariant is per-state stability; failing on any single + state still gives an actionable signal because the assertion message + names the offending state and its observed dtypes. + """ + n_periods = 4 + model = get_model(n_periods=n_periods) + params = get_params(n_periods=n_periods) + initial = { + "wealth": jnp.asarray([20.0, 50.0, 80.0]), + "age": jnp.asarray([18.0, 18.0, 18.0]), + "regime": jnp.asarray([RegimeId.working_life] * 3), + } + + result = model.simulate( + params=params, period_to_regime_to_V_arr=None, initial_conditions=initial + ) + + seen: dict[str, set] = {} + for period_data in result.raw_results.values(): + for snap in period_data.values(): + for state_name, arr in snap.states.items(): + seen.setdefault(state_name, set()).add(arr.dtype) + drifted = {name: dtypes for name, dtypes in seen.items() if len(dtypes) != 1} + assert not drifted, f"States drifted across periods: {drifted}" + + +def test_solve_v_arrays_at_canonical_float_dtype(x64_disabled: None): + """Every V-array returned by `model.solve()` is at `canonical_float_dtype()`.""" + model = get_model(n_periods=3) + period_to_regime_to_V_arr = model.solve(params=get_params(n_periods=3)) + target = canonical_float_dtype() + wrong = { + (period, regime_name): v_arr.dtype + for period, period_v in period_to_regime_to_V_arr.items() + for regime_name, v_arr in period_v.items() + if v_arr.dtype != target + } + assert not wrong, f"V-arrays not at {target}: {wrong}" + + +@pytest.mark.parametrize( + "make_grid", + [ + lambda: LinSpacedGrid(start=0, stop=1, n_points=5), + lambda: LogSpacedGrid(start=1, stop=10, n_points=5), + lambda: IrregSpacedGrid(points=(0.0, 0.5, 1.0)), + ], + ids=["linspaced", "logspaced", "irregspaced"], +) +def test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64( + make_grid: Callable[[], LinSpacedGrid | LogSpacedGrid | IrregSpacedGrid], + x64_disabled: None, +): + """Continuous grid `to_jax()` materialises at `float32` under no-x64. + + Asserts the concrete target dtype rather than `canonical_float_dtype()` + so the test fails if a future grid implementation hardcodes `float64` + (which JAX would silently truncate to `float32` under no-x64; the + helper-side comparison would mask that, the literal-side comparison + surfaces it). + + Grids are constructed inside the test body so the `x64_disabled` + fixture is in effect; grid dtype is now sticky to construction-time + `jax_enable_x64`. + """ + grid = make_grid() + assert grid.to_jax().dtype == jnp.float32 + + +@pytest.mark.parametrize("attr", ["start", "stop"]) +def test_uniform_grid_stores_endpoints_as_canonical_jax_scalar( + attr: str, x64_disabled: None +): + """`LinSpacedGrid` stores `start`/`stop` as JAX scalars at canonical dtype.""" + grid = LinSpacedGrid(start=0.0, stop=100.0, n_points=10) + value = getattr(grid, attr) + assert isinstance(value, jnp.ndarray) + assert value.dtype == canonical_float_dtype() + + +def test_irreg_grid_stores_points_as_canonical_jax_array(x64_disabled: None): + """`IrregSpacedGrid` stores `points` as a JAX array at canonical dtype.""" + grid = IrregSpacedGrid(points=(0.0, 0.5, 1.0)) + assert isinstance(grid.points, jnp.ndarray) + assert grid.points.dtype == canonical_float_dtype() + + +@pytest.mark.parametrize("key", ["low", "high"]) +def test_process_params_casts_float_array_inside_mapping_leaf_to_canonical( + key: str, x64_disabled: None +): + """`MappingLeaf` float arrays land at `canonical_float_dtype()`.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "MappingLeaf"})} + ) + user_params = { + "regime_a": { + "sched": MappingLeaf( + { + "low": np.asarray([0.1, 0.2], dtype=np.float64), + "high": np.asarray([0.5, 0.7], dtype=np.float64), + } + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + assert ( + out["regime_a"]["sched"].data[key].dtype # ty: ignore[unresolved-attribute] + == jnp.float32 + ) + + +@pytest.mark.parametrize("index", [0, 1]) +def test_process_params_casts_float_array_inside_sequence_leaf_to_canonical( + index: int, x64_disabled: None +): + """`SequenceLeaf` float arrays land at `canonical_float_dtype()`.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "SequenceLeaf"})} + ) + user_params = { + "regime_a": { + "sched": SequenceLeaf( + [ + np.asarray([0.1, 0.2], dtype=np.float64), + np.asarray([0.5, 0.7], dtype=np.float64), + ] + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + assert ( + out["regime_a"]["sched"].data[index].dtype # ty: ignore[unresolved-attribute] + == jnp.float32 + ) diff --git a/tests/test_grid_helpers.py b/tests/test_grid_helpers.py index 2b698d48..a3994c1b 100644 --- a/tests/test_grid_helpers.py +++ b/tests/test_grid_helpers.py @@ -15,19 +15,23 @@ def test_linspace(): - calculated = linspace(start=1, stop=2, n_points=6) + calculated = linspace( + start=jnp.asarray(1.0), stop=jnp.asarray(2.0), n_points=jnp.int32(6) + ) expected = np.array([1, 1.2, 1.4, 1.6, 1.8, 2]) aaae(calculated, expected, decimal=DECIMAL_PRECISION) def test_linspace_mapped_value(): """For reference of the grid values, see expected grid in `test_linspace`.""" + start = jnp.asarray(1.0) + stop = jnp.asarray(2.0) # Get position corresponding to a value in the grid calculated = get_linspace_coordinate( - value=1.2, - start=1, - stop=2, - n_points=6, + value=jnp.asarray(1.2), + start=start, + stop=stop, + n_points=jnp.int32(6), ) assert np.allclose(calculated, 1.0) @@ -36,25 +40,27 @@ def test_linspace_mapped_value(): # Here, the value is 1.3, that is in the middle of 1.2 and 1.4, which have the # positions 1 and 2, respectively. Therefore, we want the position to be 1.5. calculated = get_linspace_coordinate( - value=1.3, - start=1, - stop=2, - n_points=6, + value=jnp.asarray(1.3), + start=start, + stop=stop, + n_points=jnp.int32(6), ) assert np.allclose(calculated, 1.5) # Get position corresponding to a value that is outside the grid calculated = get_linspace_coordinate( - value=0.6, - start=1, - stop=2, - n_points=6, + value=jnp.asarray(0.6), + start=start, + stop=stop, + n_points=jnp.int32(6), ) assert np.allclose(calculated, -2.0) def test_logspace(): - calculated = logspace(start=1, stop=100, n_points=7) + calculated = logspace( + start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=jnp.int32(7) + ) expected = np.array( [ 1.0, @@ -72,10 +78,10 @@ def test_logspace(): def test_logspace_mapped_value(): """For reference of the grid values, see expected grid in `test_logspace`.""" calculated = get_logspace_coordinate( - value=(2.15443469 + 4.64158883) / 2, - start=1, - stop=100, - n_points=7, + value=jnp.asarray((2.15443469 + 4.64158883) / 2), + start=jnp.asarray(1.0), + stop=jnp.asarray(100.0), + n_points=jnp.int32(7), ) assert np.allclose(calculated, 1.5) @@ -84,9 +90,9 @@ def test_logspace_mapped_value(): def test_map_coordinates_linear(): """Illustrative test on how the output of get_linspace_coordinate can be used.""" grid_info = { - "start": 0, - "stop": 1, - "n_points": 3, + "start": jnp.asarray(0.0), + "stop": jnp.asarray(1.0), + "n_points": jnp.int32(3), } grid = linspace(**grid_info) # [0, 0.5, 1] @@ -96,7 +102,7 @@ def test_map_coordinates_linear(): # We choose a coordinate that is exactly in the middle between the first and second # entry of the grid. coordinate = get_linspace_coordinate( - value=0.25, + value=jnp.asarray(0.25), **grid_info, ) @@ -109,9 +115,9 @@ def test_map_coordinates_linear(): def test_map_coordinates_logarithmic(): """Illustrative test on how the output of get_logspace_coordinate can be used.""" grid_info = { - "start": 1, - "stop": 2, - "n_points": 3, + "start": jnp.asarray(1.0), + "stop": jnp.asarray(2.0), + "n_points": jnp.int32(3), } grid = logspace(**grid_info) # [1.0, 1.414213562373095, 2.0] @@ -121,7 +127,7 @@ def test_map_coordinates_logarithmic(): # We choose a coordinate that is exactly in the middle between the first and second # entry of the grid. coordinate = get_logspace_coordinate( - value=(1.0 + 1.414213562373095) / 2, + value=jnp.asarray((1.0 + 1.414213562373095) / 2), **grid_info, ) @@ -134,9 +140,9 @@ def test_map_coordinates_logarithmic(): def test_map_coordinates_linear_outside_grid(): """Illustrative test on what happens to values outside the grid.""" grid_info = { - "start": 1, - "stop": 2, - "n_points": 2, + "start": jnp.asarray(1.0), + "stop": jnp.asarray(2.0), + "n_points": jnp.int32(2), } grid = linspace(**grid_info) # [1, 2] @@ -146,8 +152,8 @@ def test_map_coordinates_linear_outside_grid(): # Get coordinates corresponding to values outside the grid [1, 2] coordinates = jnp.array( [ - get_linspace_coordinate(value=grid_val, **grid_info) - for grid_val in [-1, 0, 3] + get_linspace_coordinate(value=jnp.asarray(grid_val), **grid_info) + for grid_val in [-1.0, 0.0, 3.0] ] ) @@ -158,16 +164,28 @@ def test_map_coordinates_linear_outside_grid(): def test_get_linspace_coordinate_with_array(): values = jnp.array([1.0, 1.2, 1.5]) - coords = get_linspace_coordinate(value=values, start=1, stop=2, n_points=6) + coords = get_linspace_coordinate( + value=values, + start=jnp.asarray(1.0), + stop=jnp.asarray(2.0), + n_points=jnp.int32(6), + ) expected = jnp.array([0.0, 1.0, 2.5]) aaae(coords, expected, decimal=DECIMAL_PRECISION) def test_get_logspace_coordinate_with_array(): - grid = logspace(start=1, stop=100, n_points=7) + grid = logspace( + start=jnp.asarray(1.0), stop=jnp.asarray(100.0), n_points=jnp.int32(7) + ) mid = (float(grid[1]) + float(grid[2])) / 2 values = jnp.array([mid]) - coords = get_logspace_coordinate(value=values, start=1, stop=100, n_points=7) + coords = get_logspace_coordinate( + value=values, + start=jnp.asarray(1.0), + stop=jnp.asarray(100.0), + n_points=jnp.int32(7), + ) aaae(coords, jnp.array([1.5]), decimal=DECIMAL_PRECISION) diff --git a/tests/test_grids.py b/tests/test_grids.py index da2370d9..cd9cbfb1 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -167,34 +167,38 @@ class UnorderedCat: assert grid.ordered is False -def test_validate_continuous_grid_invalid_start(): - error_msg = "start must be a scalar int or float value" - with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start="a", stop=1, n_points=10) # ty: ignore[invalid-argument-type] +def test_lin_spaced_grid_rejects_non_numeric_start(): + """Non-numeric `start` is rejected at the boundary cast.""" + with pytest.raises((TypeError, ValueError)): + LinSpacedGrid(start="a", stop=1, n_points=10) # ty: ignore[invalid-argument-type] -def test_validate_continuous_grid_invalid_stop(): - error_msg = "stop must be a scalar int or float value" - with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=1, stop="a", n_points=10) # ty: ignore[invalid-argument-type] +def test_lin_spaced_grid_rejects_non_numeric_stop(): + """Non-numeric `stop` is rejected at the boundary cast.""" + with pytest.raises((TypeError, ValueError)): + LinSpacedGrid(start=1, stop="a", n_points=10) # ty: ignore[invalid-argument-type] -def test_validate_continuous_grid_invalid_n_points(): - error_msg = "n_points must be an int greater than 0 but is a" - with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=1, stop=2, n_points="a") # ty: ignore[invalid-argument-type] +def test_lin_spaced_grid_rejects_non_numeric_n_points(): + """Non-numeric `n_points` is rejected at the boundary cast.""" + with pytest.raises((TypeError, ValueError)): + LinSpacedGrid(start=1, stop=2, n_points="a") # ty: ignore[invalid-argument-type] def test_validate_continuous_grid_negative_n_points(): error_msg = "n_points must be an int greater than 0 but is -1" with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=1, stop=2, n_points=-1) + _validate_continuous_grid( + start=jnp.asarray(1.0), stop=jnp.asarray(2.0), n_points=jnp.int32(-1) + ) def test_validate_continuous_grid_start_greater_than_stop(): error_msg = "start must be less than stop" with pytest.raises(GridInitializationError, match=error_msg): - _validate_continuous_grid(start=2, stop=1, n_points=10) + _validate_continuous_grid( + start=jnp.asarray(2.0), stop=jnp.asarray(1.0), n_points=jnp.int32(10) + ) def test_linspace_grid_creation(): @@ -229,12 +233,20 @@ def test_logspace_grid_rejects_negative_start(): def test_validate_continuous_grid_rejects_nan_start(): with pytest.raises(GridInitializationError, match="start must be finite"): - _validate_continuous_grid(start=float("nan"), stop=10, n_points=5) + _validate_continuous_grid( + start=jnp.asarray(float("nan")), + stop=jnp.asarray(10.0), + n_points=jnp.int32(5), + ) def test_validate_continuous_grid_rejects_inf_stop(): with pytest.raises(GridInitializationError, match="stop must be finite"): - _validate_continuous_grid(start=1, stop=float("inf"), n_points=5) + _validate_continuous_grid( + start=jnp.asarray(1.0), + stop=jnp.asarray(float("inf")), + n_points=jnp.int32(5), + ) def test_irreg_spaced_grid_rejects_nan_points(): @@ -332,8 +344,9 @@ def test_linspaced_coordinates_match_other_grid_types( rtol = base_rtol * max_magnitude for value in all_test_values: - lin_coord = float(lin_grid.get_coordinate(value)) - other_coord = float(other_grid.get_coordinate(value)) + value_jax = jnp.asarray(value) + lin_coord = float(lin_grid.get_coordinate(value_jax)) + other_coord = float(other_grid.get_coordinate(value_jax)) assert np.isclose(lin_coord, other_coord, rtol=rtol), ( f"Mismatch at value {value} for {grid_type} vs LinSpacedGrid " f"({start}, {stop}, {n_points}): " @@ -583,7 +596,7 @@ def test_piecewise_log_spaced_grid_coordinate_at_gridpoints(): grid = PiecewiseLogSpacedGrid(pieces=(Piece(interval="[1, 100]", n_points=3),)) points = grid.to_jax() for i, p in enumerate(points): - coord = float(grid.get_coordinate(float(p))) + coord = float(grid.get_coordinate(p)) assert coord == pytest.approx(i) @@ -595,8 +608,8 @@ def test_piecewise_log_spaced_grid_coordinate_multi_piece(): Piece(interval="[10, 100]", n_points=2), ) ) - assert float(grid.get_coordinate(10.0)) == pytest.approx(2.0) - assert float(grid.get_coordinate(100.0)) == pytest.approx(3.0) + assert float(grid.get_coordinate(jnp.asarray(10.0))) == pytest.approx(2.0) + assert float(grid.get_coordinate(jnp.asarray(100.0))) == pytest.approx(3.0) def _create_boundary_test_grid(grid_cls, boundary_style: str): @@ -671,9 +684,9 @@ def test_piecewise_boundary_conditions(grid_cls, boundary_style: str): def test_piecewise_single_piece(): """Test piecewise grid with single piece works correctly.""" grid = PiecewiseLinSpacedGrid(pieces=(Piece(interval="[0, 10]", n_points=11),)) - assert float(grid.get_coordinate(0.0)) == pytest.approx(0.0) - assert float(grid.get_coordinate(5.0)) == pytest.approx(5.0) - assert float(grid.get_coordinate(10.0)) == pytest.approx(10.0) + assert float(grid.get_coordinate(jnp.asarray(0.0))) == pytest.approx(0.0) + assert float(grid.get_coordinate(jnp.asarray(5.0))) == pytest.approx(5.0) + assert float(grid.get_coordinate(jnp.asarray(10.0))) == pytest.approx(10.0) def test_lin_spaced_grid_get_coordinate_with_array(): diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index 9e9c6011..e089c0e6 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -61,13 +61,23 @@ def test_missing_cat_code_is_int32_minimum() -> None: assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE -def test_update_states_for_subjects_preserves_storage_dtype() -> None: - """A transition that returns int64 cannot promote the storage pool to int64.""" +def test_update_states_for_subjects_keeps_same_dtype_round_trip() -> None: + """Canonical-dtype transition outputs round-trip through the state pool. + + With every input boundary pinned to the canonical dtype, a well-typed user + transition returns canonical-dtype outputs and `_update_states_for_subjects` + writes them through `jnp.where` without dtype change. This test pins the + contract for the int side; mixed-dtype inputs are out of scope — the + function does not defend against transitions that violate the canonical- + dtype invariant. + """ all_states = MappingProxyType( {"work__health": jnp.asarray([0, 1, 0, 1], dtype=jnp.int32)} ) - int64_next = jnp.asarray([1, 1, 1, 1], dtype=jnp.int64) - computed = MappingProxyType({"work": MappingProxyType({"next_health": int64_next})}) + next_values = jnp.asarray([1, 1, 1, 1], dtype=jnp.int32) + computed = MappingProxyType( + {"work": MappingProxyType({"next_health": next_values})} + ) subjects = jnp.asarray([True, False, True, False]) updated = _update_states_for_subjects( @@ -91,7 +101,7 @@ def test_process_params_casts_python_int_to_int32() -> None: final_age = out["regime_a"]["final_age"] assert int(final_age) == 65 - assert final_age.dtype == jnp.int32 # ty: ignore[unresolved-attribute] + assert final_age.dtype == jnp.int32 def test_process_params_casts_int64_array_to_int32() -> None: @@ -105,7 +115,7 @@ def test_process_params_casts_int64_array_to_int32() -> None: ) schedule = out["regime_a"]["schedule"] - assert schedule.dtype == jnp.int32 # ty: ignore[unresolved-attribute] + assert schedule.dtype == jnp.int32 def test_process_params_int_array_overflow_raises_with_qualified_name() -> None: diff --git a/tests/test_next_state.py b/tests/test_next_state.py index a1f688c5..ccb79b39 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -37,18 +37,18 @@ def test_get_next_state_function_with_solve_target(): ) flat_regime_params = { - "discount_factor": 1.0, - "utility__disutility_of_work": 1.0, - "next_wealth__interest_rate": 0.05, + "discount_factor": jnp.asarray(1.0), + "utility__disutility_of_work": jnp.asarray(1.0), + "next_wealth__interest_rate": jnp.asarray(0.05), } - action = {"labor_supply": 1, "consumption": 10} - state = {"wealth": 20} + action = {"labor_supply": jnp.asarray(1), "consumption": jnp.asarray(10.0)} + state = {"wealth": jnp.asarray(20.0)} got = got_func( **action, **state, - period=1, - age=1.0, + period=jnp.int32(1), + age=jnp.asarray(1.0), **flat_regime_params, ) assert got == {"next_wealth": 1.05 * (20 - 10)} diff --git a/tests/test_pandas_utils.py b/tests/test_pandas_utils.py index e1c67139..a8dba7b6 100644 --- a/tests/test_pandas_utils.py +++ b/tests/test_pandas_utils.py @@ -1443,8 +1443,8 @@ def test_convert_series_function_level_series() -> None: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["working_life"]["next_partner__probs_array"] - assert arr.shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] - assert float(arr[0, 0, 0, 0]) == pytest.approx(1.0) # ty: ignore[not-subscriptable] + assert arr.shape == (3, 2, 2, 2) + assert float(arr[0, 0, 0, 0]) == pytest.approx(1.0) def test_convert_series_model_level_scalar_passthrough() -> None: @@ -1484,7 +1484,7 @@ def test_convert_series_regime_level_series() -> None: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["working_life"]["next_partner__probs_array"] - assert arr.shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (3, 2, 2, 2) def test_convert_series_mixed_dict() -> None: @@ -1511,7 +1511,7 @@ def test_convert_series_mixed_dict() -> None: ) assert result["working_life"]["H__discount_factor"] == 0.95 assert result["working_life"]["utility__disutility_of_work"] == 0.5 - assert result["working_life"]["next_partner__probs_array"].shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] + assert result["working_life"]["next_partner__probs_array"].shape == (3, 2, 2, 2) assert result["working_life"]["next_wealth__interest_rate"] == 0.05 np.testing.assert_allclose( result["working_life"]["labor_income__wage"], jnp.array([10.0]) @@ -1655,7 +1655,7 @@ def test_convert_series_with_derived_categoricals() -> None: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["retirement"]["next_partner__probs_array"] - assert arr.shape == (3, 2, 2, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (3, 2, 2, 2) def test_convert_series_per_target_transition() -> None: @@ -1735,7 +1735,7 @@ def _next_wealth(wealth: float) -> float: regime_names_to_ids=model.regime_names_to_ids, ) arr = result["working"]["to_working_next_health__probs_array"] - assert arr.shape == (3, 2, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (3, 2, 2) def test_build_outcome_mapping_qualified_func_name() -> None: @@ -1833,8 +1833,8 @@ def _next_wealth_sc(wealth: float) -> float: ages=model.ages, regime_names_to_ids=model.regime_names_to_ids, ) - assert result_both["regime_a"]["utility__rates"].shape == (2,) # ty: ignore[unresolved-attribute] - assert result_both["regime_b"]["utility__rates"].shape == (3,) # ty: ignore[unresolved-attribute] + assert result_both["regime_a"]["utility__rates"].shape == (2,) + assert result_both["regime_b"]["utility__rates"].shape == (3,) def test_convert_series_runtime_grid_param() -> None: @@ -2021,7 +2021,7 @@ def _health_probs_cross( arr = result["pre65"]["to_post65_next_health__health_trans_probs_cross"] # Shape: (n_ages=2, n_source_health=3, n_target_health=2) # n_ages=2 because AgeGrid has ages [0, 1]; missing age 1 is NaN-filled. - assert arr.shape == (2, 3, 2) # ty: ignore[unresolved-attribute] + assert arr.shape == (2, 3, 2) def test_resolve_categoricals_includes_derived_when_no_regime_name() -> None: diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 68cb88d0..561fc1a5 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -29,7 +29,7 @@ def _retired_utility(wealth: ContinuousState) -> FloatND: return jnp.log(wealth) -def _build_tiny_model(): +def _build_tiny_model(*, enable_jit: bool, n_subjects: int): def utility(consumption: ContinuousAction, wealth: ContinuousState) -> FloatND: return jnp.log(consumption + wealth) @@ -60,7 +60,8 @@ def next_regime(period: int) -> ScalarInt: regimes={"working": working, "retired": retired}, ages=ages, regime_id_class=_RegimeId, - enable_jit=False, + enable_jit=enable_jit, + n_subjects=n_subjects, ) params = {"discount_factor": 0.95} return model, params @@ -76,7 +77,7 @@ def _initial_conditions(): @pytest.fixture def model_and_params(): - return _build_tiny_model() + return _build_tiny_model(enable_jit=False, n_subjects=2) @pytest.fixture @@ -171,6 +172,30 @@ def test_simulate_with_solve_debug_persists_snapshot(tmp_path, model_and_params) assert snapshot.result is not None +def test_simulate_debug_persists_snapshot_with_aot_compiled_regimes(tmp_path): + """Debug snapshot saves successfully when `n_subjects` triggers AOT compile. + + AOT compilation produces `jax.stages.Compiled` callables on each + `InternalRegime.simulate_functions`; their backing `LoadedExecutable` + cannot be pickled. The snapshot path must strip those before pickling + `result.pkl`. + """ + model, params = _build_tiny_model(enable_jit=True, n_subjects=2) + model.simulate( + params=params, + initial_conditions=_initial_conditions(), + period_to_regime_to_V_arr=None, + log_level="debug", + log_path=tmp_path, + ) + + dirs = sorted(tmp_path.glob("simulate_snapshot_*/")) + assert len(dirs) == 1 + snapshot = load_snapshot(dirs[0]) + assert isinstance(snapshot, SimulateSnapshot) + assert snapshot.result is not None + + def test_solve_no_persistence_when_not_debug(tmp_path, model_and_params): model, params = model_and_params model.solve(params=params, log_level="progress", log_path=tmp_path) diff --git a/tests/test_regime_state_mismatch.py b/tests/test_regime_state_mismatch.py index 6332c1ae..11abb8e9 100644 --- a/tests/test_regime_state_mismatch.py +++ b/tests/test_regime_state_mismatch.py @@ -461,7 +461,7 @@ class _RegimeId: b: int dead: int - def next_regime() -> ScalarInt: + def next_regime() -> int: return _RegimeId.dead a = Regime( @@ -505,7 +505,7 @@ class _RegimeId: b: int dead: int - def next_regime() -> ScalarInt: + def next_regime() -> int: return _RegimeId.dead a = Regime( diff --git a/tests/test_validate_param_types.py b/tests/test_validate_param_types.py index 939950ce..17428606 100644 --- a/tests/test_validate_param_types.py +++ b/tests/test_validate_param_types.py @@ -1,11 +1,17 @@ -"""Test that numpy arrays in params are rejected after processing.""" +"""Tests for params accepted at the boundary by `_validate_param_types`. + +After `process_params` casts typed numeric arrays to canonical pylcm +dtypes, every supported user input form (numpy arrays, JAX arrays, +Python scalars) reaches the validator as a JAX array or Python scalar +and is accepted. +""" import jax.numpy as jnp import numpy as np -import pytest +from jax import Array from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical -from lcm.exceptions import InvalidParamsError +from lcm.dtypes import canonical_float_dtype @categorical(ordered=True) @@ -49,20 +55,31 @@ def _make_model() -> Model: ) -def test_numpy_array_param_rejected() -> None: - """Passing a numpy array as a param should raise InvalidParamsError.""" +def test_numpy_array_param_normalised_to_canonical_jax_array() -> None: + """A numpy array param is cast to a JAX array at `canonical_float_dtype()`.""" model = _make_model() - with pytest.raises(InvalidParamsError, match=r"numpy\.ndarray"): - model.solve(params={"bonus": np.array(1.0), "discount_factor": 0.95}) # ty: ignore[invalid-argument-type] + internal = model._process_params( + params={"bonus": np.asarray(1.0, dtype=np.float64), "discount_factor": 0.95} # ty: ignore[invalid-argument-type] + ) + bonus = internal["working"]["utility__bonus"] + assert isinstance(bonus, Array) + assert bonus.dtype == canonical_float_dtype() -def test_jax_array_param_accepted() -> None: - """JAX arrays should be accepted.""" +def test_jax_array_param_kept_at_canonical_dtype() -> None: + """A typed JAX array param is kept (or cast) at `canonical_float_dtype()`.""" model = _make_model() - model.solve(params={"bonus": jnp.array(1.0), "discount_factor": 0.95}) + internal = model._process_params( + params={"bonus": jnp.asarray(1.0), "discount_factor": 0.95} + ) + bonus = internal["working"]["utility__bonus"] + assert bonus.dtype == canonical_float_dtype() -def test_python_scalar_param_accepted() -> None: - """Python scalars should be accepted.""" +def test_python_float_param_cast_to_canonical_dtype() -> None: + """A Python `float` param is cast to `canonical_float_dtype()`.""" model = _make_model() - model.solve(params={"bonus": 1.0, "discount_factor": 0.95}) + internal = model._process_params(params={"bonus": 1.0, "discount_factor": 0.95}) + bonus = internal["working"]["utility__bonus"] + assert float(bonus) == 1.0 + assert bonus.dtype == canonical_float_dtype()