Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
1c65f45
Support runtime-supplied points on continuous-action IrregSpacedGrids
hmgaudecker Apr 29, 2026
1792279
benchmarks: bump aca-model to runtime-consumption-points version
hmgaudecker Apr 29, 2026
cf00e99
Fail loudly when reading runtime IrregSpacedGrid before substitution
hmgaudecker Apr 29, 2026
db98cde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
03ba800
Fix remaining ruff check errors in test_single_feasible_action.py
hmgaudecker Apr 29, 2026
3769d2d
Raise on regime-function-output indexed by discrete state in a consumer
hmgaudecker Apr 29, 2026
282542f
benchmarks: bump aca-model to dead-regime-NaN fix
hmgaudecker Apr 29, 2026
72c83f7
Substitute runtime-supplied action gridpoints in simulate's state-act…
hmgaudecker Apr 29, 2026
db6214f
Guard log-spaced grids against non-positive start; reject non-finite …
hmgaudecker Apr 30, 2026
f56b9be
Address PR review: docstrings, type hints, validation, drop separator…
hmgaudecker Apr 30, 2026
3b7be82
LogSpacedGrid docstring: drop redundant rationale
hmgaudecker May 1, 2026
8589146
IrregSpacedGrid.to_jax docstring: shorter, point at the substituted-g…
hmgaudecker May 1, 2026
541392c
IrregSpacedGrid.to_jax error message: same shape as docstring
hmgaudecker May 1, 2026
54d22b0
Use jnp.isfinite in grid validators; drop math import
hmgaudecker May 1, 2026
ba38876
Drop cryptic aca_model reference from validator docstring
hmgaudecker May 1, 2026
4d11dea
Validator: also flag function-output indexed by a derived categorical
hmgaudecker May 1, 2026
01608bb
create_regime_state_action_space docstring: trim rationale
hmgaudecker May 1, 2026
6241056
state_action_space: move private helpers below public function (deep …
hmgaudecker May 1, 2026
2efe9e1
Drop application-specific (aca) references from test docstrings
hmgaudecker May 1, 2026
f1c4d5d
Validator: rename to discrete_grid_names, also include discrete actions
hmgaudecker May 1, 2026
d9f37ce
Validator: tighten to actual footgun shape; correct behaviour descrip…
hmgaudecker May 1, 2026
2cc46ff
solve_brute: stream NaN/Inf reductions instead of stacking-and-flushing
hmgaudecker May 1, 2026
365da07
solve_brute: stop pinning per-period V templates in diagnostic_rows
hmgaudecker May 1, 2026
bf1cdf4
solve_brute: fail-fast on NaN per period; rewrite stale diagnostic hint
hmgaudecker May 4, 2026
c7745f3
solve/simulate: surface snapshot path in NaN exception note
hmgaudecker May 4, 2026
bc067c1
Model.n_subjects: AOT-compile simulate functions for fixed batch shape
hmgaudecker May 1, 2026
8bb8259
simulate AOT: match runtime's sparse pytree, drop runtime padding
hmgaudecker May 2, 2026
54c72a0
bench_aca_baseline: pass n_subjects=_N_SUBJECTS to create_benchmark_m…
hmgaudecker May 2, 2026
92d038c
benchmarks: bump aca-model rev to carry n_subjects on factories
hmgaudecker May 2, 2026
648afcc
benchmarks: bump aca-model rev + pass max_consumption to factory
hmgaudecker May 2, 2026
596f150
simulate AOT: re-jit `next_state` / `compute_regime_transition_probs`
hmgaudecker May 2, 2026
9fd0524
simulate AOT: only compile active-period argmax, not the full age range
hmgaudecker May 3, 2026
dfb0e8b
simulate AOT: int32 for discrete state lower-args (match runtime)
hmgaudecker May 3, 2026
43723c4
build_initial_states: cast discrete states to grid dtype (one-shot)
hmgaudecker May 3, 2026
99f3f15
DiscreteGrid: pin to_jax() to int32 regardless of x64 mode
hmgaudecker May 3, 2026
458d36f
Lock integer dtype to int32 end-to-end (#341)
hmgaudecker May 3, 2026
e8ede00
benchmarks: bump aca-model rev; drop max_consumption kwarg
hmgaudecker May 3, 2026
4316b6e
solve_brute: merge resolved_fixed_params into NaN diagnostic regime_p…
hmgaudecker May 3, 2026
866a5bb
benchmarks: bump aca-model rev to 714fee0 (assets-floor margin)
hmgaudecker May 4, 2026
a51edae
regime_template: exempt next_<state> names from fixed_param extraction
hmgaudecker May 4, 2026
71a6146
Merge feature/next-state-deps-in-transitions: exempt next_<state> fro…
hmgaudecker May 4, 2026
ac93eec
Merge improve/lazy-solve-diagnostics (incl. next_<state> exempt-set fix)
hmgaudecker May 4, 2026
d66d85a
Boilerplate refresh: dags module, current pixi/hook pins, drop stale …
hmgaudecker May 4, 2026
f18d27c
Bump .ai-instructions: pyproject-fmt + ruff + check-jsonschema rev pins
hmgaudecker May 4, 2026
5e01de6
Merge feature/runtime-action-grids: boilerplate refresh + ai-instruct…
hmgaudecker May 4, 2026
f730e78
Merge feature/next-state-deps-in-transitions: boilerplate refresh + a…
hmgaudecker May 4, 2026
749f83a
Merge improve/lazy-solve-diagnostics: boilerplate refresh + ai-instru…
hmgaudecker May 4, 2026
c3e1838
Merge main: pick up #338 (runtime action grids + validator tightening)
hmgaudecker May 4, 2026
01ba1f3
Merge feature/next-state-deps-in-transitions (carries main → #338)
hmgaudecker May 4, 2026
3229a15
Merge improve/lazy-solve-diagnostics (carries main → #338 → #342)
hmgaudecker May 4, 2026
e45cf4f
regime_template: reject next_<state> params on regular DAG functions
hmgaudecker May 4, 2026
5ff174d
Merge feature/next-state-deps-in-transitions: harden next_<state> val…
hmgaudecker May 4, 2026
5bae789
Merge improve/lazy-solve-diagnostics: harden next_<state> validator
hmgaudecker May 4, 2026
c585f4b
Bump aca-model benchmark pin to 83f22500 (post pension correction)
hmgaudecker May 4, 2026
ffe8820
Merge feature/next-state-deps-in-transitions: bump aca-model benchmar…
hmgaudecker May 4, 2026
e912de4
Merge improve/lazy-solve-diagnostics: bump aca-model benchmark pin to…
hmgaudecker May 4, 2026
f0dd7b5
Revert aca-model pin: this branch lacks Model.n_subjects (introduced …
hmgaudecker May 4, 2026
149ac78
Merge feature/next-state-deps-in-transitions: revert aca-model pin to…
hmgaudecker May 4, 2026
588c9c4
Merge improve/lazy-solve-diagnostics
hmgaudecker May 4, 2026
c00b610
Bump aca-model pin to 83f22500 on #340 (carries pension correction)
hmgaudecker May 4, 2026
e89d5e4
Revert "regime_template: reject next_<state> params on regular DAG fu…
hmgaudecker May 4, 2026
d07d897
Merge feature/next-state-deps-in-transitions: revert validator that c…
hmgaudecker May 4, 2026
53443fc
Merge improve/lazy-solve-diagnostics: revert validator + carry over
hmgaudecker May 4, 2026
e92eeec
Bump aca-model pin to 3453080 (filters stale benchmark_params key)
hmgaudecker May 4, 2026
c117f4c
Bump aca-model pin to b2e90bb (synthesise shifted imputation arrays)
hmgaudecker May 4, 2026
e422876
Bump aca-model pin to 35eddcc (declare target_his derived categorical)
hmgaudecker May 4, 2026
4b9bea3
Bump aca-model pin to 64d6567 (rename shifted-array level to target_his)
hmgaudecker May 4, 2026
9c7edb6
Bump aca-model pin to f09b5e3 (per-target next_assets, dead-target te…
hmgaudecker May 4, 2026
e6066fa
Roll #340 aca-model pin back to 63d2a38 (pre-pension-correction)
hmgaudecker May 4, 2026
a908c84
get_next_state_function_for_simulation: per-target DAG mirrors solve
hmgaudecker May 5, 2026
18d4ade
Revert aca-model rollback: restore f09b5e3 pin (with pension correction)
hmgaudecker May 5, 2026
6c64a77
get_next_state_function_for_simulation: per-target DAG mirrors solve
hmgaudecker May 5, 2026
fed28cd
Merge feature/next-state-deps-in-transitions: pull pylcm simulate-pat…
hmgaudecker May 5, 2026
c969b1a
next_state: real signature for combined; fix trivially-passing test
hmgaudecker May 6, 2026
8a2ad4f
Address #342 review: simulate-path uses concatenate_functions; cleanups
hmgaudecker May 6, 2026
bcdd358
Merge feature/next-state-deps-in-transitions: address #342 review fee…
hmgaudecker May 6, 2026
17347c8
Get rid of H_variables entirely in regime_template.
hmgaudecker May 6, 2026
076b9b6
Bump .ai-instructions: TDD always; behavior-focused docstrings
hmgaudecker May 6, 2026
e6f8e41
Merge feature/next-state-deps-in-transitions: bump .ai-instructions f…
hmgaudecker May 6, 2026
164a88b
AGENTS.md: inline TDD-always testing section directly
hmgaudecker May 6, 2026
07c5ae0
Merge feature/next-state-deps-in-transitions: inline TDD section in A…
hmgaudecker May 6, 2026
5261e29
Address #339 review: drop field-count test; tighten claudish docstrings
hmgaudecker May 6, 2026
110cc0b
regime_template: collapse H_variables into single variables set
hmgaudecker May 6, 2026
4b2895c
Merge feature/next-state-deps-in-transitions: collapse H_variables
hmgaudecker May 6, 2026
a7b9e9a
solve_brute: drop misleading "~2 MB each" magic number from comment
hmgaudecker May 6, 2026
16b570f
AGENTS.md: docstring style — describe state, no PR refs, bulleted lists
hmgaudecker May 6, 2026
1251f0b
Merge feature/next-state-deps-in-transitions: docstring style + .ai-i…
hmgaudecker May 6, 2026
ff65261
solve_brute: apply docstring style — drop PR ref, magic number; bulle…
hmgaudecker May 6, 2026
2539374
Merge improve/lazy-solve-diagnostics: docstring style, .ai-instructio…
hmgaudecker May 6, 2026
e9b7cc5
test_next_state: update to assert nested output shape from #339 merge
hmgaudecker May 6, 2026
eb69432
Merge main: pick up squash of #342
hmgaudecker May 6, 2026
838473e
validate_initial_conditions: per-constraint admissibility in error me…
hmgaudecker May 6, 2026
07f951a
Merge improve/lazy-solve-diagnostics: pick up main-squash + downstrea…
hmgaudecker May 6, 2026
e4cae2a
_per_constraint_feasibility: filter args per single-constraint feasib…
hmgaudecker May 6, 2026
62392c1
Merge branch 'main' into feat/simulate-aot-n-subjects
hmgaudecker May 6, 2026
50f78f0
Address #340 review: docstring style, TDD, x64+AOT guard, period dtype
hmgaudecker May 6, 2026
7e81532
Package A: int dtype barriers at the API boundary
hmgaudecker May 6, 2026
63b740c
Fix Package A 32-bit precision tests: build overflow fixtures with numpy
hmgaudecker May 6, 2026
568cbcb
Address #340 review-2: counterfactuals, multi-assertion tests, dedup …
hmgaudecker May 6, 2026
2f19aa1
compile: free lower-args after lowering, free Lowered after compile
hmgaudecker May 7, 2026
143a3ae
solve_brute: rename diag_params to effective_regime_params
hmgaudecker May 7, 2026
1e26926
process_params: cast Python int leaves to jnp.int32
hmgaudecker May 8, 2026
14f81fc
solve: kick off simulate AOT compile in a background thread
hmgaudecker May 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/bench_aca_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _build() -> tuple[object, object, object]:
get_benchmark_params,
)

model = create_benchmark_model()
model = create_benchmark_model(n_subjects=_N_SUBJECTS)
_, model_params = get_benchmark_params(model=model)
initial_conditions = get_benchmark_initial_conditions(
model=model, n_subjects=_N_SUBJECTS, seed=0
Expand Down
8 changes: 4 additions & 4 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
[tool.pixi.feature.benchmarks.pypi-dependencies]
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" }
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" }
[tool.pixi.feature.cuda12]
platforms = [ "linux-64" ]
system-requirements = { cuda = "12" }
Expand Down
48 changes: 48 additions & 0 deletions src/lcm/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Boundary-cast helpers that pin user-supplied data to canonical pylcm dtypes.

Used at every API boundary that accepts user data (params, initial
conditions, regime-id arrays) — always called from Python, never inside
JIT. Each helper validates that the value fits the target dtype and
raises a clearly-named error if not.

Casts further down the simulate stack (e.g. transition outputs landing
in the state pool) use plain `.astype` and rely on the boundary cast
above them having already pinned the canonical dtype.
"""

import jax.numpy as jnp
import numpy as np
from jax import Array

_INT32_MIN = int(np.iinfo(np.int32).min)
_INT32_MAX = int(np.iinfo(np.int32).max)


def safe_to_int32(value: object, *, name: str) -> Array:
"""Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range.

Args:
value: A Python int, numpy/JAX integer scalar, or array-like of
integer values.
name: Qualified name of the leaf — surfaced in the error message
so the user can locate the offending input.

Returns:
A `jnp.int32` array (0-d if `value` was a scalar).

Raises:
ValueError: If any element of `value` is outside the int32 range
`[-2**31, 2**31 - 1]`. The message names the leaf via `name`.

"""
np_value = np.asarray(value)
if np_value.size > 0:
lo = int(np_value.min())
hi = int(np_value.max())
if lo < _INT32_MIN or hi > _INT32_MAX:
msg = (
f"{name}: int32 overflow — value range [{lo}, {hi}] "
f"exceeds [{_INT32_MIN}, {_INT32_MAX}]."
)
raise ValueError(msg)
return jnp.asarray(np_value, dtype=jnp.int32)
12 changes: 10 additions & 2 deletions src/lcm/grids/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,13 @@ def batch_size(self) -> int:
return self.__batch_size

def to_jax(self) -> Int1D:
"""Convert the grid to a Jax array."""
return jnp.array(self.codes)
"""Convert the grid to a Jax array.

Discrete state/action codes are pinned to `int32` regardless of the
ambient `jax_enable_x64` setting. A single integer dtype across
transitions, V-array indexing, and action lookups keeps the JIT cache
unsplit and lets AOT-compiled programs ship one signature. `int32`
accommodates any realistic category count and matches the
`MISSING_CAT_CODE` sentinel.
"""
return jnp.array(self.codes, dtype=jnp.int32)
196 changes: 195 additions & 1 deletion src/lcm/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Collection of classes that are used by the user to define the model and grids."""

import dataclasses
import logging
import threading
from collections.abc import Mapping
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from types import MappingProxyType

Expand Down Expand Up @@ -30,6 +33,7 @@
)
from lcm.regime import Regime
from lcm.regime_building.processing import InternalRegime
from lcm.simulation.compile import compile_all_simulate_functions
from lcm.simulation.initial_conditions import validate_initial_conditions
from lcm.simulation.result import SimulationResult, get_simulation_output_dtypes
from lcm.simulation.simulate import simulate
Expand Down Expand Up @@ -85,9 +89,56 @@ class Model:
fixed_params: UserParams
"""Parameters fixed at model initialization."""

n_subjects: int | None = None
"""Expected simulate batch size; enables AOT compile of simulate functions.

Dispatch by call shape:

- `None`: purely lazy behaviour, no AOT.
- First `simulate(...)` with `actual_n == n_subjects`: AOT-compiles all
simulate functions for that batch shape in parallel and caches them.
- Subsequent `simulate(...)` with the same matching size: reuses the
cached compiled programs.
- `simulate(...)` with a mismatching size: warns once per size and falls
back to the runtime-traced path.

Param-shape contract: the cache is keyed only on `n_subjects`. The shapes
and dtypes of `internal_params` leaves at the first matching call become
part of the AOT signature; subsequent calls must keep them stable. MSM-
style estimation (varying values, fixed shapes) is the target use case;
construct a fresh `Model` whenever a param array's shape or dtype changes.
"""

_params_template: ParamsTemplate
"""Template for the model parameters."""

_simulate_compile_cache: dict[int, MappingProxyType[RegimeName, InternalRegime]]
"""AOT-compiled `internal_regimes` per matching `n_subjects`."""

_warned_n_subjects: set[int]
"""Mismatching `actual_n_subjects` already warned about (one warning each)."""

_simulate_compile_future: (
Future[MappingProxyType[RegimeName, InternalRegime]] | None
)
"""Pending background AOT compile started by `solve(...)`, or `None`.

`solve(...)` kicks off `compile_all_simulate_functions` in a single
background thread so XLA compilation overlaps with the GPU-bound
backward induction. `simulate(...)` awaits the future before
dispatching the AOT-compiled program. Cleared after the result lands
in `_simulate_compile_cache`.
"""

_simulate_compile_lock: threading.Lock
"""Serialises mutations of `_simulate_compile_cache`, `_warned_n_subjects`,
and `_simulate_compile_future`.

The check-then-set on each container is held under this lock. The
consequent `log.warning` call sits outside the lock so concurrent
simulate() calls don't serialise on logging I/O.
"""

def __init__(
self,
*,
Expand All @@ -100,6 +151,7 @@ def __init__(
derived_categoricals: Mapping[FunctionName, DiscreteGrid] = MappingProxyType(
{}
),
n_subjects: int | None = None,
) -> None:
"""Initialize the Model.

Expand All @@ -115,17 +167,27 @@ def __init__(
not in states/actions. Broadcast to all regimes (merged with
each regime's own `derived_categoricals`). Raises if a regime
already has a conflicting entry.
n_subjects: Expected simulate batch size; if set, the first matching
`simulate(...)` call AOT-compiles all simulate functions for
batch shape `n_subjects` in parallel. `None` keeps the purely
lazy behaviour.

"""
self.description = description
self.ages = ages
self.n_periods = ages.n_periods
self.fixed_params = ensure_containers_are_immutable(fixed_params)
self.n_subjects = n_subjects
self._simulate_compile_cache = {}
self._warned_n_subjects = set()
self._simulate_compile_future = None
self._simulate_compile_lock = threading.Lock()

validate_model_inputs(
n_periods=self.n_periods,
regimes=regimes,
regime_id_class=regime_id_class,
n_subjects=n_subjects,
)
self.regime_names_to_ids = MappingProxyType(
dict(
Expand All @@ -149,6 +211,31 @@ def __init__(
regime_names_to_ids=self.regime_names_to_ids,
)

def __getstate__(self) -> dict[str, object]:
"""Return a copy of `__dict__` with per-process AOT compile state removed.

Drops `_simulate_compile_lock` (a `threading.Lock`, not pickleable),
`_simulate_compile_cache` (compiled XLA programs that can't survive
a process boundary), `_warned_n_subjects` (its companion set), and
`_simulate_compile_future` (a `Future` tied to the originating thread
pool).
`__setstate__` restores all three to their fresh state.
"""
state = self.__dict__.copy()
state.pop("_simulate_compile_lock", None)
state.pop("_simulate_compile_cache", None)
state.pop("_warned_n_subjects", None)
state.pop("_simulate_compile_future", None)
return state

def __setstate__(self, state: dict[str, object]) -> None:
"""Restore AOT compile state to a fresh empty cache."""
self.__dict__.update(state)
self._simulate_compile_cache = {}
self._warned_n_subjects = set()
self._simulate_compile_future = None
self._simulate_compile_lock = threading.Lock()

def get_params_template(self) -> UserFacingParamsTemplate:
"""Get a human-readable params template.

Expand Down Expand Up @@ -210,6 +297,11 @@ def solve(
internal_params=internal_params,
ages=self.ages,
)
self._maybe_start_simulate_compile_async(
internal_params=internal_params,
max_compilation_workers=max_compilation_workers,
logger=get_logger(log_level=log_level),
)
try:
period_to_regime_to_V_arr = solve(
internal_params=internal_params,
Expand Down Expand Up @@ -240,6 +332,101 @@ def solve(
)
return period_to_regime_to_V_arr

def _maybe_start_simulate_compile_async(
self,
*,
internal_params: InternalParams,
max_compilation_workers: int | None,
logger: logging.Logger,
) -> None:
"""Spawn `compile_all_simulate_functions` in a background thread.

Called from `solve(...)` so the simulate-side XLA compilation runs in
parallel with the GPU-bound backward induction. No-op when
`n_subjects is None`, when the cache for this size is already
populated, or when a compile is already in flight.
"""
if self.n_subjects is None:
return
with self._simulate_compile_lock:
if self.n_subjects in self._simulate_compile_cache:
return
if self._simulate_compile_future is not None:
return
executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="lcm-simulate-compile"
)
self._simulate_compile_future = executor.submit(
compile_all_simulate_functions,
internal_regimes=self.internal_regimes,
internal_params=internal_params,
ages=self.ages,
n_subjects=self.n_subjects,
max_compilation_workers=max_compilation_workers,
logger=logger,
)
executor.shutdown(wait=False)

def _resolve_simulate_internal_regimes(
self,
*,
actual_n_subjects: int,
internal_params: InternalParams,
log: logging.Logger,
max_compilation_workers: int | None,
) -> MappingProxyType[RegimeName, InternalRegime]:
"""Return internal_regimes to use for simulate; AOT cache when matching.

Three dispatch cases:

- `n_subjects is None`: return the original `internal_regimes`
(purely lazy path).
- `actual_n_subjects != n_subjects`: return the original
`internal_regimes` and log a warning the first time each
mismatching size is seen.
- `actual_n_subjects == n_subjects`: return the cached AOT-compiled
regimes. If `solve(...)` started a background compile, await it
here; otherwise compile synchronously.
"""
if self.n_subjects is None:
return self.internal_regimes
if actual_n_subjects != self.n_subjects:
with self._simulate_compile_lock:
already_warned = actual_n_subjects in self._warned_n_subjects
if not already_warned:
self._warned_n_subjects.add(actual_n_subjects)
if not already_warned:
log.warning(
"simulate called with n_subjects=%d but model declared "
"n_subjects=%d; falling back to runtime compile.",
actual_n_subjects,
self.n_subjects,
)
return self.internal_regimes
with self._simulate_compile_lock:
if self.n_subjects in self._simulate_compile_cache:
return self._simulate_compile_cache[self.n_subjects]
future = self._simulate_compile_future
if future is not None:
compiled = future.result()
with self._simulate_compile_lock:
self._simulate_compile_cache[self.n_subjects] = compiled
self._simulate_compile_future = None
return compiled
with self._simulate_compile_lock:
if self.n_subjects not in self._simulate_compile_cache:
self._simulate_compile_cache[self.n_subjects] = (
compile_all_simulate_functions(
internal_regimes=self.internal_regimes,
internal_params=internal_params,
ages=self.ages,
n_subjects=self.n_subjects,
max_compilation_workers=max_compilation_workers,
logger=log,
)
)
return self._simulate_compile_cache[self.n_subjects]

def simulate(
self,
*,
Expand Down Expand Up @@ -341,10 +528,17 @@ def simulate(
)
exc.add_note(f"Snapshot saved to {snap_dir}")
raise
actual_n_subjects = len(next(iter(initial_conditions.values())))
simulate_internal_regimes = self._resolve_simulate_internal_regimes(
actual_n_subjects=actual_n_subjects,
internal_params=internal_params,
log=log,
max_compilation_workers=max_compilation_workers,
)
result = simulate(
internal_params=internal_params,
initial_conditions=initial_conditions,
internal_regimes=self.internal_regimes,
internal_regimes=simulate_internal_regimes,
regime_names_to_ids=self.regime_names_to_ids,
logger=log,
period_to_regime_to_V_arr=period_to_regime_to_V_arr,
Expand Down
Loading
Loading