diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 477d5635..a9364879 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -54,7 +54,7 @@ def _build() -> tuple[object, object, object]: get_benchmark_params, ) - model = create_benchmark_model() + model = create_benchmark_model(n_subjects=_N_SUBJECTS) _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 diff --git a/pixi.lock b/pixi.lock index 5c8127d0..4115222a 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=f09b5e34102ff42f739b95be5a9d388795b734a1#f09b5e34102ff42f739b95be5a9d388795b734a1 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev153+gc585f4b82.d20260504 - sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e + version: 0.0.2.dev195+ga908c8405.d20260505 + sha256: 44c6bd65422fdc0a7d3167cf852107aeca15bf6687a44b57a6749ad553943f11 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 3fe01162..bb10a893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } diff --git a/src/lcm/dtypes.py b/src/lcm/dtypes.py new file mode 100644 index 00000000..51dd958c --- /dev/null +++ b/src/lcm/dtypes.py @@ -0,0 +1,48 @@ +"""Boundary-cast helpers that pin user-supplied data to canonical pylcm dtypes. + +Used at every API boundary that accepts user data (params, initial +conditions, regime-id arrays) — always called from Python, never inside +JIT. Each helper validates that the value fits the target dtype and +raises a clearly-named error if not. + +Casts further down the simulate stack (e.g. transition outputs landing +in the state pool) use plain `.astype` and rely on the boundary cast +above them having already pinned the canonical dtype. +""" + +import jax.numpy as jnp +import numpy as np +from jax import Array + +_INT32_MIN = int(np.iinfo(np.int32).min) +_INT32_MAX = int(np.iinfo(np.int32).max) + + +def safe_to_int32(value: object, *, name: str) -> Array: + """Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range. + + Args: + value: A Python int, numpy/JAX integer scalar, or array-like of + integer values. + name: Qualified name of the leaf — surfaced in the error message + so the user can locate the offending input. + + Returns: + A `jnp.int32` array (0-d if `value` was a scalar). + + Raises: + ValueError: If any element of `value` is outside the int32 range + `[-2**31, 2**31 - 1]`. The message names the leaf via `name`. + + """ + np_value = np.asarray(value) + if np_value.size > 0: + lo = int(np_value.min()) + hi = int(np_value.max()) + if lo < _INT32_MIN or hi > _INT32_MAX: + msg = ( + f"{name}: int32 overflow — value range [{lo}, {hi}] " + f"exceeds [{_INT32_MIN}, {_INT32_MAX}]." + ) + raise ValueError(msg) + return jnp.asarray(np_value, dtype=jnp.int32) diff --git a/src/lcm/grids/discrete.py b/src/lcm/grids/discrete.py index f844aad6..72ded5ea 100644 --- a/src/lcm/grids/discrete.py +++ b/src/lcm/grids/discrete.py @@ -48,5 +48,13 @@ def batch_size(self) -> int: return self.__batch_size def to_jax(self) -> Int1D: - """Convert the grid to a Jax array.""" - return jnp.array(self.codes) + """Convert the grid to a Jax array. + + Discrete state/action codes are pinned to `int32` regardless of the + ambient `jax_enable_x64` setting. A single integer dtype across + transitions, V-array indexing, and action lookups keeps the JIT cache + unsplit and lets AOT-compiled programs ship one signature. `int32` + accommodates any realistic category count and matches the + `MISSING_CAT_CODE` sentinel. + """ + return jnp.array(self.codes, dtype=jnp.int32) diff --git a/src/lcm/model.py b/src/lcm/model.py index 9412d660..374d40bf 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -1,6 +1,8 @@ """Collection of classes that are used by the user to define the model and grids.""" import dataclasses +import logging +import threading from collections.abc import Mapping from pathlib import Path from types import MappingProxyType @@ -30,6 +32,7 @@ ) from lcm.regime import Regime from lcm.regime_building.processing import InternalRegime +from lcm.simulation.compile import compile_all_simulate_functions from lcm.simulation.initial_conditions import validate_initial_conditions from lcm.simulation.result import SimulationResult, get_simulation_output_dtypes from lcm.simulation.simulate import simulate @@ -85,9 +88,45 @@ class Model: fixed_params: UserParams """Parameters fixed at model initialization.""" + n_subjects: int | None = None + """Expected simulate batch size; enables AOT compile of simulate functions. + + Dispatch by call shape: + + - `None`: purely lazy behaviour, no AOT. + - First `simulate(...)` with `actual_n == n_subjects`: AOT-compiles all + simulate functions for that batch shape (blocks before solve runs) + and caches them. + - Subsequent `simulate(...)` with the same matching size: reuses the + cached compiled programs. + - `simulate(...)` with a mismatching size: warns once per size and falls + back to the runtime-traced path. + + Param-shape contract: the cache is keyed only on `n_subjects`. The shapes + and dtypes of `internal_params` leaves at the first matching call become + part of the AOT signature; subsequent calls must keep them stable. MSM- + style estimation (varying values, fixed shapes) is the target use case; + construct a fresh `Model` whenever a param array's shape or dtype changes. + """ + _params_template: ParamsTemplate """Template for the model parameters.""" + _simulate_compile_cache: dict[int, MappingProxyType[RegimeName, InternalRegime]] + """AOT-compiled `internal_regimes` per matching `n_subjects`.""" + + _warned_n_subjects: set[int] + """Mismatching `actual_n_subjects` already warned about (one warning each).""" + + _simulate_compile_lock: threading.Lock + """Serialises mutations of `_simulate_compile_cache` and + `_warned_n_subjects`. + + The check-then-set on each container is held under this lock. The + consequent `log.warning` call sits outside the lock so concurrent + simulate() calls don't serialise on logging I/O. + """ + def __init__( self, *, @@ -100,6 +139,7 @@ def __init__( derived_categoricals: Mapping[FunctionName, DiscreteGrid] = MappingProxyType( {} ), + n_subjects: int | None = None, ) -> None: """Initialize the Model. @@ -115,17 +155,26 @@ def __init__( not in states/actions. Broadcast to all regimes (merged with each regime's own `derived_categoricals`). Raises if a regime already has a conflicting entry. + n_subjects: Expected simulate batch size; if set, the first matching + `simulate(...)` call AOT-compiles all simulate functions for + batch shape `n_subjects` before backward induction starts. + `None` keeps the purely lazy behaviour. """ self.description = description self.ages = ages self.n_periods = ages.n_periods self.fixed_params = ensure_containers_are_immutable(fixed_params) + self.n_subjects = n_subjects + self._simulate_compile_cache = {} + self._warned_n_subjects = set() + self._simulate_compile_lock = threading.Lock() validate_model_inputs( n_periods=self.n_periods, regimes=regimes, regime_id_class=regime_id_class, + n_subjects=n_subjects, ) self.regime_names_to_ids = MappingProxyType( dict( @@ -149,6 +198,27 @@ def __init__( regime_names_to_ids=self.regime_names_to_ids, ) + def __getstate__(self) -> dict[str, object]: + """Return a copy of `__dict__` with per-process AOT compile state removed. + + Drops `_simulate_compile_lock` (a `threading.Lock`, not pickleable), + `_simulate_compile_cache` (compiled XLA programs that can't survive + a process boundary), and `_warned_n_subjects` (its companion set). + `__setstate__` restores all three to their fresh state. + """ + state = self.__dict__.copy() + state.pop("_simulate_compile_lock", None) + state.pop("_simulate_compile_cache", None) + state.pop("_warned_n_subjects", None) + return state + + def __setstate__(self, state: dict[str, object]) -> None: + """Restore AOT compile state to a fresh empty cache.""" + self.__dict__.update(state) + self._simulate_compile_cache = {} + self._warned_n_subjects = set() + self._simulate_compile_lock = threading.Lock() + def get_params_template(self) -> UserFacingParamsTemplate: """Get a human-readable params template. @@ -210,12 +280,34 @@ def solve( internal_params=internal_params, ages=self.ages, ) + return self._solve_compiled( + internal_params=internal_params, + params=params, + log=get_logger(log_level=log_level), + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, + max_compilation_workers=max_compilation_workers, + ) + + def _solve_compiled( + self, + *, + internal_params: InternalParams, + params: UserParams, + log: logging.Logger, + log_level: LogLevel, + log_path: str | Path | None, + log_keep_n_latest: int, + max_compilation_workers: int | None, + ) -> MappingProxyType[int, MappingProxyType[RegimeName, FloatND]]: + """Run backward induction, persisting a snapshot on debug or NaN failure.""" try: period_to_regime_to_V_arr = solve( internal_params=internal_params, ages=self.ages, internal_regimes=self.internal_regimes, - logger=get_logger(log_level=log_level), + logger=log, enable_jit=self.enable_jit, max_compilation_workers=max_compilation_workers, ) @@ -240,6 +332,41 @@ def solve( ) return period_to_regime_to_V_arr + def _resolve_simulate_internal_regimes( + self, + *, + actual_n_subjects: int, + log: logging.Logger, + ) -> MappingProxyType[RegimeName, InternalRegime]: + """Return internal_regimes to use for simulate; AOT cache when matching. + + Dispatch by `n_subjects` and batch-shape match: + + - `n_subjects is None`: return the original `internal_regimes` + (purely lazy path). + - `actual_n_subjects != n_subjects`: warn once per mismatching size, + return the original `internal_regimes`. + - `actual_n_subjects == n_subjects`: return the cached compiled + regimes (caller must have populated the cache before calling). + """ + if self.n_subjects is None: + return self.internal_regimes + if actual_n_subjects != self.n_subjects: + with self._simulate_compile_lock: + already_warned = actual_n_subjects in self._warned_n_subjects + if not already_warned: + self._warned_n_subjects.add(actual_n_subjects) + if not already_warned: + log.warning( + "simulate called with n_subjects=%d but model declared " + "n_subjects=%d; falling back to runtime compile.", + actual_n_subjects, + self.n_subjects, + ) + return self.internal_regimes + with self._simulate_compile_lock: + return self._simulate_compile_cache[self.n_subjects] + def simulate( self, *, @@ -320,31 +447,40 @@ def simulate( ages=self.ages, ) log = get_logger(log_level=log_level) - if period_to_regime_to_V_arr is None: - try: - period_to_regime_to_V_arr = solve( + actual_n_subjects = len(next(iter(initial_conditions.values()))) + n_subjects = self.n_subjects + if n_subjects is not None and n_subjects == actual_n_subjects: + with self._simulate_compile_lock: + needs_compile = n_subjects not in self._simulate_compile_cache + if needs_compile: + compiled = compile_all_simulate_functions( + internal_regimes=self.internal_regimes, internal_params=internal_params, ages=self.ages, - internal_regimes=self.internal_regimes, - logger=log, - enable_jit=self.enable_jit, + n_subjects=n_subjects, max_compilation_workers=max_compilation_workers, + logger=log, ) - except InvalidValueFunctionError as exc: - if log_path is not None and exc.partial_solution is not None: - snap_dir = save_solve_snapshot( - model=self, - params=params, - period_to_regime_to_V_arr=exc.partial_solution, # ty: ignore[invalid-argument-type] - log_path=Path(log_path), - log_keep_n_latest=log_keep_n_latest, - ) - exc.add_note(f"Snapshot saved to {snap_dir}") - raise + with self._simulate_compile_lock: + self._simulate_compile_cache[n_subjects] = compiled + if period_to_regime_to_V_arr is None: + period_to_regime_to_V_arr = self._solve_compiled( + internal_params=internal_params, + params=params, + log=log, + log_level=log_level, + log_path=log_path, + log_keep_n_latest=log_keep_n_latest, + max_compilation_workers=max_compilation_workers, + ) + simulate_internal_regimes = self._resolve_simulate_internal_regimes( + actual_n_subjects=actual_n_subjects, + log=log, + ) result = simulate( internal_params=internal_params, initial_conditions=initial_conditions, - internal_regimes=self.internal_regimes, + internal_regimes=simulate_internal_regimes, regime_names_to_ids=self.regime_names_to_ids, logger=log, period_to_regime_to_V_arr=period_to_regime_to_V_arr, @@ -352,6 +488,13 @@ def simulate( simulation_output_dtypes=self.simulation_output_dtypes, seed=seed, ) + # AOT-compiled regimes carry `jax.stages.Compiled` callables that + # wrap an unpicklable `LoadedExecutable`. `to_dataframe` only reads + # the lazy DAG functions / constraints / transitions on + # `simulate_functions`, never the compiled callables — so swap in + # the lazy regimes to keep the result cloudpickle-safe. + if simulate_internal_regimes is not self.internal_regimes: + result._internal_regimes = self.internal_regimes # noqa: SLF001 if log_level == "debug" and log_path is not None: save_simulate_snapshot( model=self, diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index 23dd60a7..d141c6a3 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -147,8 +147,11 @@ def validate_model_inputs( n_periods: int, regimes: Mapping[RegimeName, Regime], regime_id_class: type, + n_subjects: int | None = None, ) -> None: """Validate model constructor inputs.""" + _fail_if_invalid_n_subjects(n_subjects=n_subjects) + # Early exit if regimes are not lcm.Regime instances if not all(isinstance(regime, Regime) for regime in regimes.values()): raise ModelInitializationError( @@ -201,6 +204,19 @@ def validate_model_inputs( raise ModelInitializationError(msg) +def _fail_if_invalid_n_subjects(*, n_subjects: int | None) -> None: + """Raise TypeError if non-int, ValueError if non-positive.""" + if n_subjects is None: + return + # `bool` is a subclass of `int`; reject explicitly so True/False don't slip through. + if not isinstance(n_subjects, int) or isinstance(n_subjects, bool): + msg = f"n_subjects must be an int or None, got {type(n_subjects).__name__}." + raise TypeError(msg) + if n_subjects <= 0: + msg = f"n_subjects must be a positive integer, got {n_subjects}." + raise ValueError(msg) + + def _validate_all_variables_used(regimes: Mapping[RegimeName, Regime]) -> list[str]: """Validate that all states and actions are used somewhere in each regime. diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 2d696a45..cf1b4cb0 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -153,7 +153,8 @@ def initial_conditions_from_dataframe( # noqa: C901 for col, arr in result_arrays.items() } initial_conditions["regime"] = jnp.array( - df["regime"].map(dict(regime_names_to_ids)).to_numpy() + df["regime"].map(dict(regime_names_to_ids)).to_numpy(), + dtype=jnp.int32, ) return initial_conditions diff --git a/src/lcm/params/processing.py b/src/lcm/params/processing.py index 680e46e7..ae937c44 100644 --- a/src/lcm/params/processing.py +++ b/src/lcm/params/processing.py @@ -1,13 +1,26 @@ -"""Process user-provided params into internal params.""" +"""Process user-provided params into internal params. + +`process_params` resolves user-supplied parameters against the model's +template, then runs a boundary-cast pass that normalises every integer +leaf — Python `int`, typed JAX integer arrays, numpy integer arrays, +and integers inside `MappingLeaf` / `SequenceLeaf` — to `jnp.int32`. +Out-of-range values surface as `ValueError` with the offending leaf's +qualified name. +""" from collections.abc import Mapping from types import MappingProxyType from typing import Any, cast +import numpy as np from dags.tree import QNAME_DELIMITER, qname_from_tree_path, tree_path_from_qname +from jax import Array +from lcm.dtypes import safe_to_int32 from lcm.exceptions import InvalidNameError, InvalidParamsError from lcm.interfaces import InternalRegime +from lcm.params.mapping_leaf import MappingLeaf +from lcm.params.sequence_leaf import SequenceLeaf from lcm.typing import ( InternalParams, ParamsTemplate, @@ -34,7 +47,11 @@ def process_params( - Regime level: `{"regime_0": {"arg_0": 0.0}}` — propagates within regime_0 - Function level: `{"regime_0": {"func": {"arg_0": 0.0}}}` — direct specification - The output always matches the params_template skeleton. + The output always matches the params_template skeleton. Every integer + leaf — Python `int`, typed JAX or numpy integer arrays, and integers + inside `MappingLeaf` / `SequenceLeaf` — is cast to `jnp.int32` so the + AOT signature is stable across calls. Python `bool` and float leaves + are handled by the float-side cast pass. Args: params: User-provided parameters dictionary. @@ -46,6 +63,8 @@ def process_params( Raises: InvalidParamsError: If params contains unexpected keys or type mismatches. InvalidNameError: If the same parameter is specified at multiple levels. + ValueError: If a typed integer leaf carries a value outside the + int32 range; the message names the offending parameter qname. """ return broadcast_to_template(params=params, template=params_template, required=True) @@ -110,12 +129,64 @@ def broadcast_to_template( if unknown: raise InvalidParamsError(f"Unknown keys: {sorted(unknown)}") + for regime, leaves in result.items(): + for param_qname, value in leaves.items(): + leaves[param_qname] = _cast_int_leaves_to_int32( + value, name=f"{regime}{QNAME_DELIMITER}{param_qname}" + ) + return cast( "InternalParams", MappingProxyType({k: MappingProxyType(v) for k, v in result.items()}), ) +def _cast_int_leaves_to_int32(value: Any, *, name: str) -> Any: # noqa: ANN401 + """Normalise integer leaves in a params value to `jnp.int32`. + + Casts: + + - Python `int` scalars — to `jnp.int32` so the DAG sees a JAX scalar + with a pinned dtype rather than a Python int that JAX would + otherwise promote per call site. + - Typed JAX or numpy integer arrays (`jnp.array(..., dtype=jnp.int64)`, + `np.array(...)`) — cast to `int32` to keep the AOT signature stable. + - Integer leaves inside `MappingLeaf` / `SequenceLeaf` — recurse. + + Passes through unchanged: + + - Python `bool` scalars — handled by the float-side cast pass once + it lands. + - Float and non-numeric typed leaves — handled by a separate float- + normalisation pass. + """ + if isinstance(value, MappingLeaf): + return MappingLeaf( + { + k: _cast_int_leaves_to_int32(v, name=f"{name}.{k}") + for k, v in value.data.items() + } + ) + if isinstance(value, SequenceLeaf): + return SequenceLeaf( + [ + _cast_int_leaves_to_int32(v, name=f"{name}[{i}]") + for i, v in enumerate(value.data) + ] + ) + # `bool` is a subclass of `int`, so test for it first and short-circuit + # — bool handling lands with the float-side cast pass, not here. + if isinstance(value, bool): + return value + if isinstance(value, int): + return safe_to_int32(value, name=name) + if isinstance(value, (Array, np.ndarray)) and np.issubdtype( + value.dtype, np.integer + ): + return safe_to_int32(value, name=name) + return value + + def _find_candidates( *, qname: str, diff --git a/src/lcm/regime_building/argmax.py b/src/lcm/regime_building/argmax.py index 0e48cf2b..a4271e2f 100644 --- a/src/lcm/regime_building/argmax.py +++ b/src/lcm/regime_building/argmax.py @@ -43,7 +43,7 @@ def argmax_and_max( # When there are no dimensions to reduce over, return: # - index 0 (trivial argmax since there's only one element) # - the array itself (already the maximum) - return jnp.array(0), a + return jnp.array(0, dtype=jnp.int32), a # Move axis over which to compute the argmax to the back and flatten last dims # ================================================================================== @@ -65,7 +65,7 @@ def argmax_and_max( max_value_mask = a == _max if where is not None: max_value_mask = jnp.logical_and(max_value_mask, where) - _argmax = jnp.argmax(max_value_mask, axis=-1) + _argmax = jnp.argmax(max_value_mask, axis=-1).astype(jnp.int32) return _argmax, _max.reshape(_argmax.shape) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py new file mode 100644 index 00000000..ad8caeeb --- /dev/null +++ b/src/lcm/simulation/compile.py @@ -0,0 +1,435 @@ +"""AOT-compile simulate functions for a fixed batch size. + +When `Model(n_subjects=N)` is set, `compile_all_simulate_functions(...)` returns +an `internal_regimes` mapping with each regime's `simulate_functions` callables +swapped for AOT-compiled programs sized for batch shape `N`. The existing +simulate call sites then pick them up transparently — no signature changes +downstream. + +Compilation deduplicates callables by identity (only one program per unique +callable), lowers them sequentially (JAX tracing is not thread-safe), then +parallel-compiles them via a `ThreadPoolExecutor` (XLA releases the GIL). +""" + +import dataclasses +import logging +import time +from collections.abc import Callable, Hashable, Mapping +from concurrent.futures import ThreadPoolExecutor, as_completed +from types import MappingProxyType + +import jax +import jax.numpy as jnp +from dags.tree import tree_path_from_qname +from jax import Array + +from lcm.ages import AgeGrid +from lcm.interfaces import InternalRegime +from lcm.simulation.random import generate_simulation_keys +from lcm.solution.solve_brute import ( + _func_dedup_key, + _resolve_compilation_workers, +) +from lcm.typing import ( + FlatRegimeParams, + InternalParams, + RegimeName, +) +from lcm.utils.logging import format_duration +from lcm.utils.namespace import flatten_regime_namespace + + +def compile_all_simulate_functions( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + ages: AgeGrid, + n_subjects: int, + max_compilation_workers: int | None, + logger: logging.Logger, +) -> MappingProxyType[RegimeName, InternalRegime]: + """AOT-compile every unique simulate function for batch shape `n_subjects`. + + Args: + internal_regimes: Original internal regimes from the Model. + internal_params: Immutable mapping of regime names to flat parameter mappings. + ages: AgeGrid for the model. + n_subjects: Batch size for which to compile. + max_compilation_workers: Maximum threads for parallel XLA compilation. + Defaults to `os.cpu_count()`. + logger: Logger. + + Returns: + Immutable mapping of regime names to InternalRegime where each + regime's `simulate_functions` has its callables replaced by + AOT-compiled programs. + + """ + # Per-regime V-shape lookup for building period-specific templates that + # match the *sparse* mapping `simulate.simulate(...)` actually dispatches: + # `period_to_regime_to_V_arr.get(P+1, {})` — only regimes active at P+1. + regime_V_shapes = _get_regime_V_shapes( + internal_regimes=internal_regimes, + internal_params=internal_params, + ) + + unique, func_keys = _collect_unique_simulate_functions( + internal_regimes=internal_regimes, + internal_params=internal_params, + ages=ages, + n_subjects=n_subjects, + regime_V_shapes=regime_V_shapes, + ) + + n_workers = _resolve_compilation_workers( + max_compilation_workers=max_compilation_workers + ) + n_unique = len(unique) + logger.info( + "Simulate AOT compilation: %d unique functions (%d workers)", + n_unique, + n_workers, + ) + + lowered: dict[Hashable, jax.stages.Lowered] = {} + for i, (key, (func, args, label)) in enumerate(unique.items(), 1): + logger.info("%d/%d %s", i, n_unique, label) + logger.info(" lowering ...") + start = time.monotonic() + # `func` is a `jax.jit`-wrapped callable; ty sees only the abstract + # Callable type, so it can't see `.lower(...)`. + lowered[key] = func.lower(**args) # ty: ignore[unresolved-attribute, invalid-argument-type] + # Drop the concrete lower-args once the `Lowered` object has captured + # its abstract values. This releases V-shaped templates, per-regime + # subject-state/action zeros, and the regime-params view before the + # parallel compile pool starts piling Compiled kernels onto the heap. + unique[key] = (func, None, label) + logger.info( + " lowered in %s", format_duration(seconds=time.monotonic() - start) + ) + + compiled: dict[Hashable, jax.stages.Compiled] = {} + + def _compile_and_log( + *, + key: Hashable, + low: jax.stages.Lowered, + label: str, + ) -> tuple[Hashable, jax.stages.Compiled]: + logger.info(" compiling %s ...", label) + start = time.monotonic() + result = low.compile() + logger.info( + " compiled %s %s", + label, + format_duration(seconds=time.monotonic() - start), + ) + return key, result + + with ThreadPoolExecutor(max_workers=n_workers) as pool: + futures = [ + pool.submit(_compile_and_log, key=key, low=low, label=unique[key][2]) + for key, low in lowered.items() + ] + for future in as_completed(futures): + k, c = future.result() + compiled[k] = c + # Release the HLO module held by the `Lowered` object now that + # its `Compiled` counterpart is in `compiled`; otherwise every + # lowered intermediate stays resident until the slowest compile + # finishes. + del lowered[k] + + return _swap_in_compiled( + internal_regimes=internal_regimes, + compiled=compiled, + func_keys=func_keys, + ) + + +def _collect_unique_simulate_functions( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + ages: AgeGrid, + n_subjects: int, + regime_V_shapes: dict[RegimeName, tuple[int, ...]], +) -> tuple[ + dict[Hashable, tuple[Callable, dict | None, str]], + dict[tuple[RegimeName, str, int | None], Hashable], +]: + """Walk every regime/period and dedup the simulate functions to compile. + + `argmax_and_max_Q_over_a` dedup keys on `(func_id, active_at_next_period)` + so two periods that share the same argmax callable but see a different + `next_regime_to_V_arr` pytree (different active-regime set at P+1) get + separate compiled programs whose signature matches what runtime actually + dispatches. + """ + unique: dict[Hashable, tuple[Callable, dict | None, str]] = {} + func_keys: dict[tuple[RegimeName, str, int | None], Hashable] = {} + + for regime_name, regime in internal_regimes.items(): + regime_params = internal_params.get(regime_name, MappingProxyType({})) + sf = regime.simulate_functions + + # `sf.argmax_and_max_Q_over_a` has entries for *every* period + # (pylcm builds them across the full age grid), but the regime is + # only dispatched at runtime for periods in `regime.active_periods`. + # Inactive-period entries can carry a `complete_targets` set whose + # shape doesn't match the regime's actual transitions for that + # period; tracing them would surface `next_` bookkeeping + # mismatches the lazy path never reaches. Restrict AOT to active + # periods to mirror runtime. + for period in regime.active_periods: + argmax_func = sf.argmax_and_max_Q_over_a[period] + active_next = _active_regimes_at_period( + internal_regimes=internal_regimes, period=period + 1 + ) + next_regime_to_V_arr = MappingProxyType( + {name: jnp.zeros(regime_V_shapes[name]) for name in active_next} + ) + args = _build_argmax_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + period=period, + n_subjects=n_subjects, + next_regime_to_V_arr=next_regime_to_V_arr, + ) + key = ("argmax", _func_dedup_key(func=argmax_func), active_next) + func_keys[(regime_name, "argmax", period)] = key + if key not in unique: + label = ( + f"{regime_name}/argmax_and_max_Q_over_a " + f"(age {ages.values[period].item()})" + ) + unique[key] = (jax.jit(argmax_func), args, label) + + # `next_state` / `crtp` are keyed per-regime: each regime's lower-args + # depend on its own state-action shapes, so even when two regimes + # share a callable identity, their compiled programs are distinct. + if not regime.terminal: + args = _build_next_state_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + n_subjects=n_subjects, + ) + key = ("next_state", regime_name, _func_dedup_key(func=sf.next_state)) + func_keys[(regime_name, "next_state", None)] = key + if key not in unique: + # Re-wrap with `jax.jit`: when `fixed_params` are partialled + # into the regime, `sf.next_state` is a `functools.partial` + # (no `.lower()`); plain jit objects are also fine to re-jit. + unique[key] = ( + jax.jit(sf.next_state), + args, + f"{regime_name}/next_state", + ) + + if sf.compute_regime_transition_probs is not None: + args = _build_crtp_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + n_subjects=n_subjects, + ) + key = ( + "crtp", + regime_name, + _func_dedup_key(func=sf.compute_regime_transition_probs), + ) + func_keys[(regime_name, "crtp", None)] = key + if key not in unique: + unique[key] = ( + jax.jit(sf.compute_regime_transition_probs), + args, + f"{regime_name}/compute_regime_transition_probs", + ) + + return unique, func_keys + + +def _swap_in_compiled( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + compiled: dict[Hashable, jax.stages.Compiled], + func_keys: dict[tuple[RegimeName, str, int | None], Hashable], +) -> MappingProxyType[RegimeName, InternalRegime]: + """Swap compiled programs into each regime's `simulate_functions`.""" + new_regimes: dict[RegimeName, InternalRegime] = {} + for regime_name, regime in internal_regimes.items(): + sf = regime.simulate_functions + # Only active periods are AOT-compiled (see + # `_collect_unique_simulate_functions`); leave inactive-period + # entries untouched so the existing closure stays in place — they + # are never dispatched at runtime anyway. + argmax_compiled_for_active = { + period: compiled[func_keys[(regime_name, "argmax", period)]] + for period in regime.active_periods + } + argmax_compiled = MappingProxyType( + { + period: argmax_compiled_for_active.get(period, original_func) + for period, original_func in sf.argmax_and_max_Q_over_a.items() + } + ) + if regime.terminal: + next_state_compiled = sf.next_state + else: + next_state_compiled = compiled[func_keys[(regime_name, "next_state", None)]] + if sf.compute_regime_transition_probs is None: + crtp_compiled = None + else: + crtp_compiled = compiled[func_keys[(regime_name, "crtp", None)]] + + new_sf = dataclasses.replace( + sf, + argmax_and_max_Q_over_a=argmax_compiled, + next_state=next_state_compiled, + compute_regime_transition_probs=crtp_compiled, + ) + new_regimes[regime_name] = dataclasses.replace( + regime, simulate_functions=new_sf + ) + + return MappingProxyType(new_regimes) + + +def _get_regime_V_shapes( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, +) -> dict[RegimeName, tuple[int, ...]]: + """Return per-regime V-array shape (one length per state grid). + + Used to construct zero-shaped templates for `next_regime_to_V_arr` + when lowering each period's argmax — the abstract signature only + needs the shapes, not the values. + """ + shapes: dict[RegimeName, tuple[int, ...]] = {} + for regime_name, regime in internal_regimes.items(): + space = regime.state_action_space( + regime_params=internal_params.get(regime_name, MappingProxyType({})) + ) + shapes[regime_name] = tuple(len(v) for v in space.states.values()) + return shapes + + +def _active_regimes_at_period( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + period: int, +) -> tuple[RegimeName, ...]: + """Tuple of regime names active at `period`, in `internal_regimes` order. + + Returned as a `tuple` so it's hashable and pytree-key-stable. An empty + tuple matches the runtime fallback for periods past the last (`{}`). + """ + return tuple( + regime_name + for regime_name, regime in internal_regimes.items() + if period in regime.active_periods + ) + + +def _build_argmax_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + period: int, + n_subjects: int, + next_regime_to_V_arr: MappingProxyType[RegimeName, Array], +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + return { + **subject_states, + **base.discrete_actions, + **base.continuous_actions, + "next_regime_to_V_arr": next_regime_to_V_arr, + **regime_params, + "period": jnp.int32(period), + "age": ages.values[period], + } + + +def _build_next_state_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + n_subjects: int, +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_actions = _subject_shape_arrays( + {**base.discrete_actions, **base.continuous_actions}, + n_subjects=n_subjects, + ) + + stoch_transition_names = ( + internal_regime.simulate_functions.stochastic_transition_names + ) + stoch_next_func_names = sorted( + next_func_name + for next_func_name in flatten_regime_namespace( + internal_regime.simulate_functions.transitions + ) + if tree_path_from_qname(next_func_name)[-1] in stoch_transition_names + ) + _, stoch_keys = generate_simulation_keys( + key=jax.random.key(0), + names=stoch_next_func_names, + n_initial_states=n_subjects, + ) + + return { + **subject_states, + **subject_actions, + **stoch_keys, + "period": jnp.int32(0), + "age": ages.values[0], + **regime_params, + } + + +def _build_crtp_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + n_subjects: int, +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_actions = _subject_shape_arrays( + {**base.discrete_actions, **base.continuous_actions}, + n_subjects=n_subjects, + ) + return { + **subject_states, + **subject_actions, + "period": jnp.int32(0), + "age": ages.values[0], + **regime_params, + } + + +def _subject_shape_arrays( + base_arrays: Mapping[str, Array], + *, + n_subjects: int, +) -> dict[str, Array]: + """Return zeros of shape `(n_subjects,)` mirroring each base array's dtype. + + With `build_initial_states` casting discrete states to the grid dtype, + runtime states (initial + post-transition) share the grid's dtype, so + using `arr.dtype` from the regime's grid here matches runtime. + """ + return { + name: jnp.zeros((n_subjects,), dtype=arr.dtype) + for name, arr in base_arrays.items() + } diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 2cdd3c3b..69ca8d9d 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -64,10 +64,20 @@ def build_initial_states( for regime_name, internal_regime in internal_regimes.items(): for state_name in _get_regime_state_names(internal_regime): key = f"{regime_name}__{state_name}" - if state_name in initial_states: + grid = internal_regime.grids[state_name] + if isinstance(grid, DiscreteGrid): + # Cast user-supplied discrete states to the grid's index + # dtype so every period's argmax sees a single signature + # for that state. + target_dtype = grid.to_jax().dtype + if state_name in initial_states: + flat[key] = initial_states[state_name].astype(target_dtype) + else: + flat[key] = jnp.full( + n_subjects, MISSING_CAT_CODE, dtype=target_dtype + ) + elif state_name in initial_states: flat[key] = initial_states[state_name] - elif isinstance(internal_regime.grids[state_name], DiscreteGrid): - flat[key] = jnp.full(n_subjects, MISSING_CAT_CODE, dtype=jnp.int32) else: flat[key] = jnp.full(n_subjects, jnp.nan) @@ -348,7 +358,7 @@ def _collect_structural_errors( active_mask = active_mask & (~in_regime | period_active) if not jnp.all(active_mask): - invalid_indices = jnp.where(~active_mask)[0] + invalid_indices = jnp.where(~active_mask)[0].astype(jnp.int32) invalid_combos = { (ids_to_regime_names[int(regime_id_arr[i])], float(age_values[i])) for i in invalid_indices @@ -392,7 +402,7 @@ def _collect_feasibility_errors( errors: list[str] = [] for regime_name, internal_regime in internal_regimes.items(): regime_id = regime_names_to_ids[regime_name] - idx_arr = jnp.where(regime_id_arr == regime_id)[0] + idx_arr = jnp.where(regime_id_arr == regime_id)[0].astype(jnp.int32) subject_indices = idx_arr.tolist() if idx_arr.size > 0 else [] if not subject_indices: continue @@ -653,15 +663,107 @@ def _check_combo(action_kw: dict[str, Array]) -> Array: if not infeasible_indices: return None + per_constraint_admits_any = _per_constraint_feasibility( + internal_regime=internal_regime, + subject_states=subject_states, + regime_params=regime_params, + flat_actions=flat_actions, + idx_arr=idx_arr, + infeasible_indices=infeasible_indices, + ) + return _format_infeasibility_message( infeasible_indices=infeasible_indices, internal_regime=internal_regime, regime_name=regime_name, initial_states=initial_states, state_names=state_names, + per_constraint_admits_any=per_constraint_admits_any, ) +def _admits_any_action( + *, + feasibility_func: Callable[..., Array], + action_kwargs: Mapping[str, Array], + params: Mapping[str, object], +) -> bool: + """Return True iff the feasibility function admits ≥ 1 action under params.""" + if action_kwargs: + + def _check_combo(action_kw: dict[str, Array]) -> Array: + return feasibility_func(**action_kw, **params) + + per_combo = jax.vmap(_check_combo)(action_kwargs) + return bool(jnp.any(per_combo)) + return bool(feasibility_func(**params)) + + +def _per_constraint_feasibility( + *, + internal_regime: InternalRegime, + subject_states: Mapping[str, Array], + regime_params: Mapping[str, object], + flat_actions: Mapping[ActionName, Array], + idx_arr: Array, + infeasible_indices: Sequence[int], +) -> dict[str, np.ndarray]: + """Per-constraint feasibility for the infeasible subjects. + + For each constraint, returns a boolean array (one entry per infeasible + subject) indicating whether that constraint *individually* admits at + least one action. Combined with the regime's feasibility verdict, this + distinguishes "constraint X rejects every action by itself" from + "constraints jointly reject everything despite each admitting some". + + Each constraint's feasibility function has its own argument set (a + subset of the combined feasibility's union); filter `subject_states`, + `action_kwargs`, and `filtered_params` per constraint so dags doesn't + raise on stray kwargs. + """ + constraints = internal_regime.simulate_functions.constraints + functions = internal_regime.simulate_functions.functions + if not constraints or not subject_states: + return {} + + infeasible_positions = np.flatnonzero( + np.isin(np.asarray(idx_arr), np.asarray(infeasible_indices)) + ) + infeasible_states = { + name: arr[infeasible_positions] for name, arr in subject_states.items() + } + + out: dict[str, np.ndarray] = {} + for name, constraint_func in constraints.items(): + single_feasibility = _get_feasibility( + functions=functions, + constraints=MappingProxyType({name: constraint_func}), + ) + accepted = get_union_of_args([single_feasibility]) + single_states = {k: v for k, v in infeasible_states.items() if k in accepted} + single_actions = {k: v for k, v in flat_actions.items() if k in accepted} + single_params = {k: v for k, v in regime_params.items() if k in accepted} + n = len(infeasible_indices) + if not single_states: + # Action-only / parameter-only constraint — identical for all subjects. + admits_any = _admits_any_action( + feasibility_func=single_feasibility, + action_kwargs=single_actions, + params=single_params, + ) + out[name] = np.full(n, admits_any, dtype=bool) + continue + any_feasible = _batched_feasibility_check( + feasibility_func=single_feasibility, + subject_states=single_states, + action_kwargs=single_actions, + filtered_params=single_params, + flat_actions=flat_actions, + ) + out[name] = np.asarray(any_feasible) + return out + + def _raise_feasibility_type_error( *, exc: TypeError, @@ -715,6 +817,7 @@ def _format_infeasibility_message( regime_name: RegimeName, initial_states: Mapping[str, Array], state_names: Sequence[str], + per_constraint_admits_any: Mapping[str, np.ndarray], ) -> str: """Format an error message for infeasible subjects. @@ -724,6 +827,12 @@ def _format_infeasibility_message( regime_name: Name of the regime. initial_states: Mapping of state names to arrays. state_names: List of state variable names. + per_constraint_admits_any: Mapping from constraint name to a boolean + array (one entry per infeasible subject) — True where that + constraint *individually* admits at least one action. False + entries identify constraints that reject every action on their + own; rows with all-True entries are infeasible only because the + constraints jointly reject the action set. Returns: Formatted error message string. @@ -745,9 +854,10 @@ def _format_infeasibility_message( if isinstance(grid, DiscreteGrid) and name in state_df.columns: state_df[name] = [grid.categories[int(v)] for v in state_df[name]] - # Constraint names - constraint_names = list(internal_regime.simulate_functions.constraints.keys()) - constraints_str = "\n".join(f" - {name}" for name in constraint_names) + # Append one boolean column per constraint: True = admits ≥ 1 action, + # False = rejects every action by itself for that subject. + for name, mask in per_constraint_admits_any.items(): + state_df[name] = list(mask) # Truncate for large groups n = len(infeasible_indices) @@ -761,10 +871,11 @@ def _format_infeasibility_message( return ( f"All actions are infeasible for {n} subject(s) " f"in regime '{regime_name}'.\n\n" - f"Active constraints:\n{constraints_str}\n\n" - f"Infeasible subjects:\n{table_str}\n\n" - f"No action combination satisfies all constraints for these " - f"initial states." + f"Per-constraint admissibility (True = constraint admits ≥ 1 " + f"action by itself; False = constraint rejects every action):\n" + f"{table_str}\n\n" + f"No action combination satisfies all constraints jointly for " + f"these initial states." ) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d54d2674..d1ab42ab 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -99,7 +99,9 @@ def simulate( starting_periods = _compute_starting_periods( initial_ages=initial_states["age"], ages=ages ) - subject_regime_ids = jnp.full_like(initial_conditions["regime"], MISSING_CAT_CODE) + subject_regime_ids = jnp.full_like( + initial_conditions["regime"], MISSING_CAT_CODE, dtype=jnp.int32 + ) # Forward simulation simulation_results: dict[RegimeName, dict[int, PeriodRegimeSimulationData]] = { diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index c494d996..ba7cc39c 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -126,7 +126,7 @@ def calculate_next_states( **state_action_space.states, **optimal_actions, **stochastic_variables_keys, - period=period, + period=jnp.int32(period), age=age, **regime_params, ) @@ -187,7 +187,7 @@ def calculate_next_regime_membership( internal_regime.simulate_functions.compute_regime_transition_probs( # ty: ignore[call-non-callable] **state_action_space.states, **optimal_actions, - period=period, + period=jnp.int32(period), age=age, **regime_params, ) @@ -237,8 +237,9 @@ def draw_key_from_dict( """ regime_names = list(d) regime_transition_probs = jnp.array(list(d.values())).T - regime_ids = jnp.array( - [regime_names_to_ids[regime_name] for regime_name in regime_names] + regime_ids = jnp.asarray( + [regime_names_to_ids[regime_name] for regime_name in regime_names], + dtype=jnp.int32, ) def random_id( @@ -286,9 +287,19 @@ def _update_states_for_subjects( for target, target_next_states in computed_next_states.items(): for next_state_name, next_state_values in target_next_states.items(): state_name = f"{target}__{next_state_name.removeprefix('next_')}" + target_dtype = all_states[state_name].dtype + # Preserve storage dtype only when the transition output is the + # same numeric kind. Across kinds (e.g. int storage + float + # transition output) leave JAX's promotion in place; the + # cross-kind boundary cast belongs to Package B. + new_values = ( + next_state_values.astype(target_dtype) + if next_state_values.dtype.kind == target_dtype.kind + else next_state_values + ) updated_states[state_name] = jnp.where( subject_indices, - next_state_values, + new_values, all_states[state_name], ) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index d8152e13..77e63e24 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -510,6 +510,15 @@ def _raise_at( """Run the enriched NaN diagnostic on a single offending row and raise.""" internal_regime = internal_regimes[row.regime_name] regime_params = internal_params[row.regime_name] + # `compute_intermediates` was built from the regime's full `flat_param_names` + # (per-iteration params + fixed params); the live solve loop merges + # `resolved_fixed_params` into `regime_params` implicitly via the partialled + # closures, but we have to do it by hand here to call the diagnostic + # directly. Same merge order as `interfaces.state_action_space` and + # `simulation.result`. + effective_regime_params = MappingProxyType( + {**internal_regime.resolved_fixed_params, **regime_params} + ) state_action_space = internal_regime.state_action_space(regime_params=regime_params) next_regime_to_V_arr = _reconstruct_next_regime_to_V_arr( period=row.period, @@ -529,7 +538,7 @@ def _raise_at( compute_intermediates=compute_intermediates, state_action_space=state_action_space, next_regime_to_V_arr=next_regime_to_V_arr, - internal_params=regime_params, + internal_params=effective_regime_params, period=row.period, ) diff --git a/src/lcm/typing.py b/src/lcm/typing.py index 310c7bc6..da6815cf 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -4,27 +4,27 @@ import pandas as pd from jax import Array -from jaxtyping import Bool, Float, Int, Scalar +from jaxtyping import Bool, Float, Int32, Scalar from lcm.params import MappingLeaf from lcm.params.sequence_leaf import SequenceLeaf type ContinuousState = Float[Array, "..."] type ContinuousAction = Float[Array, "..."] -type DiscreteState = Int[Array, "..."] -type DiscreteAction = Int[Array, "..."] +type DiscreteState = Int32[Array, "..."] +type DiscreteAction = Int32[Array, "..."] type FloatND = Float[Array, "..."] -type IntND = Int[Array, "..."] +type IntND = Int32[Array, "..."] type BoolND = Bool[Array, "..."] type Float1D = Float[Array, "_"] # noqa: F821 -type Int1D = Int[Array, "_"] # noqa: F821 +type Int1D = Int32[Array, "_"] # noqa: F821 type Bool1D = Bool[Array, "_"] # noqa: F821 # Many JAX functions are designed to work with scalar numerical values. This also # includes zero dimensional jax arrays. -type ScalarInt = int | Int[Scalar, ""] +type ScalarInt = int | Int32[Scalar, ""] type ScalarFloat = float | Float[Scalar, ""] type Period = int | Int1D diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index ef631349..faeed9fe 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -372,7 +372,7 @@ def _format_sum_violation( {name: jnp.atleast_1d(arr) for name, arr in state_action_values.items()} ) failing_mask = ~jnp.isclose(sum_all, 1.0) - failing_indices = jnp.where(failing_mask)[0] + failing_indices = jnp.where(failing_mask)[0].astype(jnp.int32) failing_sums = sum_all[failing_mask] n_failing = int(failing_indices.shape[0]) n_show = min(n_failing, 5) diff --git a/tests/regime_building/test_regime_processing.py b/tests/regime_building/test_regime_processing.py index d51f59e2..b2def942 100644 --- a/tests/regime_building/test_regime_processing.py +++ b/tests/regime_building/test_regime_processing.py @@ -4,11 +4,14 @@ import jax.numpy as jnp import numpy as np import pandas as pd +import pytest from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal +from lcm import Regime, categorical from lcm.ages import AgeGrid from lcm.grids import DiscreteGrid, LinSpacedGrid +from lcm.interfaces import InternalRegime from lcm.regime_building.processing import ( _rename_params_to_qnames, process_regimes, @@ -177,6 +180,61 @@ def wealth_constraint(wealth): assert got.index.is_unique +@pytest.fixture(name="two_non_terminal_internal_regimes") +def _two_non_terminal_internal_regimes() -> MappingProxyType[str, InternalRegime]: + """Two non-terminal regimes that share underlying user functions.""" + + def next_x(x): + return x + + def regime_transition(age, final_age): + return jnp.where(age >= final_age, 1, 0) + + @categorical(ordered=False) + class TwoRegimeId: + early: int + late: int + + early = Regime( + transition=regime_transition, + states={"x": LinSpacedGrid(start=0, stop=10, n_points=4)}, + state_transitions={"x": next_x}, + functions={"utility": lambda x: x}, + active=lambda age: age < 1, + ) + late = Regime( + transition=regime_transition, + states={"x": LinSpacedGrid(start=0, stop=10, n_points=6)}, + state_transitions={"x": next_x}, + functions={"utility": lambda x: x}, + active=lambda age: age >= 1, + ) + return process_regimes( + regimes={"early": early, "late": late}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_names_to_ids=MappingProxyType({"early": 0, "late": 1}), + enable_jit=True, + ) + + +@pytest.mark.parametrize( + "attr", + ["next_state", "compute_regime_transition_probs"], +) +def test_simulate_functions_use_per_regime_callables( + two_non_terminal_internal_regimes: MappingProxyType[str, InternalRegime], + attr: str, +) -> None: + """Two regimes built from shared user functions get distinct simulate callables.""" + early_func = getattr( + two_non_terminal_internal_regimes["early"].simulate_functions, attr + ) + late_func = getattr( + two_non_terminal_internal_regimes["late"].simulate_functions, attr + ) + assert id(early_func) != id(late_func) + + def test_rename_params_to_qnames_with_partial(): """Regression: dags >=0.5.1 renames bound partial keywords to qualified names.""" diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py new file mode 100644 index 00000000..698256e4 --- /dev/null +++ b/tests/simulation/test_simulate_aot.py @@ -0,0 +1,326 @@ +"""Tests for simulate-AOT compilation via `Model.n_subjects`. + +When `Model(n_subjects=N)` is set, the first matching `simulate(...)` call +parallel-compiles all simulate functions for batch shape `N`. Subsequent calls +with size `N` reuse the cache; calls with a mismatching size warn once per size +and fall back to the runtime-traced path. AOT works under both `x64=False` +and `x64=True` because integer leaves are normalised to `int32` at every +boundary by `lcm.params.processing` and the simulate state pool. +""" + +import logging +import threading +from dataclasses import dataclass +from typing import Any + +import cloudpickle +import jax.numpy as jnp +import jax.stages +import pytest +from jax import Array + +from lcm import Model +from lcm.ages import AgeGrid +from tests.test_models.deterministic.regression import ( + RegimeId, + dead, + get_params, + working_life, +) + + +def _build_test_model(*, n_periods: int, n_subjects: int | None = None) -> Model: + """Construct the small 2-regime regression model with optional n_subjects.""" + final_age_alive = 18 + n_periods - 2 + return Model( + regimes={ + "working_life": working_life.replace( + active=lambda age: age <= final_age_alive, + ), + "dead": dead, + }, + ages=AgeGrid(start=18, stop=final_age_alive + 1, step="Y"), + regime_id_class=RegimeId, + n_subjects=n_subjects, + ) + + +def _build_initial_conditions(*, n_subjects: int) -> dict[str, Array]: + """Subject array of size `n_subjects` matching the regression test model.""" + wealths = jnp.linspace(20.0, 320.0, num=n_subjects) + return { + "wealth": wealths, + "age": jnp.full((n_subjects,), 18.0), + "regime": jnp.array([RegimeId.working_life] * n_subjects), + } + + +@pytest.mark.parametrize("invalid", [0, -3]) +def test_n_subjects_validation_rejects_non_positive(invalid: int) -> None: + """`Model(n_subjects=0)` and negative values raise `ValueError`.""" + with pytest.raises(ValueError, match="n_subjects"): + _build_test_model(n_periods=3, n_subjects=invalid) + + +def test_n_subjects_validation_rejects_non_int() -> None: + """`Model(n_subjects=1.5)` raises `TypeError`.""" + with pytest.raises(TypeError, match="n_subjects"): + _build_test_model(n_periods=3, n_subjects=1.5) # ty: ignore[invalid-argument-type] + + +def test_n_subjects_none_leaves_aot_cache_empty_after_simulate() -> None: + """`Model(n_subjects=None)` keeps `_simulate_compile_cache` empty after simulate.""" + n_periods = 3 + model = _build_test_model(n_periods=n_periods, n_subjects=None) + params = get_params(n_periods=n_periods) + + model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=4), + ) + + assert dict(model._simulate_compile_cache) == {} + + +def test_n_subjects_none_yields_simulate_result_sized_to_actual() -> None: + """`Model(n_subjects=None).simulate(...)` returns a result sized to the input.""" + n_periods = 3 + model = _build_test_model(n_periods=n_periods, n_subjects=None) + params = get_params(n_periods=n_periods) + + result = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=4), + ) + + assert result.n_subjects == 4 + + +def test_simulate_second_matching_call_does_not_invoke_compile( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Matching second `simulate(...)` invokes `Lowered.compile` zero times.""" + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + counter = {"count": 0} + original_compile = jax.stages.Lowered.compile + + def counting_compile( + self: jax.stages.Lowered, *args: Any, **kwargs: Any + ) -> jax.stages.Compiled: + counter["count"] += 1 + return original_compile(self, *args, **kwargs) + + monkeypatch.setattr(jax.stages.Lowered, "compile", counting_compile) + + initial_conditions = _build_initial_conditions(n_subjects=n_subjects) + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + counter["count"] = 0 + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + + assert counter["count"] == 0 + + +def test_simulate_first_matching_call_populates_aot_cache() -> None: + """Matching first `simulate(...)` populates the cache for that size.""" + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + assert n_subjects not in model._simulate_compile_cache + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=_build_initial_conditions(n_subjects=n_subjects), + ) + + assert n_subjects in model._simulate_compile_cache + + +def test_solve_does_not_populate_simulate_compile_cache() -> None: + """`solve(...)` does not touch simulate-side compile state. + + Simulate AOT compilation is driven entirely by `simulate(...)`; calling + `solve(...)` alone leaves `_simulate_compile_cache` empty. + """ + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + + model.solve(params=params) + + assert dict(model._simulate_compile_cache) == {} + + +_DECLARED_N = 4 +_ACTUAL_N = 7 + + +@dataclass(frozen=True) +class _MismatchOutcome: + """Captured simulate-with-mismatch artefacts for assertion.""" + + warnings: list[logging.LogRecord] + model: Model + + +@pytest.fixture(name="mismatch_outcome") +def _mismatch_outcome( + caplog: pytest.LogCaptureFixture, +) -> _MismatchOutcome: + """Run one mismatching `simulate(...)` and capture the WARNING records.""" + n_periods = 3 + model = _build_test_model(n_periods=n_periods, n_subjects=_DECLARED_N) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + with caplog.at_level(logging.WARNING, logger="lcm"): + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=_build_initial_conditions(n_subjects=_ACTUAL_N), + ) + + warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() + ] + return _MismatchOutcome(warnings=warnings, model=model) + + +def test_simulate_mismatch_emits_one_warning( + mismatch_outcome: _MismatchOutcome, +) -> None: + """A single mismatching call logs exactly one WARNING.""" + assert len(mismatch_outcome.warnings) == 1 + + +def test_simulate_mismatch_warning_names_declared_n( + mismatch_outcome: _MismatchOutcome, +) -> None: + """The mismatch warning message contains the declared `n_subjects`.""" + msg = mismatch_outcome.warnings[0].getMessage() + assert str(_DECLARED_N) in msg + + +def test_simulate_mismatch_warning_names_actual_n( + mismatch_outcome: _MismatchOutcome, +) -> None: + """The mismatch warning message contains the actual `n_subjects`.""" + msg = mismatch_outcome.warnings[0].getMessage() + assert str(_ACTUAL_N) in msg + + +def test_simulate_mismatch_does_not_populate_cache( + mismatch_outcome: _MismatchOutcome, +) -> None: + """A mismatching `n_subjects` falls back to the lazy path — no cache entry.""" + assert _ACTUAL_N not in mismatch_outcome.model._simulate_compile_cache + + +def test_simulate_warns_only_once_per_mismatching_size( + caplog: pytest.LogCaptureFixture, +) -> None: + """Two calls with the same mismatching size produce only one WARNING.""" + n_periods = 3 + model = _build_test_model(n_periods=n_periods, n_subjects=_DECLARED_N) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + initial_conditions = _build_initial_conditions(n_subjects=_ACTUAL_N) + + with caplog.at_level(logging.WARNING, logger="lcm"): + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + + mismatch_warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() + ] + assert len(mismatch_warnings) == 1 + + +def test_simulate_result_pickles_when_n_subjects_matches() -> None: + """`simulate(...)` returns a result that round-trips through cloudpickle. + + With `n_subjects` matching the batch shape, the simulate path runs + AOT-compiled callables that wrap `LoadedExecutable` (unpicklable). + `to_dataframe` doesn't need those callables, so the returned result + must carry the lazy regimes — otherwise downstream pickling + (e.g. pytask handing the result to the next task) fails. + """ + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + + result = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=n_subjects), + ) + + restored = cloudpickle.loads(cloudpickle.dumps(result)) + assert restored.n_subjects == n_subjects + + +def test_unpickled_model_can_simulate_with_aot() -> None: + """A cloudpickle round-tripped `Model` still drives `simulate(...)` with AOT.""" + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + initial_conditions = _build_initial_conditions(n_subjects=n_subjects) + + # Populate the AOT cache before pickling — confirms __getstate__ drops it. + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + assert n_subjects in model._simulate_compile_cache + + restored = cloudpickle.loads(cloudpickle.dumps(model)) + + # The restored Model starts with empty AOT state and a fresh lock. + assert dict(restored._simulate_compile_cache) == {} + assert restored._warned_n_subjects == set() + assert isinstance(restored._simulate_compile_lock, type(threading.Lock())) + + # Simulate works post-unpickle and re-populates the cache for that size. + restored.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + assert n_subjects in restored._simulate_compile_cache diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100644 index 00000000..43894cbd --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,56 @@ +"""Tests for `lcm.dtypes` boundary-cast helpers.""" + +import jax.numpy as jnp +import numpy as np +import pytest + +from lcm.dtypes import safe_to_int32 + + +@pytest.mark.parametrize( + "value", + [7, np.asarray([0, 1, -3], dtype=np.int64)], + ids=["python-int", "int64-array"], +) +def test_safe_to_int32_returns_int32(value: object) -> None: + """`safe_to_int32` returns a `jnp.int32` array for any in-range int input.""" + out = safe_to_int32(value, name="x") + assert out.dtype == jnp.int32 + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (7, 7), + (np.asarray([0, 1, -3], dtype=np.int64), [0, 1, -3]), + ], + ids=["python-int", "int64-array"], +) +def test_safe_to_int32_preserves_in_range_values( + value: object, expected: object +) -> None: + """`safe_to_int32` preserves element values for in-range inputs.""" + out = safe_to_int32(value, name="x") + np.testing.assert_array_equal(np.asarray(out), expected) + + +def test_safe_to_int32_raises_on_python_int_overflow() -> None: + """A Python int above int32 max raises `ValueError` naming the leaf.""" + with pytest.raises(ValueError, match="my_param"): + safe_to_int32(2**32, name="my_param") + + +def test_safe_to_int32_raises_on_array_overflow() -> None: + """An int64 array containing values above int32 max raises with the leaf name.""" + # Use numpy here: `jnp.asarray(..., dtype=jnp.int64)` truncates to int32 + # under `jax_enable_x64=False` and trips JAX's own overflow guard before + # `safe_to_int32` ever sees the value. + arr = np.asarray([1, 2, 2**32], dtype=np.int64) + with pytest.raises(ValueError, match="regime"): + safe_to_int32(arr, name="regime") + + +def test_safe_to_int32_raises_on_underflow() -> None: + """A Python int below int32 min raises `ValueError` naming the leaf.""" + with pytest.raises(ValueError, match="offset"): + safe_to_int32(-(2**40), name="offset") diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py new file mode 100644 index 00000000..9e9c6011 --- /dev/null +++ b/tests/test_int_dtype_invariants.py @@ -0,0 +1,223 @@ +"""Integer dtypes are pinned to int32 across pylcm regardless of x64 mode.""" + +from types import MappingProxyType + +import jax.numpy as jnp +import numpy as np +import pandas as pd +import pytest + +from lcm import Model +from lcm.ages import AgeGrid +from lcm.params import MappingLeaf +from lcm.params.processing import process_params +from lcm.params.sequence_leaf import SequenceLeaf +from lcm.simulation.initial_conditions import ( + MISSING_CAT_CODE, + build_initial_states, +) +from lcm.simulation.transitions import _update_states_for_subjects +from tests.test_models.deterministic.regression import ( + RegimeId, + dead, + get_model, + get_params, + working_life, +) + + +def test_discrete_grid_to_jax_is_int32() -> None: + """Every `DiscreteGrid.to_jax()` in the model returns an `int32` array.""" + model = get_model(n_periods=3) + for regime in model.regimes.values(): + for grid in {**regime.states, **regime.actions}.values(): + jax_arr = grid.to_jax() + if jax_arr.dtype.kind == "i": + assert jax_arr.dtype == jnp.int32, ( + f"Discrete grid yielded {jax_arr.dtype}, expected int32." + ) + + +def test_build_initial_states_discrete_dtype_is_int32() -> None: + """`build_initial_states` casts every discrete state array to `int32`.""" + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.array([20.0, 50.0]), + "age": jnp.array([18.0, 18.0]), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + for key, arr in flat.items(): + if arr.dtype.kind == "i": + assert arr.dtype == jnp.int32, ( + f"Initial state {key} has dtype {arr.dtype}, expected int32." + ) + + +def test_missing_cat_code_is_int32_minimum() -> None: + """`MISSING_CAT_CODE` equals `iinfo(int32).min` — never a real category code.""" + assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE + + +def test_update_states_for_subjects_preserves_storage_dtype() -> None: + """A transition that returns int64 cannot promote the storage pool to int64.""" + all_states = MappingProxyType( + {"work__health": jnp.asarray([0, 1, 0, 1], dtype=jnp.int32)} + ) + int64_next = jnp.asarray([1, 1, 1, 1], dtype=jnp.int64) + computed = MappingProxyType({"work": MappingProxyType({"next_health": int64_next})}) + subjects = jnp.asarray([True, False, True, False]) + + updated = _update_states_for_subjects( + all_states=all_states, + computed_next_states=computed, + subject_indices=subjects, + ) + + assert updated["work__health"].dtype == jnp.int32 + + +def test_process_params_casts_python_int_to_int32() -> None: + """A Python `int` param leaf is cast to `jnp.int32`.""" + template = MappingProxyType({"regime_a": MappingProxyType({"final_age": "int"})}) + user_params = {"regime_a": {"final_age": 65}} + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + final_age = out["regime_a"]["final_age"] + assert int(final_age) == 65 + assert final_age.dtype == jnp.int32 # ty: ignore[unresolved-attribute] + + +def test_process_params_casts_int64_array_to_int32() -> None: + """A `jnp.int64` array param leaf is normalised to `jnp.int32`.""" + template = MappingProxyType({"regime_a": MappingProxyType({"schedule": "Array"})}) + user_params = {"regime_a": {"schedule": jnp.asarray([0, 1, 2], dtype=jnp.int64)}} + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + schedule = out["regime_a"]["schedule"] + assert schedule.dtype == jnp.int32 # ty: ignore[unresolved-attribute] + + +def test_process_params_int_array_overflow_raises_with_qualified_name() -> None: + """An out-of-int32-range int array surfaces the param's qualified name.""" + template = MappingProxyType({"regime_a": MappingProxyType({"big_param": "Array"})}) + # Numpy here: under `jax_enable_x64=False`, `jnp.asarray(..., dtype=int64)` + # of an out-of-int32 value raises before our helper sees it. + user_params = {"regime_a": {"big_param": np.asarray([0, 2**40], dtype=np.int64)}} + + with pytest.raises(ValueError, match="big_param"): + process_params( + params=user_params, # ty: ignore[invalid-argument-type] + params_template=template, # ty: ignore[invalid-argument-type] + ) + + +@pytest.mark.parametrize("key", ["low", "high"]) +def test_process_params_casts_int_array_inside_mapping_leaf_to_int32(key: str) -> None: + """`MappingLeaf` int arrays land at `jnp.int32` after params processing.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "MappingLeaf"})} + ) + user_params = { + "regime_a": { + "sched": MappingLeaf( + { + "low": jnp.asarray([0, 1], dtype=jnp.int64), + "high": jnp.asarray([10, 20], dtype=jnp.int64), + } + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + assert ( + out["regime_a"]["sched"].data[key].dtype # ty: ignore[unresolved-attribute] + == jnp.int32 + ) + + +@pytest.mark.parametrize("index", [0, 1]) +def test_process_params_casts_int_array_inside_sequence_leaf_to_int32( + index: int, +) -> None: + """`SequenceLeaf` int arrays land at `jnp.int32` after params processing.""" + template = MappingProxyType( + {"regime_a": MappingProxyType({"sched": "SequenceLeaf"})} + ) + user_params = { + "regime_a": { + "sched": SequenceLeaf( + [ + jnp.asarray([0, 1], dtype=jnp.int64), + jnp.asarray([10, 20], dtype=jnp.int64), + ] + ) + } + } + + out = process_params( + params=user_params, + params_template=template, # ty: ignore[invalid-argument-type] + ) + + assert ( + out["regime_a"]["sched"].data[index].dtype # ty: ignore[unresolved-attribute] + == jnp.int32 + ) + + +def test_simulate_accepts_int64_regime_initial_condition_and_round_trips() -> None: + """`regime` as `jnp.int64` simulates the same as `jnp.int32`.""" + n_periods = 3 + final_age_alive = 18 + n_periods - 2 + model = Model( + regimes={ + "working_life": working_life.replace( + active=lambda age: age <= final_age_alive, + ), + "dead": dead, + }, + ages=AgeGrid(start=18, stop=final_age_alive + 1, step="Y"), + regime_id_class=RegimeId, + ) + params = get_params(n_periods=n_periods) + + common = { + "wealth": jnp.linspace(20.0, 80.0, num=4), + "age": jnp.full((4,), 18.0), + } + initial_conditions_int32 = { + **common, + "regime": jnp.asarray([RegimeId.working_life] * 4, dtype=jnp.int32), + } + initial_conditions_int64 = { + **common, + "regime": jnp.asarray([RegimeId.working_life] * 4, dtype=jnp.int64), + } + + df_int32 = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=initial_conditions_int32, + ).to_dataframe() + df_int64 = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=initial_conditions_int64, + ).to_dataframe() + + pd.testing.assert_frame_equal(df_int64, df_int32, check_dtype=False) diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 162650cb..a1f688c5 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -3,7 +3,6 @@ import jax.numpy as jnp import pandas as pd -from pybaum import tree_equal from lcm.ages import AgeGrid from lcm.grids import DiscreteGrid @@ -13,10 +12,7 @@ get_next_state_function_for_simulation, get_next_state_function_for_solution, ) -from lcm.typing import ( - ContinuousState, - FloatND, -) +from lcm.typing import ContinuousState from tests.test_models.deterministic.regression import dead, working_life @@ -59,14 +55,19 @@ def test_get_next_state_function_with_solve_target(): def test_get_next_state_function_with_simulate_target(): + """Outputs are nested by target regime: `{target: {next_state: array}}`. + + The combined function dispatches inputs to the per-target DAG and + returns a mapping from target regime name to that target's + `{next_: array}` outputs, matching what + `_update_states_for_subjects` consumes. + """ + def f_a(state: ContinuousState) -> ContinuousState: - return state[0] + return state * 2.0 def f_b(state: ContinuousState) -> ContinuousState: - return None # ty: ignore[invalid-return-type] - - def f_weight_b(state: ContinuousState) -> FloatND: - return jnp.array([0.0, 1.0]) + return state + 1.0 @dataclass class MockCategory: @@ -76,11 +77,12 @@ class MockCategory: all_grids = MappingProxyType( {"mock": MappingProxyType({"b": DiscreteGrid(MockCategory)})} ) - variable_info = pd.DataFrame({"is_shock": [False]}) + variable_info = pd.DataFrame({"is_shock": [False]}, index=["b"]) transitions = MappingProxyType( {"mock": MappingProxyType({"next_a": f_a, "next_b": f_b})} ) - functions = MappingProxyType({"utility": lambda: 0, "f_weight_b": f_weight_b}) + functions = MappingProxyType({"utility": lambda: 0}) + got_func = get_next_state_function_for_simulation( transitions=transitions, # ty: ignore[invalid-argument-type] functions=functions, # ty: ignore[invalid-argument-type] @@ -88,11 +90,12 @@ class MockCategory: variable_info=variable_info, ) - key = jnp.arange(2, dtype="uint32") - got = got_func(state=jnp.arange(2), key_b=key) + got = got_func(state=jnp.array([1.0, 2.0])) - expected = {"a": jnp.array([0]), "b": jnp.array([1])} - assert tree_equal(expected, got) + assert set(got.keys()) == {"mock"} + assert set(got["mock"].keys()) == {"next_a", "next_b"} + assert jnp.array_equal(got["mock"]["next_a"], jnp.array([2.0, 4.0])) + assert jnp.array_equal(got["mock"]["next_b"], jnp.array([2.0, 3.0])) def test_create_stochastic_next_func():