Skip to content

Pin user-supplied floats to canonical dtype at every API boundary#345

Open
hmgaudecker wants to merge 29 commits intofeat/simulate-aot-n-subjectsfrom
feat/canonical-float-dtype
Open

Pin user-supplied floats to canonical dtype at every API boundary#345
hmgaudecker wants to merge 29 commits intofeat/simulate-aot-n-subjectsfrom
feat/canonical-float-dtype

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 6, 2026

Context

Continues the int-side normalisation in #340 with the float side, though for a completely different reason.

The constraint:

def borrowing_constraint(
    consumption: ContinuousAction,    # action grid, fp32 (quantized via jnp.float32)
    cash_on_hand: FloatND,
    consumption_floor: float,         # Python float — fp64
    equivalence_scale: FloatND,
) -> BoolND:
    return consumption <= cash_on_hand + consumption_floor * equivalence_scale

What lands on each side without dtype barriers (under jax_enable_x64=True,
which aca_model/__init__.py sets at import):

  • LHS consumption: action grid quantized to jnp.float32 in the
    runtime-consumption-points path. Promoted to fp64 for the comparison —
    but promotion preserves the quantization error, it doesn't undo it.
  • RHS consumption_floor * equivalence_scale: consumption_floor is a
    Python float (fp64 precision), so the RHS keeps fp64 throughout.

When cash_on_hand took large negative values, the two sides differ by less than
the smallest gap fp64 can represent at that magnitude (a fraction of a single
fp32 quantization step, leaked into fp64 by the promotion). <= flips, and
validate_initial_conditions raises InvalidInitialConditionsError.

This was very annoying to debug. To have one less thing to worry about, this PR makes sure all floats have a consistent dtype.

Overview

Adds canonical_float_dtype() and safe_to_float_dtype next to the int
helpers from #340, and applies them at the same boundaries (params,
initial conditions, transition outputs, V-arrays).

What lands

src/lcm/dtypes.py

  • canonical_float_dtype() returns jnp.float64 under
    jax_enable_x64=True, else jnp.float32. Read at call time.
  • safe_to_float_dtype(value, *, name) casts to the canonical dtype
    and raises OverflowError (with the leaf's qualified name) when
    down-casting a value above float32 magnitude. Up-casts and
    same-width casts skip the range check; precision loss within range
    is not an error.

Params boundary (src/lcm/params/processing.py)

Simulate boundary (src/lcm/simulation/initial_conditions.py)

Transition boundary (src/lcm/simulation/transitions.py)

  • _update_states_for_subjects unconditionally casts
    next_state_values to the storage dtype. The cross-kind escape
    hatch added in Model.n_subjects: AOT-compile simulate, lock integer dtype to int32 #340 (so an int-typed user initial condition for a
    continuous state would not be coerced) is no longer needed — the
    initial-state cast above pins storage to the canonical float dtype
    upstream of this site.

Tests

  • tests/test_float_dtype_invariants.py (10 tests): helper round-trips,
    initial-state casts, params casts, grid materialisation, V-array
    dtype, multi-period state-dtype stability.
  • tests/test_dtypes.py: 7 additional float-helper unit tests.
  • tests/test_validate_param_types.py:
    numpy_array_param_rejected -> numpy_array_param_accepted_and_ normalised. With the boundary cast in place numpy arrays are
    auto-converted; the historical rejection-by-isinstance is obsolete.

928 pass, 5 skip; prek + ty clean.

Stacked on

