diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 477d5635..a9364879 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -54,7 +54,7 @@ def _build() -> tuple[object, object, object]: get_benchmark_params, ) - model = create_benchmark_model() + model = create_benchmark_model(n_subjects=_N_SUBJECTS) _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 diff --git a/pixi.lock b/pixi.lock index 5c8127d0..2421c6ce 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=3f215a2c44237b9fa3fa74bf78ef93c8d695a517#3f215a2c44237b9fa3fa74bf78ef93c8d695a517 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=3f215a2c44237b9fa3fa74bf78ef93c8d695a517#3f215a2c44237b9fa3fa74bf78ef93c8d695a517 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev153+gc585f4b82.d20260504 - sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e + version: 0.0.2.dev222+g62392c11e.d20260506 + sha256: 07cd0cb56028dc357c5b4c57c3b06d2f91a85bb499494767e1fa2dd259cdb283 requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index 3fe01162..1c0a12e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "3f215a2c44237b9fa3fa74bf78ef93c8d695a517" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } diff --git a/src/lcm/grids/discrete.py b/src/lcm/grids/discrete.py index f844aad6..acdb850d 100644 --- a/src/lcm/grids/discrete.py +++ b/src/lcm/grids/discrete.py @@ -48,5 +48,15 @@ def batch_size(self) -> int: return self.__batch_size def to_jax(self) -> Int1D: - """Convert the grid to a Jax array.""" - return jnp.array(self.codes) + """Convert the grid to a Jax array. + + Discrete state/action codes are pinned to `int32` regardless of the + ambient `jax_enable_x64` setting. `jnp.array([...])` would otherwise + produce `int32` in 32-bit mode and `int64` in x64 mode, and + downstream values (transitions, V-array indexing, action lookups) + inherit that ambiguity — which silently splits the JIT cache into + per-period int32/int64 variants and breaks any AOT-compiled + program that ships a single signature. `int32` covers any realistic + category count and matches the `MISSING_CAT_CODE` sentinel. + """ + return jnp.array(self.codes, dtype=jnp.int32) diff --git a/src/lcm/model.py b/src/lcm/model.py index 9412d660..8935535a 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -1,6 +1,7 @@ """Collection of classes that are used by the user to define the model and grids.""" import dataclasses +import logging from collections.abc import Mapping from pathlib import Path from types import MappingProxyType @@ -30,6 +31,7 @@ ) from lcm.regime import Regime from lcm.regime_building.processing import InternalRegime +from lcm.simulation.compile import compile_all_simulate_functions from lcm.simulation.initial_conditions import validate_initial_conditions from lcm.simulation.result import SimulationResult, get_simulation_output_dtypes from lcm.simulation.simulate import simulate @@ -85,6 +87,16 @@ class Model: fixed_params: UserParams """Parameters fixed at model initialization.""" + n_subjects: int | None = None + """Expected simulate batch size; enables AOT compile of simulate functions. + + When set, the first matching `simulate(...)` call AOT-compiles all simulate + functions for batch shape `n_subjects` in parallel. Subsequent calls with the + same size reuse the compiled programs. Calls with a mismatching size warn + once per size and fall back to the runtime-traced path. `None` keeps the + purely lazy behaviour. + """ + _params_template: ParamsTemplate """Template for the model parameters.""" @@ -100,6 +112,7 @@ def __init__( derived_categoricals: Mapping[FunctionName, DiscreteGrid] = MappingProxyType( {} ), + n_subjects: int | None = None, ) -> None: """Initialize the Model. @@ -115,17 +128,27 @@ def __init__( not in states/actions. Broadcast to all regimes (merged with each regime's own `derived_categoricals`). Raises if a regime already has a conflicting entry. + n_subjects: Expected simulate batch size; if set, the first matching + `simulate(...)` call AOT-compiles all simulate functions for + batch shape `n_subjects` in parallel. `None` keeps the purely + lazy behaviour. """ self.description = description self.ages = ages self.n_periods = ages.n_periods self.fixed_params = ensure_containers_are_immutable(fixed_params) + self.n_subjects = n_subjects + self._simulate_compile_cache: dict[ + int, MappingProxyType[RegimeName, InternalRegime] + ] = {} + self._warned_n_subjects: set[int] = set() validate_model_inputs( n_periods=self.n_periods, regimes=regimes, regime_id_class=regime_id_class, + n_subjects=n_subjects, ) self.regime_names_to_ids = MappingProxyType( dict( @@ -240,6 +263,46 @@ def solve( ) return period_to_regime_to_V_arr + def _resolve_simulate_internal_regimes( + self, + *, + actual_n_subjects: int, + internal_params: InternalParams, + log: logging.Logger, + max_compilation_workers: int | None, + ) -> MappingProxyType[RegimeName, InternalRegime]: + """Return internal_regimes to use for simulate; AOT cache when matching. + + Returns the original `internal_regimes` when `n_subjects` is `None` or + when the actual batch size mismatches the declared one (logging a + warning once per mismatching size). Otherwise builds and caches the + AOT-compiled regimes for the matching size. + """ + if self.n_subjects is None: + return self.internal_regimes + if actual_n_subjects != self.n_subjects: + if actual_n_subjects not in self._warned_n_subjects: + log.warning( + "simulate called with n_subjects=%d but model declared " + "n_subjects=%d; falling back to runtime compile.", + actual_n_subjects, + self.n_subjects, + ) + self._warned_n_subjects.add(actual_n_subjects) + return self.internal_regimes + if self.n_subjects not in self._simulate_compile_cache: + self._simulate_compile_cache[self.n_subjects] = ( + compile_all_simulate_functions( + internal_regimes=self.internal_regimes, + internal_params=internal_params, + ages=self.ages, + n_subjects=self.n_subjects, + max_compilation_workers=max_compilation_workers, + logger=log, + ) + ) + return self._simulate_compile_cache[self.n_subjects] + def simulate( self, *, @@ -341,10 +404,17 @@ def simulate( ) exc.add_note(f"Snapshot saved to {snap_dir}") raise + actual_n_subjects = len(next(iter(initial_conditions.values()))) + simulate_internal_regimes = self._resolve_simulate_internal_regimes( + actual_n_subjects=actual_n_subjects, + internal_params=internal_params, + log=log, + max_compilation_workers=max_compilation_workers, + ) result = simulate( internal_params=internal_params, initial_conditions=initial_conditions, - internal_regimes=self.internal_regimes, + internal_regimes=simulate_internal_regimes, regime_names_to_ids=self.regime_names_to_ids, logger=log, period_to_regime_to_V_arr=period_to_regime_to_V_arr, diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index 23dd60a7..d141c6a3 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -147,8 +147,11 @@ def validate_model_inputs( n_periods: int, regimes: Mapping[RegimeName, Regime], regime_id_class: type, + n_subjects: int | None = None, ) -> None: """Validate model constructor inputs.""" + _fail_if_invalid_n_subjects(n_subjects=n_subjects) + # Early exit if regimes are not lcm.Regime instances if not all(isinstance(regime, Regime) for regime in regimes.values()): raise ModelInitializationError( @@ -201,6 +204,19 @@ def validate_model_inputs( raise ModelInitializationError(msg) +def _fail_if_invalid_n_subjects(*, n_subjects: int | None) -> None: + """Raise TypeError if non-int, ValueError if non-positive.""" + if n_subjects is None: + return + # `bool` is a subclass of `int`; reject explicitly so True/False don't slip through. + if not isinstance(n_subjects, int) or isinstance(n_subjects, bool): + msg = f"n_subjects must be an int or None, got {type(n_subjects).__name__}." + raise TypeError(msg) + if n_subjects <= 0: + msg = f"n_subjects must be a positive integer, got {n_subjects}." + raise ValueError(msg) + + def _validate_all_variables_used(regimes: Mapping[RegimeName, Regime]) -> list[str]: """Validate that all states and actions are used somewhere in each regime. diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 2d696a45..cf1b4cb0 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -153,7 +153,8 @@ def initial_conditions_from_dataframe( # noqa: C901 for col, arr in result_arrays.items() } initial_conditions["regime"] = jnp.array( - df["regime"].map(dict(regime_names_to_ids)).to_numpy() + df["regime"].map(dict(regime_names_to_ids)).to_numpy(), + dtype=jnp.int32, ) return initial_conditions diff --git a/src/lcm/regime_building/argmax.py b/src/lcm/regime_building/argmax.py index 0e48cf2b..a4271e2f 100644 --- a/src/lcm/regime_building/argmax.py +++ b/src/lcm/regime_building/argmax.py @@ -43,7 +43,7 @@ def argmax_and_max( # When there are no dimensions to reduce over, return: # - index 0 (trivial argmax since there's only one element) # - the array itself (already the maximum) - return jnp.array(0), a + return jnp.array(0, dtype=jnp.int32), a # Move axis over which to compute the argmax to the back and flatten last dims # ================================================================================== @@ -65,7 +65,7 @@ def argmax_and_max( max_value_mask = a == _max if where is not None: max_value_mask = jnp.logical_and(max_value_mask, where) - _argmax = jnp.argmax(max_value_mask, axis=-1) + _argmax = jnp.argmax(max_value_mask, axis=-1).astype(jnp.int32) return _argmax, _max.reshape(_argmax.shape) diff --git a/src/lcm/simulation/compile.py b/src/lcm/simulation/compile.py new file mode 100644 index 00000000..c325b27f --- /dev/null +++ b/src/lcm/simulation/compile.py @@ -0,0 +1,417 @@ +"""AOT-compile simulate functions for a fixed batch size. + +When `Model(n_subjects=N)` is set, `compile_all_simulate_functions(...)` returns +an `internal_regimes` mapping with each regime's `simulate_functions` callables +swapped for AOT-compiled programs sized for batch shape `N`. The existing +simulate call sites then pick them up transparently — no signature changes +downstream. + +Mirrors the pattern in `solve_brute._compile_all_functions`: deduplicate by +callable identity, sequentially lower (tracing is not thread-safe), then +parallel-compile via `ThreadPoolExecutor` (XLA releases the GIL). +""" + +import dataclasses +import logging +import time +from collections.abc import Callable, Hashable, Mapping +from concurrent.futures import ThreadPoolExecutor, as_completed +from types import MappingProxyType + +import jax +import jax.numpy as jnp +from dags.tree import tree_path_from_qname +from jax import Array + +from lcm.ages import AgeGrid +from lcm.interfaces import InternalRegime +from lcm.simulation.random import generate_simulation_keys +from lcm.solution.solve_brute import ( + _func_dedup_key, + _resolve_compilation_workers, +) +from lcm.typing import ( + FlatRegimeParams, + InternalParams, + RegimeName, +) +from lcm.utils.logging import format_duration +from lcm.utils.namespace import flatten_regime_namespace + + +def compile_all_simulate_functions( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + ages: AgeGrid, + n_subjects: int, + max_compilation_workers: int | None, + logger: logging.Logger, +) -> MappingProxyType[RegimeName, InternalRegime]: + """AOT-compile every unique simulate function for batch shape `n_subjects`. + + Args: + internal_regimes: Original internal regimes from the Model. + internal_params: Immutable mapping of regime names to flat parameter mappings. + ages: AgeGrid for the model. + n_subjects: Batch size for which to compile. + max_compilation_workers: Maximum threads for parallel XLA compilation. + Defaults to `os.cpu_count()`. + logger: Logger. + + Returns: + Immutable mapping of regime names to InternalRegime where each + regime's `simulate_functions` has its callables replaced by + AOT-compiled programs. + + """ + # Per-regime V-shape lookup for building period-specific templates that + # match the *sparse* mapping `simulate.simulate(...)` actually dispatches: + # `period_to_regime_to_V_arr.get(P+1, {})` — only regimes active at P+1. + regime_V_shapes = _get_regime_V_shapes( + internal_regimes=internal_regimes, + internal_params=internal_params, + ) + + unique, func_keys = _collect_unique_simulate_functions( + internal_regimes=internal_regimes, + internal_params=internal_params, + ages=ages, + n_subjects=n_subjects, + regime_V_shapes=regime_V_shapes, + ) + + n_workers = _resolve_compilation_workers( + max_compilation_workers=max_compilation_workers + ) + n_unique = len(unique) + logger.info( + "Simulate AOT compilation: %d unique functions (%d workers)", + n_unique, + n_workers, + ) + + lowered: dict[Hashable, jax.stages.Lowered] = {} + for i, (key, (func, args, label)) in enumerate(unique.items(), 1): + logger.info("%d/%d %s", i, n_unique, label) + logger.info(" lowering ...") + start = time.monotonic() + # `func` is a `jax.jit`-wrapped callable; ty sees only the abstract + # Callable type, so it can't see `.lower(...)`. + lowered[key] = func.lower(**args) # ty: ignore[unresolved-attribute] + logger.info( + " lowered in %s", format_duration(seconds=time.monotonic() - start) + ) + + compiled: dict[Hashable, jax.stages.Compiled] = {} + + def _compile_and_log( + *, + key: Hashable, + low: jax.stages.Lowered, + label: str, + ) -> tuple[Hashable, jax.stages.Compiled]: + logger.info(" compiling %s ...", label) + start = time.monotonic() + result = low.compile() + logger.info( + " compiled %s %s", + label, + format_duration(seconds=time.monotonic() - start), + ) + return key, result + + with ThreadPoolExecutor(max_workers=n_workers) as pool: + futures = [ + pool.submit(_compile_and_log, key=key, low=low, label=unique[key][2]) + for key, low in lowered.items() + ] + for future in as_completed(futures): + k, c = future.result() + compiled[k] = c + + return _swap_in_compiled( + internal_regimes=internal_regimes, + compiled=compiled, + func_keys=func_keys, + ) + + +def _collect_unique_simulate_functions( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, + ages: AgeGrid, + n_subjects: int, + regime_V_shapes: dict[RegimeName, tuple[int, ...]], +) -> tuple[ + dict[Hashable, tuple[Callable, dict, str]], + dict[tuple[RegimeName, str, int | None], Hashable], +]: + """Walk every regime/period and dedup the simulate functions to compile. + + `argmax_and_max_Q_over_a` dedup keys on `(func_id, active_at_next_period)` + so two periods that share the same argmax callable but see a different + `next_regime_to_V_arr` pytree (different active-regime set at P+1) get + separate compiled programs whose signature matches what runtime actually + dispatches. + """ + unique: dict[Hashable, tuple[Callable, dict, str]] = {} + func_keys: dict[tuple[RegimeName, str, int | None], Hashable] = {} + + for regime_name, regime in internal_regimes.items(): + regime_params = internal_params.get(regime_name, MappingProxyType({})) + sf = regime.simulate_functions + + # `sf.argmax_and_max_Q_over_a` has entries for *every* period + # (pylcm builds them across the full age grid), but the regime is + # only dispatched at runtime for periods in `regime.active_periods`. + # The unused entries can carry a stale `complete_targets` set + # whose shape doesn't match the regime's actual transitions + # (e.g. a forced-canwork regime's argmax for a pre-FRA period + # has choose targets in scope, even though the regime never + # reaches that period at runtime). Tracing those would surface + # `next_` bookkeeping inconsistencies that the lazy path + # never trips. Restrict AOT to active periods to mirror runtime. + for period in regime.active_periods: + argmax_func = sf.argmax_and_max_Q_over_a[period] + active_next = _active_regimes_at_period( + internal_regimes=internal_regimes, period=period + 1 + ) + next_regime_to_V_arr = MappingProxyType( + {name: jnp.zeros(regime_V_shapes[name]) for name in active_next} + ) + args = _build_argmax_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + period=period, + n_subjects=n_subjects, + next_regime_to_V_arr=next_regime_to_V_arr, + ) + key = ("argmax", _func_dedup_key(func=argmax_func), active_next) + func_keys[(regime_name, "argmax", period)] = key + if key not in unique: + label = ( + f"{regime_name}/argmax_and_max_Q_over_a " + f"(age {ages.values[period].item()})" + ) + unique[key] = (jax.jit(argmax_func), args, label) + + if not regime.terminal: + args = _build_next_state_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + n_subjects=n_subjects, + ) + key = ("next_state", _func_dedup_key(func=sf.next_state)) + func_keys[(regime_name, "next_state", None)] = key + if key not in unique: + # Re-wrap with `jax.jit`: when `fixed_params` are partialled + # into the regime, `sf.next_state` is a `functools.partial` + # (no `.lower()`); plain jit objects are also fine to re-jit. + unique[key] = ( + jax.jit(sf.next_state), + args, + f"{regime_name}/next_state", + ) + + if sf.compute_regime_transition_probs is not None: + args = _build_crtp_args( + internal_regime=regime, + regime_params=regime_params, + ages=ages, + n_subjects=n_subjects, + ) + key = ("crtp", _func_dedup_key(func=sf.compute_regime_transition_probs)) + func_keys[(regime_name, "crtp", None)] = key + if key not in unique: + unique[key] = ( + jax.jit(sf.compute_regime_transition_probs), + args, + f"{regime_name}/compute_regime_transition_probs", + ) + + return unique, func_keys + + +def _swap_in_compiled( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + compiled: dict[Hashable, jax.stages.Compiled], + func_keys: dict[tuple[RegimeName, str, int | None], Hashable], +) -> MappingProxyType[RegimeName, InternalRegime]: + """Swap compiled programs into each regime's `simulate_functions`.""" + new_regimes: dict[RegimeName, InternalRegime] = {} + for regime_name, regime in internal_regimes.items(): + sf = regime.simulate_functions + # Only active periods are AOT-compiled (see + # `_collect_unique_simulate_functions`); leave inactive-period + # entries untouched so the existing closure stays in place — they + # are never dispatched at runtime anyway. + argmax_compiled_for_active = { + period: compiled[func_keys[(regime_name, "argmax", period)]] + for period in regime.active_periods + } + argmax_compiled = MappingProxyType( + { + period: argmax_compiled_for_active.get(period, original_func) + for period, original_func in sf.argmax_and_max_Q_over_a.items() + } + ) + if regime.terminal: + next_state_compiled = sf.next_state + else: + next_state_compiled = compiled[func_keys[(regime_name, "next_state", None)]] + if sf.compute_regime_transition_probs is None: + crtp_compiled = None + else: + crtp_compiled = compiled[func_keys[(regime_name, "crtp", None)]] + + new_sf = dataclasses.replace( + sf, + argmax_and_max_Q_over_a=argmax_compiled, + next_state=next_state_compiled, + compute_regime_transition_probs=crtp_compiled, + ) + new_regimes[regime_name] = dataclasses.replace( + regime, simulate_functions=new_sf + ) + + return MappingProxyType(new_regimes) + + +def _get_regime_V_shapes( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + internal_params: InternalParams, +) -> dict[RegimeName, tuple[int, ...]]: + shapes: dict[RegimeName, tuple[int, ...]] = {} + for regime_name, regime in internal_regimes.items(): + space = regime.state_action_space( + regime_params=internal_params.get(regime_name, MappingProxyType({})) + ) + shapes[regime_name] = tuple(len(v) for v in space.states.values()) + return shapes + + +def _active_regimes_at_period( + *, + internal_regimes: MappingProxyType[RegimeName, InternalRegime], + period: int, +) -> tuple[RegimeName, ...]: + """Tuple of regime names active at `period`, in `internal_regimes` order. + + Returned as a `tuple` so it's hashable and pytree-key-stable. An empty + tuple matches the runtime fallback for periods past the last (`{}`). + """ + return tuple( + regime_name + for regime_name, regime in internal_regimes.items() + if period in regime.active_periods + ) + + +def _build_argmax_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + period: int, + n_subjects: int, + next_regime_to_V_arr: MappingProxyType[RegimeName, Array], +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + return { + **subject_states, + **base.discrete_actions, + **base.continuous_actions, + "next_regime_to_V_arr": next_regime_to_V_arr, + **regime_params, + "period": jnp.int32(period), + "age": ages.values[period], + } + + +def _build_next_state_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + n_subjects: int, +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_actions = _subject_shape_arrays( + {**base.discrete_actions, **base.continuous_actions}, + n_subjects=n_subjects, + ) + + stoch_transition_names = ( + internal_regime.simulate_functions.stochastic_transition_names + ) + stoch_next_func_names = sorted( + next_func_name + for next_func_name in flatten_regime_namespace( + internal_regime.simulate_functions.transitions + ) + if tree_path_from_qname(next_func_name)[-1] in stoch_transition_names + ) + _, stoch_keys = generate_simulation_keys( + key=jax.random.key(0), + names=stoch_next_func_names, + n_initial_states=n_subjects, + ) + + # `period` is passed as a plain Python int by `calculate_next_states` + # (transitions.py), which traces as the default-precision int. Match that + # here so the lowered shape signature lines up with the runtime call. + return { + **subject_states, + **subject_actions, + **stoch_keys, + "period": 0, + "age": ages.values[0], + **regime_params, + } + + +def _build_crtp_args( + *, + internal_regime: InternalRegime, + regime_params: FlatRegimeParams, + ages: AgeGrid, + n_subjects: int, +) -> dict[str, object]: + base = internal_regime.state_action_space(regime_params=regime_params) + subject_states = _subject_shape_arrays(base.states, n_subjects=n_subjects) + subject_actions = _subject_shape_arrays( + {**base.discrete_actions, **base.continuous_actions}, + n_subjects=n_subjects, + ) + return { + **subject_states, + **subject_actions, + "period": 0, + "age": ages.values[0], + **regime_params, + } + + +def _subject_shape_arrays( + base_arrays: Mapping[str, Array], + *, + n_subjects: int, +) -> dict[str, Array]: + """Return zeros of shape `(n_subjects,)` mirroring each base array's dtype. + + With `build_initial_states` casting discrete states to the grid dtype, + runtime states (initial + post-transition) share the grid's dtype, so + using `arr.dtype` from the regime's grid here matches runtime. + """ + return { + name: jnp.zeros((n_subjects,), dtype=arr.dtype) + for name, arr in base_arrays.items() + } diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 2cdd3c3b..6ef75a89 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -64,10 +64,24 @@ def build_initial_states( for regime_name, internal_regime in internal_regimes.items(): for state_name in _get_regime_state_names(internal_regime): key = f"{regime_name}__{state_name}" - if state_name in initial_states: + grid = internal_regime.grids[state_name] + if isinstance(grid, DiscreteGrid): + # Match the grid's index dtype so the state is index-stable + # across the simulate loop. Without this, period-0 dispatch + # carries the user-supplied dtype (often int32) but post- + # transition states are promoted to the grid dtype (int64 + # under x64), forcing JAX to compile two argmax variants + # per regime and breaking AOT-compiled programs that key + # on a single signature. + target_dtype = grid.to_jax().dtype + if state_name in initial_states: + flat[key] = initial_states[state_name].astype(target_dtype) + else: + flat[key] = jnp.full( + n_subjects, MISSING_CAT_CODE, dtype=target_dtype + ) + elif state_name in initial_states: flat[key] = initial_states[state_name] - elif isinstance(internal_regime.grids[state_name], DiscreteGrid): - flat[key] = jnp.full(n_subjects, MISSING_CAT_CODE, dtype=jnp.int32) else: flat[key] = jnp.full(n_subjects, jnp.nan) @@ -348,7 +362,7 @@ def _collect_structural_errors( active_mask = active_mask & (~in_regime | period_active) if not jnp.all(active_mask): - invalid_indices = jnp.where(~active_mask)[0] + invalid_indices = jnp.where(~active_mask)[0].astype(jnp.int32) invalid_combos = { (ids_to_regime_names[int(regime_id_arr[i])], float(age_values[i])) for i in invalid_indices @@ -392,7 +406,7 @@ def _collect_feasibility_errors( errors: list[str] = [] for regime_name, internal_regime in internal_regimes.items(): regime_id = regime_names_to_ids[regime_name] - idx_arr = jnp.where(regime_id_arr == regime_id)[0] + idx_arr = jnp.where(regime_id_arr == regime_id)[0].astype(jnp.int32) subject_indices = idx_arr.tolist() if idx_arr.size > 0 else [] if not subject_indices: continue @@ -653,15 +667,107 @@ def _check_combo(action_kw: dict[str, Array]) -> Array: if not infeasible_indices: return None + per_constraint_admits_any = _per_constraint_feasibility( + internal_regime=internal_regime, + subject_states=subject_states, + regime_params=regime_params, + flat_actions=flat_actions, + idx_arr=idx_arr, + infeasible_indices=infeasible_indices, + ) + return _format_infeasibility_message( infeasible_indices=infeasible_indices, internal_regime=internal_regime, regime_name=regime_name, initial_states=initial_states, state_names=state_names, + per_constraint_admits_any=per_constraint_admits_any, ) +def _admits_any_action( + *, + feasibility_func: Callable[..., Array], + action_kwargs: Mapping[str, Array], + params: Mapping[str, object], +) -> bool: + """Return True iff the feasibility function admits ≥ 1 action under params.""" + if action_kwargs: + + def _check_combo(action_kw: dict[str, Array]) -> Array: + return feasibility_func(**action_kw, **params) + + per_combo = jax.vmap(_check_combo)(action_kwargs) + return bool(jnp.any(per_combo)) + return bool(feasibility_func(**params)) + + +def _per_constraint_feasibility( + *, + internal_regime: InternalRegime, + subject_states: Mapping[str, Array], + regime_params: Mapping[str, object], + flat_actions: Mapping[ActionName, Array], + idx_arr: Array, + infeasible_indices: Sequence[int], +) -> dict[str, np.ndarray]: + """Per-constraint feasibility for the infeasible subjects. + + For each constraint, returns a boolean array (one entry per infeasible + subject) indicating whether that constraint *individually* admits at + least one action. Combined with the regime's feasibility verdict, this + distinguishes "constraint X rejects every action by itself" from + "constraints jointly reject everything despite each admitting some". + + Each constraint's feasibility function has its own argument set (a + subset of the combined feasibility's union); filter `subject_states`, + `action_kwargs`, and `filtered_params` per constraint so dags doesn't + raise on stray kwargs. + """ + constraints = internal_regime.simulate_functions.constraints + functions = internal_regime.simulate_functions.functions + if not constraints or not subject_states: + return {} + + infeasible_positions = np.flatnonzero( + np.isin(np.asarray(idx_arr), np.asarray(infeasible_indices)) + ) + infeasible_states = { + name: arr[infeasible_positions] for name, arr in subject_states.items() + } + + out: dict[str, np.ndarray] = {} + for name, constraint_func in constraints.items(): + single_feasibility = _get_feasibility( + functions=functions, + constraints=MappingProxyType({name: constraint_func}), + ) + accepted = get_union_of_args([single_feasibility]) + single_states = {k: v for k, v in infeasible_states.items() if k in accepted} + single_actions = {k: v for k, v in flat_actions.items() if k in accepted} + single_params = {k: v for k, v in regime_params.items() if k in accepted} + n = len(infeasible_indices) + if not single_states: + # Action-only / parameter-only constraint — identical for all subjects. + admits_any = _admits_any_action( + feasibility_func=single_feasibility, + action_kwargs=single_actions, + params=single_params, + ) + out[name] = np.full(n, admits_any, dtype=bool) + continue + any_feasible = _batched_feasibility_check( + feasibility_func=single_feasibility, + subject_states=single_states, + action_kwargs=single_actions, + filtered_params=single_params, + flat_actions=flat_actions, + ) + out[name] = np.asarray(any_feasible) + return out + + def _raise_feasibility_type_error( *, exc: TypeError, @@ -715,6 +821,7 @@ def _format_infeasibility_message( regime_name: RegimeName, initial_states: Mapping[str, Array], state_names: Sequence[str], + per_constraint_admits_any: Mapping[str, np.ndarray], ) -> str: """Format an error message for infeasible subjects. @@ -724,6 +831,12 @@ def _format_infeasibility_message( regime_name: Name of the regime. initial_states: Mapping of state names to arrays. state_names: List of state variable names. + per_constraint_admits_any: Mapping from constraint name to a boolean + array (one entry per infeasible subject) — True where that + constraint *individually* admits at least one action. False + entries identify constraints that reject every action on their + own; rows with all-True entries are infeasible only because the + constraints jointly reject the action set. Returns: Formatted error message string. @@ -745,9 +858,10 @@ def _format_infeasibility_message( if isinstance(grid, DiscreteGrid) and name in state_df.columns: state_df[name] = [grid.categories[int(v)] for v in state_df[name]] - # Constraint names - constraint_names = list(internal_regime.simulate_functions.constraints.keys()) - constraints_str = "\n".join(f" - {name}" for name in constraint_names) + # Append one boolean column per constraint: True = admits ≥ 1 action, + # False = rejects every action by itself for that subject. + for name, mask in per_constraint_admits_any.items(): + state_df[name] = list(mask) # Truncate for large groups n = len(infeasible_indices) @@ -761,10 +875,11 @@ def _format_infeasibility_message( return ( f"All actions are infeasible for {n} subject(s) " f"in regime '{regime_name}'.\n\n" - f"Active constraints:\n{constraints_str}\n\n" - f"Infeasible subjects:\n{table_str}\n\n" - f"No action combination satisfies all constraints for these " - f"initial states." + f"Per-constraint admissibility (True = constraint admits ≥ 1 " + f"action by itself; False = constraint rejects every action):\n" + f"{table_str}\n\n" + f"No action combination satisfies all constraints jointly for " + f"these initial states." ) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d54d2674..d1ab42ab 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -99,7 +99,9 @@ def simulate( starting_periods = _compute_starting_periods( initial_ages=initial_states["age"], ages=ages ) - subject_regime_ids = jnp.full_like(initial_conditions["regime"], MISSING_CAT_CODE) + subject_regime_ids = jnp.full_like( + initial_conditions["regime"], MISSING_CAT_CODE, dtype=jnp.int32 + ) # Forward simulation simulation_results: dict[RegimeName, dict[int, PeriodRegimeSimulationData]] = { diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index d8152e13..ba792609 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -510,6 +510,15 @@ def _raise_at( """Run the enriched NaN diagnostic on a single offending row and raise.""" internal_regime = internal_regimes[row.regime_name] regime_params = internal_params[row.regime_name] + # `compute_intermediates` was built from the regime's full `flat_param_names` + # (per-iteration params + fixed params); the live solve loop merges + # `resolved_fixed_params` into `regime_params` implicitly via the partialled + # closures, but we have to do it by hand here to call the diagnostic + # directly. Same merge order as `interfaces.state_action_space` and + # `simulation.result`. + diag_params = MappingProxyType( + {**internal_regime.resolved_fixed_params, **regime_params} + ) state_action_space = internal_regime.state_action_space(regime_params=regime_params) next_regime_to_V_arr = _reconstruct_next_regime_to_V_arr( period=row.period, @@ -529,7 +538,7 @@ def _raise_at( compute_intermediates=compute_intermediates, state_action_space=state_action_space, next_regime_to_V_arr=next_regime_to_V_arr, - internal_params=regime_params, + internal_params=diag_params, period=row.period, ) diff --git a/src/lcm/typing.py b/src/lcm/typing.py index 310c7bc6..da6815cf 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -4,27 +4,27 @@ import pandas as pd from jax import Array -from jaxtyping import Bool, Float, Int, Scalar +from jaxtyping import Bool, Float, Int32, Scalar from lcm.params import MappingLeaf from lcm.params.sequence_leaf import SequenceLeaf type ContinuousState = Float[Array, "..."] type ContinuousAction = Float[Array, "..."] -type DiscreteState = Int[Array, "..."] -type DiscreteAction = Int[Array, "..."] +type DiscreteState = Int32[Array, "..."] +type DiscreteAction = Int32[Array, "..."] type FloatND = Float[Array, "..."] -type IntND = Int[Array, "..."] +type IntND = Int32[Array, "..."] type BoolND = Bool[Array, "..."] type Float1D = Float[Array, "_"] # noqa: F821 -type Int1D = Int[Array, "_"] # noqa: F821 +type Int1D = Int32[Array, "_"] # noqa: F821 type Bool1D = Bool[Array, "_"] # noqa: F821 # Many JAX functions are designed to work with scalar numerical values. This also # includes zero dimensional jax arrays. -type ScalarInt = int | Int[Scalar, ""] +type ScalarInt = int | Int32[Scalar, ""] type ScalarFloat = float | Float[Scalar, ""] type Period = int | Int1D diff --git a/src/lcm/utils/error_handling.py b/src/lcm/utils/error_handling.py index ef631349..faeed9fe 100644 --- a/src/lcm/utils/error_handling.py +++ b/src/lcm/utils/error_handling.py @@ -372,7 +372,7 @@ def _format_sum_violation( {name: jnp.atleast_1d(arr) for name, arr in state_action_values.items()} ) failing_mask = ~jnp.isclose(sum_all, 1.0) - failing_indices = jnp.where(failing_mask)[0] + failing_indices = jnp.where(failing_mask)[0].astype(jnp.int32) failing_sums = sum_all[failing_mask] n_failing = int(failing_indices.shape[0]) n_show = min(n_failing, 5) diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py new file mode 100644 index 00000000..3b8e57a1 --- /dev/null +++ b/tests/simulation/test_simulate_aot.py @@ -0,0 +1,180 @@ +"""Tests for simulate-AOT compilation via `Model.n_subjects`. + +When `Model(n_subjects=N)` is set, the first matching `simulate(...)` call +parallel-compiles all simulate functions for batch shape `N`. Subsequent calls +with size `N` reuse the cache; calls with a mismatching size warn once per size +and fall back to the runtime-traced path. +""" + +import logging + +import jax.numpy as jnp +import jax.stages +import pytest + +from lcm import Model +from lcm.ages import AgeGrid +from tests.test_models.deterministic.regression import ( + RegimeId, + dead, + get_params, + working_life, +) + + +def _build_test_model(*, n_periods: int, n_subjects: int | None = None) -> Model: + """Construct the small 2-regime regression model with optional n_subjects.""" + final_age_alive = 18 + n_periods - 2 + return Model( + regimes={ + "working_life": working_life.replace( + active=lambda age: age <= final_age_alive, + ), + "dead": dead, + }, + ages=AgeGrid(start=18, stop=final_age_alive + 1, step="Y"), + regime_id_class=RegimeId, + n_subjects=n_subjects, + ) + + +def _build_initial_conditions(*, n_subjects: int) -> dict: + """Subject array of size `n_subjects` matching the regression test model.""" + wealths = jnp.linspace(20.0, 320.0, num=n_subjects) + return { + "wealth": wealths, + "age": jnp.full((n_subjects,), 18.0), + "regime": jnp.array([RegimeId.working_life] * n_subjects), + } + + +@pytest.mark.parametrize("invalid", [0, -3]) +def test_n_subjects_validation_rejects_non_positive(invalid: int) -> None: + with pytest.raises(ValueError, match="n_subjects"): + _build_test_model(n_periods=3, n_subjects=invalid) + + +def test_n_subjects_validation_rejects_non_int() -> None: + with pytest.raises(TypeError, match="n_subjects"): + _build_test_model(n_periods=3, n_subjects=1.5) # ty: ignore[invalid-argument-type] + + +def test_n_subjects_none_keeps_lazy_behavior() -> None: + """Without n_subjects, simulate works and no AOT cache is populated.""" + n_periods = 3 + model = _build_test_model(n_periods=n_periods, n_subjects=None) + params = get_params(n_periods=n_periods) + + result = model.simulate( + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=_build_initial_conditions(n_subjects=4), + ) + + assert result.n_subjects == 4 + assert model.n_subjects is None + assert not getattr(model, "_simulate_compile_cache", {}) + + +def test_simulate_compiles_only_once_with_matching_n_subjects( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """First simulate call AOT-compiles; second call hits the cache.""" + n_periods = 3 + n_subjects = 4 + model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + counter = {"count": 0} + original_compile = jax.stages.Lowered.compile + + def counting_compile(self: jax.stages.Lowered, *args, **kwargs): + counter["count"] += 1 + return original_compile(self, *args, **kwargs) + + monkeypatch.setattr(jax.stages.Lowered, "compile", counting_compile) + + initial_conditions = _build_initial_conditions(n_subjects=n_subjects) + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + n_first = counter["count"] + counter["count"] = 0 + + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + n_second = counter["count"] + + assert n_first > 0, "First simulate must trigger compilation." + assert n_second == 0, "Second simulate must hit the AOT cache." + assert n_subjects in model._simulate_compile_cache + + +def test_simulate_warns_on_n_subjects_mismatch( + caplog: pytest.LogCaptureFixture, +) -> None: + """Mismatching size logs WARNING naming both N and M, falls back to lazy path.""" + n_periods = 3 + declared_n = 4 + actual_n = 7 + model = _build_test_model(n_periods=n_periods, n_subjects=declared_n) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + + with caplog.at_level(logging.WARNING, logger="lcm"): + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=_build_initial_conditions(n_subjects=actual_n), + ) + + mismatch_warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() + ] + assert len(mismatch_warnings) == 1 + msg = mismatch_warnings[0].getMessage() + assert str(declared_n) in msg + assert str(actual_n) in msg + # Cache is NOT populated for mismatching size — fallback path was taken. + assert actual_n not in model._simulate_compile_cache + + +def test_simulate_caches_recompiled_size_no_second_warning( + caplog: pytest.LogCaptureFixture, +) -> None: + """Two calls with the same mismatching size produce only one WARNING.""" + n_periods = 3 + declared_n = 4 + actual_n = 7 + model = _build_test_model(n_periods=n_periods, n_subjects=declared_n) + params = get_params(n_periods=n_periods) + period_to_regime_to_V_arr = model.solve(params=params) + initial_conditions = _build_initial_conditions(n_subjects=actual_n) + + with caplog.at_level(logging.WARNING, logger="lcm"): + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + model.simulate( + params=params, + period_to_regime_to_V_arr=period_to_regime_to_V_arr, + initial_conditions=initial_conditions, + ) + + mismatch_warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "n_subjects" in r.getMessage() + ] + assert len(mismatch_warnings) == 1 diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py new file mode 100644 index 00000000..3147de3f --- /dev/null +++ b/tests/test_int_dtype_invariants.py @@ -0,0 +1,41 @@ +"""Integer dtypes are pinned to int32 across pylcm regardless of x64 mode.""" + +import jax.numpy as jnp + +from lcm.simulation.initial_conditions import ( + MISSING_CAT_CODE, + build_initial_states, +) +from tests.test_models.deterministic.regression import get_model + + +def test_discrete_grid_to_jax_is_int32() -> None: + model = get_model(n_periods=3) + for regime in model.regimes.values(): + for grid in {**regime.states, **regime.actions}.values(): + jax_arr = grid.to_jax() + if jax_arr.dtype.kind == "i": + assert jax_arr.dtype == jnp.int32, ( + f"Discrete grid yielded {jax_arr.dtype}, expected int32." + ) + + +def test_build_initial_states_discrete_dtype_is_int32() -> None: + model = get_model(n_periods=3) + initial_states = { + "wealth": jnp.array([20.0, 50.0]), + "age": jnp.array([18.0, 18.0]), + } + flat = build_initial_states( + initial_states=initial_states, + internal_regimes=model.internal_regimes, + ) + for key, arr in flat.items(): + if arr.dtype.kind == "i": + assert arr.dtype == jnp.int32, ( + f"Initial state {key} has dtype {arr.dtype}, expected int32." + ) + + +def test_missing_cat_code_is_int32_minimum() -> None: + assert jnp.iinfo(jnp.int32).min == MISSING_CAT_CODE diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 162650cb..a1f688c5 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -3,7 +3,6 @@ import jax.numpy as jnp import pandas as pd -from pybaum import tree_equal from lcm.ages import AgeGrid from lcm.grids import DiscreteGrid @@ -13,10 +12,7 @@ get_next_state_function_for_simulation, get_next_state_function_for_solution, ) -from lcm.typing import ( - ContinuousState, - FloatND, -) +from lcm.typing import ContinuousState from tests.test_models.deterministic.regression import dead, working_life @@ -59,14 +55,19 @@ def test_get_next_state_function_with_solve_target(): def test_get_next_state_function_with_simulate_target(): + """Outputs are nested by target regime: `{target: {next_state: array}}`. + + The combined function dispatches inputs to the per-target DAG and + returns a mapping from target regime name to that target's + `{next_: array}` outputs, matching what + `_update_states_for_subjects` consumes. + """ + def f_a(state: ContinuousState) -> ContinuousState: - return state[0] + return state * 2.0 def f_b(state: ContinuousState) -> ContinuousState: - return None # ty: ignore[invalid-return-type] - - def f_weight_b(state: ContinuousState) -> FloatND: - return jnp.array([0.0, 1.0]) + return state + 1.0 @dataclass class MockCategory: @@ -76,11 +77,12 @@ class MockCategory: all_grids = MappingProxyType( {"mock": MappingProxyType({"b": DiscreteGrid(MockCategory)})} ) - variable_info = pd.DataFrame({"is_shock": [False]}) + variable_info = pd.DataFrame({"is_shock": [False]}, index=["b"]) transitions = MappingProxyType( {"mock": MappingProxyType({"next_a": f_a, "next_b": f_b})} ) - functions = MappingProxyType({"utility": lambda: 0, "f_weight_b": f_weight_b}) + functions = MappingProxyType({"utility": lambda: 0}) + got_func = get_next_state_function_for_simulation( transitions=transitions, # ty: ignore[invalid-argument-type] functions=functions, # ty: ignore[invalid-argument-type] @@ -88,11 +90,12 @@ class MockCategory: variable_info=variable_info, ) - key = jnp.arange(2, dtype="uint32") - got = got_func(state=jnp.arange(2), key_b=key) + got = got_func(state=jnp.array([1.0, 2.0])) - expected = {"a": jnp.array([0]), "b": jnp.array([1])} - assert tree_equal(expected, got) + assert set(got.keys()) == {"mock"} + assert set(got["mock"].keys()) == {"next_a", "next_b"} + assert jnp.array_equal(got["mock"]["next_a"], jnp.array([2.0, 4.0])) + assert jnp.array_equal(got["mock"]["next_b"], jnp.array([2.0, 3.0])) def test_create_stochastic_next_func():