Pin user-supplied floats to canonical dtype at every API boundary#345
Open
hmgaudecker wants to merge 29 commits intofeat/simulate-aot-n-subjectsfrom
Open
Pin user-supplied floats to canonical dtype at every API boundary#345hmgaudecker wants to merge 29 commits intofeat/simulate-aot-n-subjectsfrom
hmgaudecker wants to merge 29 commits intofeat/simulate-aot-n-subjectsfrom
Conversation
Pin every float that crosses into pylcm to a single canonical dtype
derived from `jax.config.jax_enable_x64`. Adds the
`canonical_float_dtype()` and `safe_to_float_dtype(value, *, name)`
helpers next to Package A's int counterparts, and applies them at the
same boundaries.
Helpers (`src/lcm/dtypes.py`):
- `canonical_float_dtype()`: `jnp.float64` if x64, else `jnp.float32`.
Read at call time so toggling JAX config between tests is honoured.
- `safe_to_float_dtype(value, *, name)`: host-side cast with overflow
check on float64 -> float32 down-casts. Up-casts and same-width
casts skip the range check; precision loss within range is *not*
an error (it's an inherent consequence of the user's x64 choice).
Boundaries:
- `_cast_int_leaves_to_int32` becomes `_cast_leaves_to_canonical_dtype`:
one pass that handles both int (via `safe_to_int32`) and float (via
`safe_to_float_dtype`). Python `int` / `float` / `bool` scalars
pass through to keep JAX weak-typing semantics. `pd.Series` leaves
pass through too — `convert_series_in_params` reshapes them later
based on their multi-index.
- `build_initial_states` casts continuous user arrays to
`canonical_float_dtype()` and pins the missing-state NaN fallback
to the same dtype. After this, both discrete (int32) and
continuous (canonical float) state pools have stable dtypes
across all simulate periods.
- `_update_states_for_subjects` now unconditionally casts
`next_state_values` to the storage dtype. Package A's cross-kind
guard is no longer needed: with the continuous-state cast in
place, storage dtype is always the canonical one for that kind.
Tests:
- New `tests/test_float_dtype_invariants.py` (10 tests):
- `canonical_float_dtype()` follows `jax_enable_x64`
- `safe_to_float_dtype` round-trip / down-cast / up-cast / overflow
- `build_initial_states` continuous-state casts (float64 input,
int input, missing-state NaN fallback)
- `process_params` casts typed float arrays, leaves Python floats
weak-typed, raises on float-array overflow with qualified name,
casts inside `MappingLeaf`
- `LinSpacedGrid` / `LogSpacedGrid` / `IrregSpacedGrid` `to_jax`
materialise at canonical dtype
- `model.solve(...)` V-arrays at canonical dtype
- Multi-period simulate: every state's dtype is stable across
periods (no silent promotion mid-run)
- `test_validate_param_types`: `numpy_array_param_rejected` ->
`numpy_array_param_accepted_and_normalised`. With the boundary
cast in place, numpy arrays are auto-converted; the historical
rejection-by-isinstance is no longer needed.
- Extended `tests/test_dtypes.py` with 7 float-helper tests.
After Package B, the simulate-AOT path traces against a single
abstract signature (int32 + canonical float) regardless of how
users supply their inputs (Python scalars, mixed-precision JAX
arrays, numpy arrays).
…h numpy Same fix as Package A's int side. Under `jax_enable_x64=False`, `jnp.asarray(..., dtype=jnp.float64)` of `1e40` saturates to `±inf` at construction time before `safe_to_float_dtype` ever sees it. Use `np.asarray(..., dtype=np.float64)` so the value reaches the boundary helper as a real float64 and the helper produces its own qualified-name `OverflowError`.
5 tasks
…-dtype # Conflicts: # src/lcm/params/processing.py
Benchmark comparison (main → HEAD)Comparing
|
Source: - `_update_states_for_subjects`: drop `next_state_values.astype(target_dtype)`. With every input boundary pinned to the canonical dtype, a pure-JAX user transition consuming canonical inputs returns canonical outputs and the cast is a no-op. The only case it covered — a user transition that explicitly produces a non-canonical dtype — is now surfaced loudly via AOT cache mismatch instead of silently coerced. - `safe_to_float_dtype` docstring: bullet list for the cast-direction enumeration (down-cast vs up/same-width). - `process_params` module + function docstring: extend to mention float cast and `OverflowError` alongside the int cast and `ValueError`. - `_cast_leaves_to_canonical_dtype`: rephrase `pd.Series` justification to describe the *property* (multi-index structure) rather than the internal helper that handles it. - `convert_series_in_params` and `initial_conditions_from_dataframe`: route every `pd.Series` -> JAX-array conversion through `canonical_float_dtype()` so the boundary contract holds for pandas-backed params and pandas-backed initial conditions. - `dtypes.py` module docstring: drop the contradictory note about downstream `.astype` casts (downstream no longer casts). Tests: - Move `x64_disabled` / `x64_enabled` fixtures to `tests/conftest.py` (were duplicated across two test files). - `test_safe_to_float_dtype_casts_float64_array_to_float32`: switch to `np.asarray(..., dtype=np.float64)`. With `jnp.asarray`, JAX silently truncated to `float32` at construction time under no-x64, so the helper's down-cast path was never exercised. - Same fix applied to every `tests/test_float_dtype_invariants.py` test that built a float64 input under the `x64_disabled` fixture. - Split / parametrise multi-assertion tests in `test_float_dtype_invariants.py`: continuous-grid `to_jax` over `LinSpacedGrid` / `LogSpacedGrid` / `IrregSpacedGrid`; `MappingLeaf` float keys parametrised over `["low", "high"]`. - Add `test_process_params_casts_float_array_inside_sequence_leaf_to_canonical` to mirror the `MappingLeaf` test (parametrised over `[0, 1]`). - Add `test_build_initial_states_missing_continuous_fallback_values_are_nan` asserting the fallback is actually NaN (not just at canonical dtype). - `test_continuous_grid_to_jax_dtype_is_canonical_under_no_x64`: assert against the literal `jnp.float32` instead of `canonical_float_dtype()` so a future grid implementation that hardcodes `float64` would surface here (current form passed trivially because both sides are driven by the same x64 flag). - `test_simulate_state_pool_dtype_stable_across_periods` and `test_solve_v_arrays_at_canonical_float_dtype`: collect violations into a single dict and assert non-emptiness, so failures still name the offending state / V-array but the test has one assertion. - `tests/test_validate_param_types.py`: rewrite the three numpy / JAX / Python-scalar tests to assert the *normalised* leaf type and dtype via `_process_params`, not just "no exception raised". Update module docstring to describe the current behaviour. - `tests/test_int_dtype_invariants.py::test_update_states_for_subjects_*`: rewrite as a positive same-dtype round-trip — the previous form pinned a guard that the cast removal eliminated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model dropped the `pref_type_grid` default on `create_benchmark_model`. Forward `DiscreteGrid(BenchmarkPrefType)` explicitly to keep the benchmark on its 2-type pref-type axis.
for more information, see https://pre-commit.ci
Per AGENTS.md: no in-function imports.
hmgaudecker
added a commit
that referenced
this pull request
May 8, 2026
Drops the "Python scalar pass-through to keep JAX weak-typing semantics" line from `_cast_int_leaves_to_int32`: a Python `int` arriving at a DAG function as a Python scalar now becomes `jnp.int32(value)` so downstream code sees a JAX-typed scalar rather than a Python int that JAX would promote per call site. This finishes the "no Python scalars inside JIT'd loops" goal that motivated the dtype-barrier work. `bool` is short-circuited (the float-side cast pass on #345 will handle it; bool is a Python `int` subclass so the bool branch must come before the int one). Module + function docstrings refreshed accordingly. Test `test_process_params_passes_python_int_through_for_jax_weak_typing` renamed and flipped to assert `dtype == jnp.int32`. `ScalarInt` keeps the `int | Int32[Scalar, ""]` union: tightening to JAX-only cascades into 13 call-site mismatches at internal Python metadata sites (`n_points`, `period`) that legitimately pass Python `int` outside the JIT'd DAG. Tightening the alias is a separate audit and follow-up. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Folds #340's int Python-scalar cast into the broader #345 work: - `_cast_leaves_to_canonical_dtype` becomes a strict whitelist with explicit raises (no silent fallthrough). Dispatch order: `MappingLeaf`/`SequenceLeaf` → recurse, `pd.Series` → defensive raise (orchestrator must reshape first), `bool` (before `int`), `int`, `float`, `np.ndarray | jax.Array` keyed on `dtype.kind` (`b/i/u/f`). Anything else (`str`, complex/object arrays, custom objects) raises `InvalidParamsError` with the leaf's qualified name. - `broadcast_to_template` becomes broadcast-only — the inline cast loop moves out into the new top-level `cast_params_to_canonical_dtypes(internal_params)` helper. - `process_params` chains broadcast + cast for callers that have no `pd.Series` leaves; pd.Series-bearing callers must orchestrate the three steps explicitly. - `_process_params` (model.py) and `_build_regimes_and_template_with_fixed_params` (model_processing.py) now sequence broadcast → reshape → cast → validate. The cast walks a uniform tree (no `pd.Series` to special- case) which lets the whitelist drop the pass-through branch. - `bool` casts to `jnp.bool_`, `float` casts to `canonical_float_dtype()`, completing the "no Python scalars inside JIT'd loops" goal — DAG functions now receive JAX-typed scalars end-to-end. - `ScalarBool` alias added (`bool | Bool[Scalar, ""]`). `ScalarInt` / `ScalarFloat` keep their `int |` / `float |` unions for now; the tighter forms cascade into call-site mismatches at internal Python metadata sites (`n_points`, `period`) and warrant a separate audit. Tests: - `test_process_params_passes_python_float_through_for_jax_weak_typing` → `..._casts_python_float_to_canonical`, asserts dtype. - `test_python_float_param_passed_through_for_weak_typing` → `..._cast_to_canonical_dtype`, asserts dtype. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After `cast_params_to_canonical_dtypes`, every leaf is either a JAX `Array` or a `MappingLeaf` / `SequenceLeaf`. The validator's Python scalar branch and duck-typed-array branch can never fire, so collapse the dispatch to the three live cases. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…uction ScalarFloat, ScalarInt, and ScalarBool now stand for JAX scalars only, so downstream annotations (e.g. aca-model DAG functions) carry the "post-cast invariant" guarantee accurately. Changes that follow from the tightening: - UniformContinuousGrid (LinSpacedGrid, LogSpacedGrid) and IrregSpacedGrid use a manual __init__ to accept Python literals at the user-facing API and store start/stop/points as JAX scalars at canonical_float_dtype(). Grid dtype is now sticky to construction time x64 mode. - Coordinate helpers (linspace, logspace, get_*_coordinate, Grid.get_coordinate) widen each numeric slot to `float | ScalarFloat` / `int | ScalarInt` so they remain callable from setup-time Python code as well as the JIT'd DAG. - simulate.py replaces `enumerate(ages.values)` with index-based iteration so `age` carries a proper JAX-scalar type; transitions.py follows. - Display/diagnostic age parameters in error_handling.py and logging.py widen to `int | float | ScalarInt | ScalarFloat` so Python literals from `_DiagnosticRow` keep working. Test changes: parametrised dtype-invariant test now constructs grids inside the test body so the x64_disabled fixture is in effect; the returning-int test in test_regime_state_mismatch flips to `-> int`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`linspace`, `logspace`, `get_*_coordinate` are pylcm-internal: every production caller (Grid methods, piecewise dispatchers) hands them JAX scalars. Drop the `float | ScalarFloat` widening on `start` / `stop` / `value` so the helpers pin the post-cast contract. Conversion of user input now happens once at the public-API boundary, inside `Grid.get_coordinate`, via a small `_to_jax_scalar` helper. The helper-direct tests in test_grid_helpers.py wrap their literals with `jnp.asarray` to match. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`Model.__init__` lifts `fixed_params` Python scalars to JAX arrays via the boundary dtype cast, which initialises CUDA in the parent process when running under cuda12. ASV forks the benchmark worker from that parent; the inherited CUDA context is unusable in the child and surfaces as `CUDA_ERROR_NOT_INITIALIZED` on the first device op. Wrap `_build()` in `jax.default_device(cpu)` so all setup-time array creations stay on CPU. The worker process initialises CUDA freshly when `simulate(...)` runs in `setup`/method bodies; JAX moves the deserialised arrays to GPU on demand. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…sult When `Model(n_subjects=N)` triggers an AOT compile, every `InternalRegime.simulate_functions` field carries a `jax.stages.Compiled` that holds an unpicklable `LoadedExecutable`. The snapshot already side-loads the V-array via HDF5; widen the strip pass to overwrite `SimulationResult._internal_regimes` with `model.internal_regimes` (the lazy regimes — same metadata, JIT'd `PjitFunction`s pickle cleanly, which is why `model.pkl` survives the same round-trip). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ASV's forkserver runs `preimport` to discover benchmarks across every `bench_*.py` module before forking workers. Importing JAX at module top loads the multithreaded XLA backend into the forkserver; every subsequent `os.fork()` (for any benchmark, not just this one) inherits a corrupted CUDA context and the first device op in the worker aborts with `CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the forkserver and confine it to the worker process. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Continues the dtype-barrier work by promoting internal scalar metadata to JAX-typed forms wherever it lives strictly inside pylcm: - `UniformContinuousGrid.n_points` and `Piece.n_points` are stored as `jnp.int32` JAX scalars, converted from the Python literals at construction. `_init_uniform_grid` casts `start` / `stop` / `n_points` at the boundary before validation; the validator can then assume strict `ScalarFloat` / `ScalarInt` arguments and only check value invariants. Coordinate helpers (`linspace`, `logspace`, `get_*_coordinate`) tighten `n_points` to `ScalarInt` so the conversion happens once at the boundary instead of at every call. - `Grid.get_coordinate` reverts to `ScalarFloat | Array` (no Python float). The single production caller in `regime_building/V.py` always passes a JAX array; tests that called the helpers with Python literals wrap them with `jnp.asarray` / `jnp.int32`. - `Period` aliases `ScalarInt` and `Age` aliases `ScalarInt | ScalarFloat` for the JIT-internal scalar contexts. `AgeGrid.period_to_age` and `age_to_period` use plain `int | float` directly since they are user-facing API methods returning Python values. - `_simulate_regime_in_period` and the `transitions.py` helpers now take `period: ScalarInt`. The simulation loop derives `period = jnp.int32 (period_idx)` once per iteration and passes it through; dict-key lookups (`argmax_and_max_Q_over_a[period_idx]`, `period_to_regime_to_V_arr.get(period_idx + 1)`) keep using the Python int. - `FlatRegimeParams` tightens to `MappingProxyType[str, Array]` — post-whitelist every leaf is a JAX array, the prior `bool | float | Array` union was stale. - `safe_to_int32` renamed to `safe_to_int_dtype` to mirror `safe_to_float_dtype`. - `_strip_V_arr_from_result` made fully kw-only. - `pyproject.toml` ignores `ARG001` for `tests/test_float_dtype_invariants.py` so per-test `# noqa: ARG001` comments drop out and signatures collapse to a single line. - `Piece` becomes `init=False` with a manual `__init__` that lifts `n_points` to `jnp.int32`, mirroring `UniformContinuousGrid`. Test-side fallout addressed in the same commit: literals wrapped with `jnp.asarray` / `jnp.int32` where helpers tightened, redundant `# ty: ignore` comments dropped, and three "validator rejects non-numeric" tests reframed to assert the boundary cast catches them. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`jnp.linspace`/`jnp.logspace`'s `num` parameter is annotated `int` in JAX's stubs but accepts `jnp.int32` JAX scalars in eager mode (verified on cuda12). Pass `n_points: ScalarInt` through directly and silence the type-check mismatch at the single call site rather than materialising the JAX scalar to a Python int. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the Python `sum(generator, start=jnp.int32(0))` with a single `_piece_n_points.sum()` reduction. The cached `Int1D` is already populated by `_init_piecewise_grid_cache`, the property is read after `__post_init__`, and the result is the same `ScalarInt`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pull in the consumption-grid pinning, borrowing-constraint kink fix, and precision-workaround cleanups so the GPU benchmark CI runs the benchmark-aca-baseline kernel that aca-dev currently tracks. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The period_idx / period split was noisy: every loop iteration computed both a Python int (for dict-key indexing and `period in active_periods`) and a JAX scalar (for the JIT'd compute call). Drop the JAX-scalar shadow; iterate `for period, age in enumerate(ages.values)` once. `_simulate_regime_in_period(period: int)` keeps the integer through dict lookups and casts to `jnp.int32(period)` only at the `argmax_and_max_Q_over_a` / next-state JIT boundaries. Same pattern for transitions.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pulls in the aca-model CI workflow's matching pylcm pin so the GPU benchmark CI runs the same aca-model rev that aca-dev now tracks. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
solve() no longer touches simulate-side compile state. simulate() is the sole driver: spawns the AOT compile in a background thread when n_subjects is set and the batch shape matches, then runs solve (if period_to_regime_to_V_arr is None) and awaits the future at the state-action-space dispatch point. Both public methods share an internal _solve_compiled() body for the snapshot/error handling. Drops _simulate_compile_future from instance state — the future lives in a local variable on the simulate() stack, so there's no per-process state to gate against. The lock keeps protecting _simulate_compile_cache and _warned_n_subjects; the rest of the "maybe spawn" logic collapses into a single inline check at the simulate() call site. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move the ARG001 ignore for the x64_disabled / x64_enabled fixture pattern into pyproject.toml's per-file-ignores for test_dtypes.py and test_float_dtype_invariants.py, then drop the per-call noqa comments and the now-redundant -> None return annotations (tests/* already ignores ANN). Single-arg signatures collapse to one line; longer ones stay wrapped, but without the trailing comma noise. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`period=1, age=1.0, **flat_regime_params={...float...}` was suppressed
with `# ty: ignore[invalid-argument-type]` to keep the call site
short. Once `ScalarInt` / `ScalarFloat` tightened to JAX-only, the
fix is to pass `jnp.int32(1)` / `jnp.asarray(1.0)` (and to wrap the
float param leaves in `jnp.asarray`). The ignore comments come out
and the call site genuinely type-checks.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`SimulationResult.to_pickle()` (and any cloudpickle.dumps on the result) hit `cannot pickle 'jaxlib._jax.LoadedExecutable'` when the result carried the AOT-compiled `internal_regimes`. The compiled callables (`argmax_and_max_Q_over_a`, `next_state`, `compute_regime_transition_probs`) wrap a `LoadedExecutable` that can't survive a process boundary. `to_dataframe` only reads `simulate_functions.functions / constraints / transitions / stochastic_transition_names` — none of which the AOT pass replaces. So after `simulate(...)` runs, the result has no use for the compiled callables: `model.simulate()` swaps them out for the lazy `self.internal_regimes` before returning. Add a TDD test that round-trips the result through cloudpickle under `n_subjects` matching, which is the failure mode pytask hit on HPC. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Context
Continues the int-side normalisation in #340 with the float side, though for a completely different reason.
The constraint:
What lands on each side without dtype barriers (under
jax_enable_x64=True,which
aca_model/__init__.pysets at import):consumption: action grid quantized tojnp.float32in theruntime-consumption-points path. Promoted to
fp64for the comparison —but promotion preserves the quantization error, it doesn't undo it.
consumption_floor * equivalence_scale:consumption_flooris aPython float (fp64 precision), so the RHS keeps fp64 throughout.
When
cash_on_handtook large negative values, the two sides differ by less thanthe smallest gap fp64 can represent at that magnitude (a fraction of a single
fp32 quantization step, leaked into fp64 by the promotion).
<=flips, andvalidate_initial_conditionsraisesInvalidInitialConditionsError.This was very annoying to debug. To have one less thing to worry about, this PR makes sure all floats have a consistent dtype.
Overview
Adds
canonical_float_dtype()andsafe_to_float_dtypenext to the inthelpers from #340, and applies them at the same boundaries (params,
initial conditions, transition outputs, V-arrays).
What lands
src/lcm/dtypes.pycanonical_float_dtype()returnsjnp.float64underjax_enable_x64=True, elsejnp.float32. Read at call time.safe_to_float_dtype(value, *, name)casts to the canonical dtypeand raises
OverflowError(with the leaf's qualified name) whendown-casting a value above
float32magnitude. Up-casts andsame-width casts skip the range check; precision loss within range
is not an error.
Params boundary (
src/lcm/params/processing.py)_cast_leaves_to_canonical_dtypeextends the int-only pass fromModel.n_subjects: AOT-compile simulate, lock integer dtype to int32 #340 to also cast typed float arrays. Python scalars (
int/float/
bool) pass through to keep JAX weak-typing.pd.Seriesleavespass through too —
convert_series_in_paramsreshapes them laterbased on the multi-index.
Simulate boundary (
src/lcm/simulation/initial_conditions.py)build_initial_statescasts continuous user arrays to the canonicalfloat dtype and pins the missing-state
nanfallback to the samedtype. Together with the discrete-state int32 cast from Model.n_subjects: AOT-compile simulate, lock integer dtype to int32 #340, the
simulate state pool has a stable abstract signature across all
periods.
Transition boundary (
src/lcm/simulation/transitions.py)_update_states_for_subjectsunconditionally castsnext_state_valuesto the storage dtype. The cross-kind escapehatch added in Model.n_subjects: AOT-compile simulate, lock integer dtype to int32 #340 (so an int-typed user initial condition for a
continuous state would not be coerced) is no longer needed — the
initial-state cast above pins storage to the canonical float dtype
upstream of this site.
Tests
tests/test_float_dtype_invariants.py(10 tests): helper round-trips,initial-state casts, params casts, grid materialisation, V-array
dtype, multi-period state-dtype stability.
tests/test_dtypes.py: 7 additional float-helper unit tests.tests/test_validate_param_types.py:numpy_array_param_rejected->numpy_array_param_accepted_and_ normalised. With the boundary cast in place numpy arrays areauto-converted; the historical rejection-by-isinstance is obsolete.
928 pass, 5 skip; prek + ty clean.
Stacked on
This branch is stacked on
feat/simulate-aot-n-subjects(#340) — thebase ref for this PR. Merge order: #340 first, then this. The diff
view here only shows the float-side changes.
Out of scope
IntND/FloatNDaliases (see Model.n_subjects: AOT-compile simulate, lock integer dtype to int32 #340 review thread —permanently out).
x64 by default; this normalises within whichever mode the user
picked).