This branch is stacked on feat/simulate-aot-n-subjects (#340) — the
base ref for this PR. Merge order: #340 first, then this. The diff
view here only shows the float-side changes.

Out of scope

Pin every float that crosses into pylcm to a single canonical dtype
derived from `jax.config.jax_enable_x64`. Adds the
`canonical_float_dtype()` and `safe_to_float_dtype(value, *, name)`
helpers next to Package A's int counterparts, and applies them at the
same boundaries.

Helpers (`src/lcm/dtypes.py`):

- `canonical_float_dtype()`: `jnp.float64` if x64, else `jnp.float32`.
  Read at call time so toggling JAX config between tests is honoured.
- `safe_to_float_dtype(value, *, name)`: host-side cast with overflow
  check on float64 -> float32 down-casts. Up-casts and same-width
  casts skip the range check; precision loss within range is *not*
  an error (it's an inherent consequence of the user's x64 choice).

Boundaries:

- `_cast_int_leaves_to_int32` becomes `_cast_leaves_to_canonical_dtype`:
  one pass that handles both int (via `safe_to_int32`) and float (via
  `safe_to_float_dtype`). Python `int` / `float` / `bool` scalars
  pass through to keep JAX weak-typing semantics. `pd.Series` leaves
  pass through too — `convert_series_in_params` reshapes them later
  based on their multi-index.
- `build_initial_states` casts continuous user arrays to
  `canonical_float_dtype()` and pins the missing-state NaN fallback
  to the same dtype. After this, both discrete (int32) and
  continuous (canonical float) state pools have stable dtypes
  across all simulate periods.
- `_update_states_for_subjects` now unconditionally casts
  `next_state_values` to the storage dtype. Package A's cross-kind
  guard is no longer needed: with the continuous-state cast in
  place, storage dtype is always the canonical one for that kind.

Tests:

- New `tests/test_float_dtype_invariants.py` (10 tests):
  - `canonical_float_dtype()` follows `jax_enable_x64`
  - `safe_to_float_dtype` round-trip / down-cast / up-cast / overflow
  - `build_initial_states` continuous-state casts (float64 input,
    int input, missing-state NaN fallback)
  - `process_params` casts typed float arrays, leaves Python floats
    weak-typed, raises on float-array overflow with qualified name,
    casts inside `MappingLeaf`
  - `LinSpacedGrid` / `LogSpacedGrid` / `IrregSpacedGrid` `to_jax`
    materialise at canonical dtype
  - `model.solve(...)` V-arrays at canonical dtype
  - Multi-period simulate: every state's dtype is stable across
    periods (no silent promotion mid-run)
- `test_validate_param_types`: `numpy_array_param_rejected` ->
  `numpy_array_param_accepted_and_normalised`. With the boundary
  cast in place, numpy arrays are auto-converted; the historical
  rejection-by-isinstance is no longer needed.
- Extended `tests/test_dtypes.py` with 7 float-helper tests.

After Package B, the simulate-AOT path traces against a single
abstract signature (int32 + canonical float) regardless of how
users supply their inputs (Python scalars, mixed-precision JAX
arrays, numpy arrays).
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 6, 2026

…h numpy

Same fix as Package A's int side. Under `jax_enable_x64=False`,
`jnp.asarray(..., dtype=jnp.float64)` of `1e40` saturates to `±inf`
at construction time before `safe_to_float_dtype` ever sees it. Use
`np.asarray(..., dtype=np.float64)` so the value reaches the boundary
helper as a real float64 and the helper produces its own
qualified-name `OverflowError`.
…-dtype

# Conflicts:
#	src/lcm/params/processing.py
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 6, 2026

Benchmark comparison (main → HEAD)

Comparing a4eca9bf (main) → ca66ba9b (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 46.946 s 27.690 s 0.59
peak GPU mem 671 MB 1.24 GB 1.85
compilation time 427.20 s 310.71 s 0.73
peak CPU mem 8.39 GB 7.51 GB 0.90
Mahler-Yum execution time 4.756 s 4.733 s 1.00
peak GPU mem 522 MB 529 MB 1.01
compilation time 16.46 s 16.21 s 0.98
peak CPU mem 1.68 GB 1.68 GB 1.00
Precautionary Savings - Solve execution time 45.5 ms 49.0 ms 1.08
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.52 s 2.79 s 1.11
peak CPU mem 1.12 GB 1.13 GB 1.00
Precautionary Savings - Simulate execution time 119.6 ms 119.1 ms 1.00
peak GPU mem 340 MB 344 MB 1.01
compilation time 5.99 s 5.07 s 0.85
peak CPU mem 1.29 GB 1.31 GB 1.01
Precautionary Savings - Solve & Simulate execution time 136.7 ms 155.2 ms 1.14
peak GPU mem 577 MB 578 MB 1.00
compilation time 7.78 s 6.78 s 0.87
peak CPU mem 1.28 GB 1.29 GB 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 280.2 ms 284.1 ms 1.01
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 8.59 s 7.48 s 0.87
peak CPU mem 1.33 GB 1.34 GB 1.00

hmgaudecker and others added 7 commits May 7, 2026 06:03
Source:

- `_update_states_for_subjects`: drop `next_state_values.astype(target_dtype)`.
  With every input boundary pinned to the canonical dtype, a pure-JAX
  user transition consuming canonical inputs returns canonical outputs
  and the cast is a no-op. The only case it covered — a user transition
  that explicitly produces a non-canonical dtype — is now surfaced
  loudly via AOT cache mismatch instead of silently coerced.
- `safe_to_float_dtype` docstring: bullet list for the cast-direction
  enumeration (down-cast vs up/same-width).
- `process_params` module + function docstring: extend to mention float
  cast and `OverflowError` alongside the int cast and `ValueError`.
- `_cast_leaves_to_canonical_dtype`: rephrase `pd.Series` justification
  to describe the *property* (multi-index structure) rather than the
  internal helper that handles it.
- `convert_series_in_params` and `initial_conditions_from_dataframe`:
  route every `pd.Series` -> JAX-array conversion through
  `canonical_float_dtype()` so the boundary contract holds for
  pandas-backed params and pandas-backed initial conditions.
- `dtypes.py` module docstring: drop the contradictory note about
  downstream `.astype` casts (downstream no longer casts).

Tests:

- Move `x64_disabled` / `x64_enabled` fixtures to `tests/conftest.py`
  (were duplicated across two test files).
- `test_safe_to_float_dtype_casts_float64_array_to_float32`: switch to
  `np.asarray(..., dtype=np.float64)`. With `jnp.asarray`, JAX silently
  truncated to `float32` at construction time under no-x64, so the
  helper's down-cast path was never exercised.
- Same fix applied to every `tests/test_float_dtype_invariants.py`
  test that built a float64 input under the `x64_disabled` fixture.
- Split / parametrise multi-assertion tests in
  `test_float_dtype_invariants.py`: continuous-grid `to_jax` over
  `LinSpacedGrid` / `LogSpacedGrid` / `IrregSpacedGrid`; `MappingLeaf`
  float keys parametrised over `["low", "high"]`.
- Add `test_process_params_casts_float_array_inside_sequence_leaf_to_canonical`
  to mirror the `MappingLeaf` test (parametrised over `[0, 1]`).
- Add `test_build_initial_states_missing_continuous_fallback_values_are_nan`
  asserting the fallback is actually NaN (not just at canonical dtype).
- `test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64`:
  assert against the literal `jnp.float32` instead of
  `canonical_float_dtype()` so a future grid implementation that
  hardcodes `float64` would surface here (current form passed
  trivially because both sides are driven by the same x64 flag).
- `test_simulate_state_pool_dtype_stable_across_periods` and
  `test_solve_v_arrays_at_canonical_float_dtype`: collect violations
  into a single dict and assert non-emptiness, so failures still name
  the offending state / V-array but the test has one assertion.
- `tests/test_validate_param_types.py`: rewrite the three numpy / JAX
  / Python-scalar tests to assert the *normalised* leaf type and
  dtype via `_process_params`, not just "no exception raised". Update
  module docstring to describe the current behaviour.
- `tests/test_int_dtype_invariants.py::test_update_states_for_subjects_*`:
  rewrite as a positive same-dtype round-trip — the previous form
  pinned a guard that the cast removal eliminated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model dropped the `pref_type_grid` default on
`create_benchmark_model`. Forward `DiscreteGrid(BenchmarkPrefType)`
explicitly to keep the benchmark on its 2-type pref-type axis.
hmgaudecker added a commit that referenced this pull request May 8, 2026
Drops the "Python scalar pass-through to keep JAX weak-typing
semantics" line from `_cast_int_leaves_to_int32`: a Python `int`
arriving at a DAG function as a Python scalar now becomes
`jnp.int32(value)` so downstream code sees a JAX-typed scalar
rather than a Python int that JAX would promote per call site.
This finishes the "no Python scalars inside JIT'd loops" goal
that motivated the dtype-barrier work.

`bool` is short-circuited (the float-side cast pass on #345 will
handle it; bool is a Python `int` subclass so the bool branch must
come before the int one). Module + function docstrings refreshed
accordingly.

Test `test_process_params_passes_python_int_through_for_jax_weak_typing`
renamed and flipped to assert `dtype == jnp.int32`.

`ScalarInt` keeps the `int | Int32[Scalar, ""]` union: tightening
to JAX-only cascades into 13 call-site mismatches at internal
Python metadata sites (`n_points`, `period`) that legitimately
pass Python `int` outside the JIT'd DAG. Tightening the alias is
a separate audit and follow-up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker and others added 15 commits May 8, 2026 10:27
Folds #340's int Python-scalar cast into the broader #345 work:

- `_cast_leaves_to_canonical_dtype` becomes a strict whitelist with
  explicit raises (no silent fallthrough). Dispatch order:
  `MappingLeaf`/`SequenceLeaf` → recurse, `pd.Series` → defensive
  raise (orchestrator must reshape first), `bool` (before `int`),
  `int`, `float`, `np.ndarray | jax.Array` keyed on `dtype.kind`
  (`b/i/u/f`). Anything else (`str`, complex/object arrays, custom
  objects) raises `InvalidParamsError` with the leaf's qualified name.
- `broadcast_to_template` becomes broadcast-only — the inline cast
  loop moves out into the new top-level
  `cast_params_to_canonical_dtypes(internal_params)` helper.
- `process_params` chains broadcast + cast for callers that have no
  `pd.Series` leaves; pd.Series-bearing callers must orchestrate the
  three steps explicitly.
- `_process_params` (model.py) and `_build_regimes_and_template_with_fixed_params`
  (model_processing.py) now sequence broadcast → reshape → cast →
  validate. The cast walks a uniform tree (no `pd.Series` to special-
  case) which lets the whitelist drop the pass-through branch.
- `bool` casts to `jnp.bool_`, `float` casts to
  `canonical_float_dtype()`, completing the "no Python scalars inside
  JIT'd loops" goal — DAG functions now receive JAX-typed scalars
  end-to-end.
- `ScalarBool` alias added (`bool | Bool[Scalar, ""]`). `ScalarInt` /
  `ScalarFloat` keep their `int |` / `float |` unions for now; the
  tighter forms cascade into call-site mismatches at internal Python
  metadata sites (`n_points`, `period`) and warrant a separate audit.

Tests:

- `test_process_params_passes_python_float_through_for_jax_weak_typing`
  → `..._casts_python_float_to_canonical`, asserts dtype.
- `test_python_float_param_passed_through_for_weak_typing` →
  `..._cast_to_canonical_dtype`, asserts dtype.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After `cast_params_to_canonical_dtypes`, every leaf is either a JAX
`Array` or a `MappingLeaf` / `SequenceLeaf`. The validator's Python
scalar branch and duck-typed-array branch can never fire, so collapse
the dispatch to the three live cases.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…uction

ScalarFloat, ScalarInt, and ScalarBool now stand for JAX scalars only,
so downstream annotations (e.g. aca-model DAG functions) carry the
"post-cast invariant" guarantee accurately.

Changes that follow from the tightening:

- UniformContinuousGrid (LinSpacedGrid, LogSpacedGrid) and
  IrregSpacedGrid use a manual __init__ to accept Python literals at
  the user-facing API and store start/stop/points as JAX scalars at
  canonical_float_dtype(). Grid dtype is now sticky to construction
  time x64 mode.

- Coordinate helpers (linspace, logspace, get_*_coordinate,
  Grid.get_coordinate) widen each numeric slot to
  `float | ScalarFloat` / `int | ScalarInt` so they remain callable
  from setup-time Python code as well as the JIT'd DAG.

- simulate.py replaces `enumerate(ages.values)` with index-based
  iteration so `age` carries a proper JAX-scalar type; transitions.py
  follows.

- Display/diagnostic age parameters in error_handling.py and
  logging.py widen to `int | float | ScalarInt | ScalarFloat` so
  Python literals from `_DiagnosticRow` keep working.

Test changes: parametrised dtype-invariant test now constructs grids
inside the test body so the x64_disabled fixture is in effect; the
returning-int test in test_regime_state_mismatch flips to `-> int`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`linspace`, `logspace`, `get_*_coordinate` are pylcm-internal: every
production caller (Grid methods, piecewise dispatchers) hands them
JAX scalars. Drop the `float | ScalarFloat` widening on `start` /
`stop` / `value` so the helpers pin the post-cast contract.

Conversion of user input now happens once at the public-API boundary,
inside `Grid.get_coordinate`, via a small `_to_jax_scalar` helper. The
helper-direct tests in test_grid_helpers.py wrap their literals with
`jnp.asarray` to match.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`Model.__init__` lifts `fixed_params` Python scalars to JAX arrays via
the boundary dtype cast, which initialises CUDA in the parent process
when running under cuda12. ASV forks the benchmark worker from that
parent; the inherited CUDA context is unusable in the child and
surfaces as `CUDA_ERROR_NOT_INITIALIZED` on the first device op.

Wrap `_build()` in `jax.default_device(cpu)` so all setup-time array
creations stay on CPU. The worker process initialises CUDA freshly
when `simulate(...)` runs in `setup`/method bodies; JAX moves the
deserialised arrays to GPU on demand.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…sult

When `Model(n_subjects=N)` triggers an AOT compile, every
`InternalRegime.simulate_functions` field carries a `jax.stages.Compiled`
that holds an unpicklable `LoadedExecutable`. The snapshot already
side-loads the V-array via HDF5; widen the strip pass to overwrite
`SimulationResult._internal_regimes` with `model.internal_regimes`
(the lazy regimes — same metadata, JIT'd `PjitFunction`s pickle cleanly,
which is why `model.pkl` survives the same round-trip).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ASV's forkserver runs `preimport` to discover benchmarks across every
`bench_*.py` module before forking workers. Importing JAX at module
top loads the multithreaded XLA backend into the forkserver; every
subsequent `os.fork()` (for any benchmark, not just this one) inherits
a corrupted CUDA context and the first device op in the worker aborts
with `CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of
the forkserver and confine it to the worker process.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Continues the dtype-barrier work by promoting internal scalar metadata
to JAX-typed forms wherever it lives strictly inside pylcm:

- `UniformContinuousGrid.n_points` and `Piece.n_points` are stored as
  `jnp.int32` JAX scalars, converted from the Python literals at
  construction. `_init_uniform_grid` casts `start` / `stop` /
  `n_points` at the boundary before validation; the validator can then
  assume strict `ScalarFloat` / `ScalarInt` arguments and only check
  value invariants. Coordinate helpers (`linspace`, `logspace`,
  `get_*_coordinate`) tighten `n_points` to `ScalarInt` so the
  conversion happens once at the boundary instead of at every call.
- `Grid.get_coordinate` reverts to `ScalarFloat | Array` (no Python
  float). The single production caller in `regime_building/V.py`
  always passes a JAX array; tests that called the helpers with
  Python literals wrap them with `jnp.asarray` / `jnp.int32`.
- `Period` aliases `ScalarInt` and `Age` aliases `ScalarInt | ScalarFloat`
  for the JIT-internal scalar contexts. `AgeGrid.period_to_age` and
  `age_to_period` use plain `int | float` directly since they are
  user-facing API methods returning Python values.
- `_simulate_regime_in_period` and the `transitions.py` helpers now take
  `period: ScalarInt`. The simulation loop derives `period = jnp.int32
  (period_idx)` once per iteration and passes it through; dict-key
  lookups (`argmax_and_max_Q_over_a[period_idx]`,
  `period_to_regime_to_V_arr.get(period_idx + 1)`) keep using the
  Python int.
- `FlatRegimeParams` tightens to `MappingProxyType[str, Array]` —
  post-whitelist every leaf is a JAX array, the prior `bool | float |
  Array` union was stale.
- `safe_to_int32` renamed to `safe_to_int_dtype` to mirror
  `safe_to_float_dtype`.
- `_strip_V_arr_from_result` made fully kw-only.
- `pyproject.toml` ignores `ARG001` for
  `tests/test_float_dtype_invariants.py` so per-test
  `# noqa: ARG001` comments drop out and signatures collapse to a
  single line.
- `Piece` becomes `init=False` with a manual `__init__` that lifts
  `n_points` to `jnp.int32`, mirroring `UniformContinuousGrid`.

Test-side fallout addressed in the same commit: literals wrapped with
`jnp.asarray` / `jnp.int32` where helpers tightened, redundant
`# ty: ignore` comments dropped, and three "validator rejects
non-numeric" tests reframed to assert the boundary cast catches them.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`jnp.linspace`/`jnp.logspace`'s `num` parameter is annotated `int` in
JAX's stubs but accepts `jnp.int32` JAX scalars in eager mode (verified
on cuda12). Pass `n_points: ScalarInt` through directly and silence the
type-check mismatch at the single call site rather than materialising
the JAX scalar to a Python int.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the Python `sum(generator, start=jnp.int32(0))` with a single
`_piece_n_points.sum()` reduction. The cached `Int1D` is already
populated by `_init_piecewise_grid_cache`, the property is read after
`__post_init__`, and the result is the same `ScalarInt`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pull in the consumption-grid pinning, borrowing-constraint kink fix, and
precision-workaround cleanups so the GPU benchmark CI runs the
benchmark-aca-baseline kernel that aca-dev currently tracks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The period_idx / period split was noisy: every loop iteration computed
both a Python int (for dict-key indexing and `period in active_periods`)
and a JAX scalar (for the JIT'd compute call). Drop the JAX-scalar
shadow; iterate `for period, age in enumerate(ages.values)` once.
`_simulate_regime_in_period(period: int)` keeps the integer through
dict lookups and casts to `jnp.int32(period)` only at the
`argmax_and_max_Q_over_a` / next-state JIT boundaries. Same pattern
for transitions.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pulls in the aca-model CI workflow's matching pylcm pin so the GPU
benchmark CI runs the same aca-model rev that aca-dev now tracks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
solve() no longer touches simulate-side compile state. simulate() is the
sole driver: spawns the AOT compile in a background thread when
n_subjects is set and the batch shape matches, then runs solve (if
period_to_regime_to_V_arr is None) and awaits the future at the
state-action-space dispatch point. Both public methods share an internal
_solve_compiled() body for the snapshot/error handling.

Drops _simulate_compile_future from instance state — the future lives in
a local variable on the simulate() stack, so there's no per-process
state to gate against. The lock keeps protecting _simulate_compile_cache
and _warned_n_subjects; the rest of the "maybe spawn" logic collapses
into a single inline check at the simulate() call site.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker and others added 2 commits May 9, 2026 13:27
Move the ARG001 ignore for the x64_disabled / x64_enabled fixture
pattern into pyproject.toml's per-file-ignores for test_dtypes.py and
test_float_dtype_invariants.py, then drop the per-call noqa comments
and the now-redundant -> None return annotations (tests/* already
ignores ANN). Single-arg signatures collapse to one line; longer ones
stay wrapped, but without the trailing comma noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`period=1, age=1.0, **flat_regime_params={...float...}` was suppressed
with `# ty: ignore[invalid-argument-type]` to keep the call site
short. Once `ScalarInt` / `ScalarFloat` tightened to JAX-only, the
fix is to pass `jnp.int32(1)` / `jnp.asarray(1.0)` (and to wrap the
float param leaves in `jnp.asarray`). The ignore comments come out
and the call site genuinely type-checks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker marked this pull request as ready for review May 9, 2026 11:29
@hmgaudecker hmgaudecker requested review from mj023 and timmens May 9, 2026 11:29
`SimulationResult.to_pickle()` (and any cloudpickle.dumps on the
result) hit `cannot pickle 'jaxlib._jax.LoadedExecutable'` when the
result carried the AOT-compiled `internal_regimes`. The compiled
callables (`argmax_and_max_Q_over_a`, `next_state`,
`compute_regime_transition_probs`) wrap a `LoadedExecutable` that
can't survive a process boundary.

`to_dataframe` only reads `simulate_functions.functions /
constraints / transitions / stochastic_transition_names` — none of
which the AOT pass replaces. So after `simulate(...)` runs, the
result has no use for the compiled callables: `model.simulate()`
swaps them out for the lazy `self.internal_regimes` before
returning.

Add a TDD test that round-trips the result through cloudpickle
under `n_subjects` matching, which is the failure mode pytask hit
on HPC.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant