Skip to content

Counting tropical#42

Closed
xuanzhaogao wants to merge 61 commits into
TensorBFS:mainfrom
xuanzhaogao:counting-tropical
Closed

Counting tropical#42
xuanzhaogao wants to merge 61 commits into
TensorBFS:mainfrom
xuanzhaogao:counting-tropical

Conversation

@xuanzhaogao
Copy link
Copy Markdown

No description provided.

Xuanzhao Gao and others added 30 commits April 21, 2026 12:34
CPU-first plan for exact ground-state counting via CountingTropical with
Mod<P> count field and host-side Chinese Remainder reconstruction to BigInt.
Mirrors GenericTensorNetworks.jl big_integer_solve pattern.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Addresses codex review: the original single spec conflated an ABI change
(thread T end-to-end, drop the T::Scalar repr-transparent transmute)
with the CRT/BigInt driver and Python surface. Split per codex's scope
recommendation.

Spec A: make Mat<CountingTropical<T, C, D>> work through the existing
GEMM pipeline. Parameterize direction via Max/Min marker trait. No CRT,
no BigInt, no Python changes. Count type is plain u64.

Spec B (new file): Mod<const P: i32> count scalar, CRT driver with a
caller-supplied count_upper_bound for unique reconstruction (fixing the
unsound consecutive-equality termination criterion), Python binding.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
14-task plan: Phase 1 adds the TropicalDirection marker and parameterizes
CountingTropical by it; Phase 2 widens the internal GEMM ABI from
*const T::Scalar to *const T (atomic across kernel/packing/dispatch/mat,
single commit at Task 10); Phase 3 wires CountingTropical through and
adds end-to-end integration tests for both Max and Min directions.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a third generic D: TropicalDirection to CountingTropical. The Max
default preserves source compatibility for all existing call sites.
tropical_zero, tropical_add, and tropical_add_argmax route through
D to support both Max (ground state = largest) and Min (ground state
= smallest) semantics.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous SIMD_AVAILABLE=true, SIMD_WIDTH=8 claim was aspirational:
no SIMD kernel exists for CountingTropical. Set to (false, 1) to
reflect reality. Vectorized counting will be a later follow-up once an
SoA-capable kernel lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Encodes the repr(transparent) layout contract as an unsafe marker
trait. Implemented for MaxPlus, MinPlus, MaxMul, AndOr. Will be used
in Phase 2 to let the scalar-slice public API safely reinterpret
&[S::Scalar] as &[S] after the kernel ABI widens from *const T::Scalar
to *const T.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…Phase 2, Tasks 5-10)

- Microkernel::execute now takes *const T / *mut T instead of scalar pointers; SIMD kernels cast internally using repr(transparent) guarantee
- pack_a/pack_b relaxed from TropicalScalar to Copy + Default; packed buffers are Vec<T>
- tropical_gemm_portable/inner and argmax variants thread *const T end-to-end
- KernelDispatch::dispatch_gemm and tropical_gemm_dispatch accept *const T
- MatRef stores &[S] instead of &[S::Scalar]; from_slice/as_slice kept with ReprTransparentTropical bound; from_elements/as_element_slice added for &[S]
- Public API (tropical_matmul, tropical_matmul_with_argmax, TropicalGemm, batched variants) adds ReprTransparentTropical + Default bounds and casts scalar slices on entry
- New tropical_matmul_t<T>(a: &[T], ...) -> Vec<T> function added and re-exported
- ops.rs Mul impls and Mat::matmul_batched impl updated with + Default bound
- tropical-gemm-cuda from_matref/from_mats updated with ReprTransparentTropical bound

All 281 lib tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Blanket KernelDispatch impl for CountingTropical<T, C, D>: all variants
route to the portable (non-SIMD) kernel, which is the only kernel that
supports compound elements today.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Covers both Max and Min directions through the public tropical_matmul_t
entry point: small hand-checked matmul, tie merging, and count
multiplication along a single k path.

Also export Max and Min direction markers from lib.rs for public API.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The type alias MaxPlus cannot be used as a tuple-struct constructor;
.map(MaxPlus) fails with E0423. Use the underlying TropicalMaxPlus.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
12-task plan: Phase 1 adds Mod<P> scalar; Phase 2 adds num-bigint and
CRT primitives (prime table, pairwise combine); Phase 3 adds the
count_ground_states driver with caller-supplied upper bound; Phase 4
adds correctness tests including a BigInt oracle semiring; Phase 5
adds the Python binding (build gated on cluster Python 3.7+);
Phase 6 runs the final regression gate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Residue type with normalizing constructor, modular add/mul using i64
intermediates, and raw accessor. P must satisfy (P-1)^2 < 2^62 to keep
scalar_mul overflow-free. No TropicalScalar impl yet — added in the
next commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Routes scalar_add/scalar_mul through modular arithmetic, returns 0/1
for identities. pos_infinity/neg_infinity/scalar_max/scalar_min panic
with a clear message — Mod<P> is only valid in the count field of
CountingTropical, never as a tropical value.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CRT_PRIMES: 16 distinct 30-bit primes satisfying (P-1)^2 < 2^60.
crt_combine: fold one (residue, prime) into a running (value, modulus)
accumulator via the standard extended-gcd inverse method.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Takes a caller-supplied BigInt count_upper_bound, selects the smallest
subset of CRT_PRIMES whose product exceeds 2*bound (ensuring unique
reconstruction), runs spec-A matmul once per prime with Mod<P> count,
and folds residues into per-cell BigInt via crt_combine.

Asserts the tropical value field is identical across primes — any
divergence indicates an internal invariant bug, not a numerical issue.

Also exports bound_for_single_matmul(k: usize) -> BigInt for the common
case where both input matrices have per-cell counts of 1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Test-only gadget exposing a slow but exact BigInt-count matmul. Used
to verify the CRT driver produces the same counts as a direct BigInt
accumulator computation. Module is gated on cfg(test) or the 'testing'
feature so it does not bloat the public surface.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cross-checks count_ground_states against reference_matmul on small
random matrices (both Max and Min), an all-ties corner case, and a
minimal 1x1 sanity shape. Test requires the 'testing' feature.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Forces a BigInt bound exceeding u128::MAX so the driver must use
multiple CRT_PRIMES. Asserts the reconstructed count still equals
the true (small) count — exercises the multi-prime fold path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NaN comparisons are reflexively false, so they don't satisfy the
is_strictly_better condition in either direction. This triggers the
tie path, confirming the driver handles NaN without corrupting results.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The original plan claimed NaN inputs would trip the cross-prime value
invariant and panic. That assumption was wrong: is_strictly_better is
'>' which is false for both NaN>x and x>NaN, so every NaN candidate
takes the tie branch. With the accumulator starting at -inf for Max,
NaN values merge into the -inf count and get discarded on the first
real comparison — no panic, and the final counts silently ignore the
NaN input. Test renamed and rewritten to pin that behavior so any
future change is intentional. Real NaN safety is a separate question
worth revisiting if it matters to users.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Python API:
    count_ground_states_py(a, b, direction='min', count_upper_bound=None)
      -> (values: ndarray[float32], counts: ndarray[object] of Python int)

Default bound uses bound_for_single_matmul(k). Validates shape compat
and direction string. Releases the GIL during the CRT heavy compute.

Build requires python >= 3.7; load via `module load python` on the
cluster.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Skips gracefully if the extension is not installed (e.g., when running
from a plain cargo checkout without maturin develop). Covers trivial
1x1, tie merging in both directions, Python-int (not numpy int) return
type, and validation of direction.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The counting_crt.rs integration test imports from tropical_gemm::testing,
which is only available when the testing feature is enabled. Without this
gate, cargo test with default features would fail to compile. Now:
- cargo test -p tropical-gemm --features testing: all 335 tests pass
- cargo test -p tropical-gemm: 333 tests pass (excludes 2 gated by testing)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
SoA element layout (two GpuMatrix buffers: values + counts) to reuse
existing GpuMatrix<Scalar> infrastructure. Four kernel instantiations
(f32/f64 × Max/Min) stamped via preprocessor macro, matching the
existing tropical_gemm.cu pattern. Modulus P stays a runtime argument
to avoid 16x kernel cache pressure. GPU CRT driver mirrors the CPU
spec-B driver: upload value matrices once, launch per-prime, download
residues, reconstruct to BigInt host-side.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
8-task plan across 7 phases: Task 1 promotes CPU CRT helpers to pub;
Tasks 2-3 add counting_gemm.cu and wire it into CudaContext; Task 4
adds the Rust launch wrapper; Task 5 adds count_ground_states_gpu
and cross-check tests; Task 6-7 add the Python binding; Task 8 is
the final regression gate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tropical-gemm-cuda will reuse these helpers verbatim in its GPU CRT
driver (spec C).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Four specializations via preprocessor macro: {f32, f64} × {Max, Min}.
SoA layout (parallel value + count pointers), runtime modulus P.
One thread per output cell, 2D grid.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a second NVRTC compile + load_ptx call for the 4 new counting
kernels alongside the existing tropical_gemm module. Adds
counting_grid_dims / counting_block_dims helpers for launch config.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Xuanzhao Gao and others added 25 commits April 21, 2026 19:43
5-task plan: Task 1 rewrites counting_gemm.cu as a BLIS-style tiled
macro; Tasks 2-3 wire the new grid/block dims per T via a launch_dims
trait method (landed together since they depend on each other);
Task 4 adds large, off-boundary, and f64 medium cross-check tests;
Task 5 is the final regression gate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two-level blocking mirroring tropical_gemm.cu: shared-memory block
tile (64x32x64 for f32, 32x16x32 for f64) with a 4x4 register tile
per thread. Four parallel tiles per block (A value, A count, B value,
B count). Four kernel specializations via macro: f32/f64 × Max/Min.

Replaces the naive one-thread-per-cell kernel. Same kernel names, same
API, same signature — internal change only.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Split counting_block_dims / counting_grid_dims into _f32 and _f64
variants matching the tile sizes of the tiled kernel (64x64 for f32,
32x32 for f64). Add a launch_dims associated method on
CountingCudaKernel so the launch wrapper dispatches per (T, D)
via the type system rather than branching on T at runtime.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- 512x512x512 f32 Max: exercises the tile loop over many iterations.
- 17x19x23 f32 Max: every dim prime, stresses predicated tile-load
  bounds checks at every edge simultaneously.
- 128x128x128 f64 Max: validates the f64 macro on a non-trivial shape.

All three cross-check against the CPU count_ground_states oracle.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces per-step software `x % P` (~50 cycles on Turing) with Barrett
reduction using a host-precomputed reciprocal mu = floor(2^64 / P):
  q = __umul64hi(x, mu);   // 4-8 cycles
  r = x - q*P;             // [0, 2P)
  if (r >= P) r -= P;      // correction
Verified via PTX: 0 div/rem instructions, 1600 mul.hi in the compiled
module.

Also:
- Tie-branch modular reduction deferred to write-back: cnt_accum
  stays in u64 un-reduced during the K loop (bounded by K*P < 2^61).
- Divergent three-way if/else replaced with predicated selects to
  reduce warp divergence.
- Kernel signature adds `unsigned long long MU` arg; host wrapper
  computes it via u128 math.

Benchmark (Quadro RTX 6000, f32 Max, 1 prime):
  size   GPU before   GPU after   speedup
   128   2.11 ms      1.66 ms     +27%
   256   6.68 ms      5.91 ms     +13%
   512   25.9 ms      24.5 ms     +6%
  1024   122  ms      114  ms     +7%
  2048   501  ms      473  ms     +6%
  4096   2.16 s       1.97 s      +10%

Also adds `examples/bench_counting.rs` for future re-runs.

Modest gain (~6-10% on large sizes) confirms the modular reduction
was not the sole bottleneck. Further wins (3-5×) would require
structural changes: async copy (needs sm_80+), warp-level reduction
across k, or vectorized loads.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reduces ptxas register usage from 150 to 64 per thread, fitting the
1024 threads/block exactly in the Turing register file (64*1024 =
65536 = SM reg file size). Zero spills.

Occupancy is NOT the primary bottleneck (74 vs 72 G-ops/s at 4096²
— modest gain). The counting workload is fundamentally bound by:
  - 2x global memory traffic vs MaxPlus (value + count arrays)
  - 2x shared memory pressure (32 KB vs 16 KB per block, halving
    blocks-per-SM)
  - ~5x arithmetic per inner step (Barrett + tie-merge vs FADD+FMAX)

The existing MaxPlus kernel reaches ~1500 G-ops/s at 4096² on the same
RTX 6000; our counting kernel peaks at ~75 G-ops/s. Closing that gap
would need structural changes: packed (val,cnt) element loads,
warp-level K-reductions, or async copy pipelining (requires sm_80+).

Also adds examples/bench_maxplus.rs as a ceiling-reference bench.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After measurement, the hand-tiled BLIS-style counting kernel provided
no perf advantage over GPUArrays.jl-style naive (one thread per output
cell). Compared pure kernel launch times:

  Pure kernel (data on device, f32 Max):
    size      Tiled BLIS   Naive
    128²      0.02 ms      0.02 ms
    1024²     5.1 ms       5.1 ms
    2048²     36 ms        36 ms
    4096²     276 ms       276 ms   (497 G tropical-ops/s)

Reason the tile didn't help: counting's 2x memory traffic + u64
accumulator + Barrett state costs ~150 registers/thread in the tiled
version, capping occupancy at 25%. The naive version uses ~30 regs/
thread and lets L1/L2 caches handle data reuse, which turns out to be
good enough for counting's access pattern (broadcast A-row, coalesced
B-col within a warp).

Comparison to Julia CUDA.jl generic matmul on the same hardware
(Quadro RTX 6000):
    Julia (includes a*b allocation): 158 G tropical-ops/s at 2048²
    Our kernel (pure launch):        475 G tropical-ops/s at 2048²
    -> ~3x faster on the hot path.

Earlier end-to-end bench showed Rust at 37 G-ops/s vs Julia at 158
because count_ground_states_gpu had ~400 ms of host-side BigInt
allocations (4M BigInts at ~100 ns each). That's driver overhead
inherent to the CountedMat<BigInt> return type, not kernel slowness.

Driver cleanups in this commit:
- Allocate count_a, count_b, value_c, count_c ONCE (not per prime).
- Use GpuMatrix::alloc (GPU-side zero) instead of from_host(zeros).
- Skip downloading value_c on subsequent primes (invariant guaranteed
  except under NaN, which we already pinned as known-limitation).
- Debug-only invariant check on prime TensorBFS#2.
- Fast-path Vec<BigInt> construction for single-prime case.

End-to-end driver bench improvement (f32 Max, 1 prime, via
count_ground_states_gpu):
    size     before   after   kernel-only (ceiling)
    128      2.11 ms  1.09 ms  0.02 ms
    256      6.68 ms  4.01 ms  0.10 ms
    512      25.9 ms  17.2 ms  0.66 ms
    1024     122 ms   99 ms    5.1 ms
    2048     501 ms   459 ms   36 ms     <-- BigInt alloc dominates
    4096     2.16 s   1.98 s   276 ms

The remaining driver overhead (400 ms at 2048², 1.7 s at 4096²) is
almost entirely BigInt allocation: 4M cells × ~100 ns heap-alloc per
BigInt. Follow-up optimization: add a Vec<u64> fast-path variant when
count_upper_bound fits in u64 (bound < ~2^60, i.e. num_primes ≤ 2).

Also adds examples/bench_kernel_only.rs and bench_maxplus.rs for
continued performance validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Design doc for the next round of counting-kernel optimization on A100+.
Stage A: 32 threads cooperate per output cell, K-stride inner loop, warp
shuffle reduction with the tropical-add operator. Stage B: pack
(value, count) into 8-byte AoS pairs to halve global loads.

Also adds bench_kernel_single — single-launch variant of the kernel
bench, complementing bench_kernel_only's amortized timing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add warp-K-reduction counting kernel: 32 threads cooperate on each
output cell, K-stride inner loop, 5-step shfl_xor tree reduction with
the tropical-add operator (u64 acc shuffled hi/lo).

Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only):
  - Square shapes (128²–4096²): warpk LOSES 0.15-0.77x vs naive.
    Strided-B reads (32 lanes at fixed j) are non-coalesced; naive's
    coalesced-B pattern dominates whenever M*N is large enough to fill
    the GPU.
  - Tall-skinny / parallelism-starved (M=N=32, K=4096): warpk WINS 9.0x
    (98 vs 11 G/s); M=N=64, K=4096: 3.1x.

Dispatch in launch_counting_gemm: use_warpk = (K>=64) && (M*N<=64*64).
Naive remains the default for square/large; warpk handles small-shape
high-K regime where the SMs are otherwise idle.

4 new tests (boundary, non-aligned, all-ties, Min+f64), 13/13 green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Successor to spec E. Pack (value, count) into 8-byte (f32) or 16-byte
(f64) structs on-device, halving global LDG instructions in the inner
loop. Targets the dominant large-shape regime where naive runs and
where memory pressure on the inner loop is the next bottleneck after
Barrett. Expected ~1.3-1.5x kernel speedup.

Output stays SoA (callers consume value/count separately). Pack runs
once in count_ground_states_gpu before the prime loop, amortized.

Spec includes three open decisions for user review before implementation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pack (value, count) into 8-byte (f32) or 16-byte (f64) structs on-device,
halving global LDG instructions in the counting kernel inner loop. Output
buffers stay SoA. count_ground_states_gpu packs inputs once before the
prime loop, amortized.

Adds:
- src/pair.rs: PairF32/PairF64 with DeviceRepr+ValidAsZeroBits, pack
  helpers, PackPair trait. 7 unit tests for layout + roundtrip.
- 8 new AoS kernels in counting_gemm.cu (4 naive + 4 warpk).
- launch_counting_gemm_aos trait method + free function. Same shape-aware
  dispatch as SoA path (warpk for small M*N high-K, naive otherwise).
- bench_kernel_aos example for AoS-vs-SoA head-to-head.

Driver in crt.rs now packs to AoS and routes through the AoS kernels;
SoA kernels remain reachable via launch_counting_gemm for benchmarking.

Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only):
  Naive path:  1.11-1.24x speedup (4096²: 600 -> 666 G-ops/s)
  Warpk path:  1.64-1.92x speedup (M=N=64,K=4096: 128 -> 246 G/s)

Stacking AoS on spec-E warpk dispatch lifts the small-shape regime by
6.0x total over the original SoA naive kernel.

69/69 tests green (56 lib + 13 counting_gpu integration).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Spec F's rollout plan called for retiring the SoA path once AoS was
measured to win at all shapes. AoS now wins everywhere (1.11x at 4096²
square, 1.92x in the warpk small-shape regime), so:

- Drop 8 SoA kernels (4 naive + 4 warpk) from counting_gemm.cu.
- Rename the 8 _aos kernels to canonical (no suffix). Single layout,
  single source of truth.
- Drop _AOS / _WARPK_AOS trait consts; KERNEL_NAME / KERNEL_NAME_WARPK
  now refer to the AoS kernels.
- Drop SoA `launch_counting_gemm` trait method + free fn. The remaining
  `launch_counting_gemm` takes pair buffers (the former _aos signature).
- Driver in crt.rs updated to canonical name.
- Benches: delete bench_kernel_single, bench_kernel_warpk, bench_kernel_aos
  (head-to-head benches no longer meaningful with one layout). Rewrite
  bench_kernel_only as the canonical AoS bench covering both naive and
  warpk dispatch paths.

Verified on A100-SXM4-80GB: 56 lib + 13 integration = 69/69 green.
Bench numbers unchanged (662 G/s naive @ 4096², 241 G/s warpk @ M=N=64).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Insight: count_ground_states_gpu's only entry point passes all-ones
input counts. The AoS general kernel does u32×u32->u64 multiply +
Barrett mod every k-step, but for ones inputs the count product is
literally 1 — all that arithmetic is wasted. Specialize.

Inner loop becomes: 2 fp loads, 1 add, 1 compare, 1 increment. No
count loads, no multiply, no per-step Barrett. Output count fits in
u32 (max value K) with single Barrett at end. Warp reduction uses
single shuffle per acc_cnt (5 shuffles total) instead of the hi/lo
u64 split (10 shuffles).

Targets 1.6-2.1x kernel speedup on the dominant 4096² square path —
~1100-1400 G-ops/s, approaching the MaxPlus 1500 G/s reference.

AoS general kernels remain in place as fallback for future non-ones
callers (chained matmul, etc.).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Insight from codex review: count_ground_states_gpu's only entry point
passes all-ones input counts. The AoS general kernel does u32×u32->u64
multiply + Barrett mod every k-step, but for ones inputs the count
product is literally 1 — that arithmetic is wasted.

Specialize. Inner loop becomes 2 fp loads + 1 add + 1 compare + 1
conditional set/inc. Output count fits in u32 (max value K), single
Barrett at end. Warp reduction uses one shfl per acc_cnt instead of
the hi/lo u64 split (5 shuffles vs 10).

Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only):
  Naive path:
    128²: 163  -> 348  G/s  (2.13x)
    1024²: 625 -> 1944 G/s  (3.11x)
    2048²: 665 -> 2136 G/s  (3.21x)
    4096²: 665 -> 1946 G/s  (2.93x)
  Warpk path: 1.0-1.2x (memory-bound on strided-B, not count-arith).

Exceeds the prior MaxPlus reference (~1500 G/s, no counting) at >= 512²
square. The ones kernel removes both count loads (halving bandwidth)
and count arithmetic; the freed cycles flip the regime from
mixed-bound to streamlined compute-bound.

Adds:
- 8 new kernels in counting_gemm.cu (naive + warpk × 4 type/dir).
- KERNEL_NAME_ONES + WARPK_ONES trait consts, launch_counting_gemm_ones
  method + free fn taking value-only GpuMatrix<T>.
- count_ground_states_gpu driver routes through ones path; AoS pack
  removed.
- 3 ones-path correctness tests; 16/16 integration tests green.

AoS general kernels stay in tree as fallback for future non-ones
callers (chained matmul, etc.) — currently orphaned in production but
reachable via launch_counting_gemm.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codex's TensorBFS#2 candidate. Current warpk kernel's only memory bottleneck
is non-coalesced B reads (32 lanes at fixed j, strided by N). At
small N this is masked; at moderate N L2 chokes on transactions.

Fix: upload B as B^T (N×K row-major) when the warpk regime is
dispatched. Lanes then read 32 contiguous K elements per warp-step,
identical pattern to A. No kernel-shape change; only the B layout.

Naive path keeps current B layout (already coalesced for that
parallelization). Driver branches on use_warpk and uploads the
appropriate B.

Expected 1.5-2.5x in warpk regime, plus likely expansion of
COUNTING_WARPK_MN_CEILING past 64*64 once warpk is broadly
competitive.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codex's TensorBFS#2 candidate. The warpk kernel's only memory-bound bottleneck
was non-coalesced B reads (32 lanes at fixed j, strided by N). Fix:
upload B as B^T (N×K row-major) when warpk dispatched. Lanes then
read 32 contiguous K elements per warp-step, identical to A.

Naive path keeps non-transposed B (already coalesced for that
parallelization). Driver branches B upload on use_warpk and applies
host-side transpose. transpose_row_major helper added to crt.rs.

Trait method launch_counting_gemm_ones now takes M, K, N explicitly
because B's GpuMatrix shape is layout-dependent.

Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only):
  M=N=64,   K=4096:  122 ->  1324 G/s  (10.9x over naive-ones)
  M=N=128,  K=4096:  480 ->  1903 G/s  (3.96x)
  M=N=512,  K=4096: 1786 ->  2316 G/s  (1.30x)
  M=N=1024, K=4096: 2031 ->  2525 G/s  (1.24x)

Peak hits 2.5 TG/s — beyond the prior MaxPlus reference.

COUNTING_WARPK_MN_CEILING raised 64*64 -> 128*128. Conservative
re-tune: kernel wins extend past 1024² but host-side transpose cost
on single-prime calls limits the profitable ceiling. Multi-prime CRT
calls would win further out; on-device transpose (follow-up) would
push the ceiling to ~256-1024².

Adds warpk_transposed_b_layout test (deliberately asymmetric A/B to
catch transpose-helper bugs). 17/17 integration + 56/56 lib green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rhead

Per memory profiling, BigInt construction dominates e2e at 2048² (~65%
of 459 ms). Kernel itself is now <5% (1946 G/s ones-kernel takes ~70 ms
of e2e). The dominant cost is the output format.

For count_upper_bound < 2^60 (covers any single-matmul case via
bound_for_single_matmul(k)), CRT product fits in u64 with one or two
30-bit primes. Specialize: return Vec<u64> instead of Vec<BigInt>,
zero per-cell allocation.

API: parallel entry point count_ground_states_gpu_u64 + CountedMatU64.
BigInt path stays for general / chained-matmul callers. Driver
refactored to share kernel-loop helper between the two paths.

Targets ~5x e2e speedup (459 ms -> ~80-100 ms at 2048²).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Eliminates per-cell BigInt heap allocation by returning Vec<u64>
counts when the CRT product fits in u63 (≤2 of the 30-bit primes —
covers count_upper_bound < 2^60, all single-matmul use cases).

CPU additions (tropical_gemm::crt):
- crt_combine_u64: pairwise CRT in u64 with extended-Euclidean inverse.
- choose_primes_u64: prime-prefix selection bounded by 2^63.
- bound_for_single_matmul_u64.
- 4 unit tests including BigInt-parity randomized check.

GPU driver (tropical_gemm_cuda::crt):
- run_kernels_per_prime: extracted helper for the shared device side
  (layout choice, B-transpose-or-not, kernel launches, downloads).
  BigInt and u64 entry points both use it.
- count_ground_states_gpu_u64 + CountedMatU64: new entry point.
  Single-prime path skips CRT combine entirely; 2-prime path runs
  pairwise crt_combine_u64.

Measured on A100-SXM4-80GB (f32 Max, single matmul, e2e):
  256²:  3.54  -> 0.26 ms (13.50x)
  512²: 13.92  -> 0.91 ms (15.36x)
  1024²: 63.14 -> 8.38 ms (7.53x)
  2048²: 292.46 -> 29.79 ms (9.82x)
  4096²: 1237.90 -> 152.28 ms (8.13x)

Cumulative e2e improvement at 2048² since branch start: 459 ms -> 30 ms
= ~15x. Combined with kernel-only progress (50 -> 1946 G/s = 39x), the
counting matmul is now genuinely fast end-to-end on A100.

21/21 integration tests + 60/60 lib tests green. BigInt path
unchanged in behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a Julia interface to count_ground_states_gpu_u64 for the four
(T, D) combos. C ABI added to tropical-gemm-cuda (now built as both
rlib + cdylib). Julia package CountingTropicalGEMM.jl/ at repo root,
~100 LOC main module + 50 LOC tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a C ABI to tropical-gemm-cuda (now built as both rlib + cdylib)
exposing count_ground_states_gpu_u64 to non-Rust callers. Four entry
points (f32/f64 × Max/Min) plus version + last-error TLS helpers,
each wrapped in catch_unwind for panic safety across the FFI boundary.

Julia package CountingTropicalGEMM.jl/ at repo root. Library
resolution via TROPICAL_GEMM_LIB env override, workspace dev-build
fallback, or system search paths. Result type CountedMatU64{T}
mirrors the Rust struct. Typed exceptions (BoundTooLargeError for
the u64 envelope, CountingTropicalGEMMError otherwise).

Wrapper transposes Julia's column-major Matrix{T} to row-major at
the FFI boundary; transpose cost is negligible vs kernel time.

Tested 8/8 green on A100-SXM4-80GB:
  f32 Max small (hand-verifiable), f64 Min vs reference (randomized),
  all-ties large K, BoundTooLargeError, DimensionMismatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CountingTropicalGEMM.jl/bench/bench.jl exercises the same shapes as
Rust's bench_e2e_u64. Reports tropical-ops/s through the full Julia
path (transpose, ccall, kernel, reconstruction, output reshape).

Measured on A100-SXM4-80GB:
  256²:  0.69 ms (vs Rust 0.26 ms)
  512²:  6.66 ms (vs Rust 0.91 ms)
  1024²: 39.7 ms (vs Rust 8.4 ms)
  2048²: 120 ms  (vs Rust 30 ms)
  4096²: 441 ms  (vs Rust 152 ms)

Julia overhead scales with M*N — confirmed source is the
column→row-major boundary transpose on inputs + outputs (Julia's
Matrix{T} is column-major; the kernel expects row-major). Output
transpose is the expensive one at 4096²: 16M UInt64 elements = 128 MB
of host memory motion.

Overhead is the wrapper's, not the kernel's. Kernel itself still
hits ~1946 G/s; full Julia path at 4096² is ~312 G-ops/s e2e.
Possible follow-up: layout-flag in C ABI to accept column-major
inputs directly, avoiding host transpose.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tg_bench_kernel_only_u64_<T>_<D> to the C ABI: uploads once,
runs the kernel `iters` times, returns avg per-launch wall time
(ms). Bypasses CRT combine and u64 reconstruction — measures pure
kernel runtime + sync, matching what bench_kernel_only.rs reports.

Julia: bench_kernel_only_u64(dir, A, B, bound; iters) -> Float64.
Bench script rewritten to use this entry point.

Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only via Julia):
  Naive path:
    1024²: 1.112 ms  (1931 G/s)
    2048²: 8.704 ms  (1974 G/s)
    4096²: 70.08 ms  (1961 G/s)
  Warpk path (transposed B, K=4096):
    M=N=64:   0.027 ms (1244 G/s)
    M=N=512:  1.097 ms (1958 G/s)
    M=N=1024: 4.030 ms (2132 G/s) — peak

Numbers match the Rust bench_kernel_only example to within timing
noise, confirming the binding adds no kernel-side overhead.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Confirms throughput scales: 1829 G/s at 16384² × 16384 = 8.8 × 10^12
tropical-ops in 4.81 sec (within ~7% of 4096²'s 1961 G/s peak).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two larger sizes to bench_julia_vs_rust.jl. Confirms
GenericTensorNetworks/CUDA.jl's auto-compiled broadcast kernel for
CountingTropical{Float32, Mod{P}} saturates at ~220-265 G/s across
sizes from 512² to 8192² — vs our specialized kernel at ~1960 G/s.

Measured A100-SXM4-80GB (kernel-only ms / G tropical-ops/s):
  4096²: GTN 614 ms (224 G/s)  vs ours 70 ms (1961 G/s) — 8.8x
  8192²: GTN 4979 ms (221 G/s) vs ours ~560 ms (~1960)  — ~9x

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings April 28, 2026 03:32
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 28, 2026

Codecov Report

❌ Patch coverage is 86.40000% with 34 lines in your changes missing coverage. Please review.
✅ Project coverage is 94.19%. Comparing base (6d22701) to head (64f1b08).

Files with missing lines Patch % Lines
crates/tropical-gemm/src/crt.rs 79.46% 23 Missing ⚠️
crates/tropical-gemm/src/types/modp.rs 83.33% 4 Missing ⚠️
...rates/tropical-gemm/src/testing/bigint_semiring.rs 89.65% 3 Missing ⚠️
crates/tropical-gemm/src/api.rs 93.10% 2 Missing ⚠️
crates/tropical-gemm/src/mat/ref_.rs 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #42      +/-   ##
==========================================
- Coverage   96.41%   94.19%   -2.22%     
==========================================
  Files          18       22       +4     
  Lines         892     1086     +194     
==========================================
+ Hits          860     1023     +163     
- Misses         32       63      +31     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds full-stack support for “counting tropical” GEMM: new semiring direction tags (Max/Min), CRT-based exact counting on CPU and GPU (including a u64 fast-path), plus Python and Julia bindings and accompanying design specs/bench scripts.

Changes:

  • Introduce TropicalDirection (Max/Min) and make CountingTropical direction-aware; add ReprTransparentTropical to safely support scalar-slice reinterpretation for transparent semiring wrappers.
  • Add GPU counting CRT driver (BigInt + u64 fast-path), AoS pair layout helpers, counting CUDA kernels (including ones-specialized + warpk-transposed-B variants), and new CUDA benchmarks.
  • Add Python bindings/tests for CPU+GPU count-ground-states and a Julia package (CountingTropicalGEMM.jl) with tests/benchmarks; add multiple design spec documents.

Reviewed changes

Copilot reviewed 69 out of 71 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
sanity_julia.jl Local Julia sanity script for CountingTropical GPU matmul (currently contains machine-specific paths).
bench_julia_vs_rust.jl Julia benchmark to compare CUDA.jl generic CountingTropical matmul vs Rust (currently contains machine-specific paths).
docs/superpowers/specs/2026-04-27-julia-binding-design.md Design spec for Julia binding architecture and FFI surface.
docs/superpowers/specs/2026-04-27-counting-u64-fastpath-design.md Design spec for u64 fast-path in GPU driver + CRT helpers.
docs/superpowers/specs/2026-04-27-counting-kernel-warpk-transposed-design.md Design spec for warpk kernel reading transposed B for coalesced loads.
docs/superpowers/specs/2026-04-27-counting-kernel-ones-specialized-design.md Design spec for ones-specialized counting kernels.
docs/superpowers/specs/2026-04-27-counting-kernel-aos-design.md Design spec for AoS (value,count) layout for counting kernels.
docs/superpowers/specs/2026-04-21-counting-tropical-cuda-design.md Prior design spec for GPU counting kernel + CRT driver.
docs/superpowers/specs/2026-04-21-counting-tropical-crt-design.md Prior design spec for CRT/BigInt CPU driver and bound contract.
docs/superpowers/specs/2026-04-21-counting-tropical-compose-design.md Prior design spec for making CountingTropical compose with GEMM pipeline.
docs/superpowers/specs/2026-04-21-counting-kernel-tiled-design.md Prior design spec for tiled CUDA counting kernel.
crates/tropical-gemm/tests/counting_crt.rs New/updated CRT correctness tests vs BigInt oracle + edge cases.
crates/tropical-gemm/tests/counting_compose.rs Tests for CountingTropical composing through tropical_matmul_t (Max/Min, ties, multiplication).
crates/tropical-gemm/src/types/traits.rs Adds ReprTransparentTropical marker trait for safe scalar↔element reinterpretation.
crates/tropical-gemm/src/types/modp.rs Adds Mod<P> scalar for CRT residue arithmetic.
crates/tropical-gemm/src/types/mod.rs Wires new direction + Mod modules and re-exports Max/Min/Mod/ReprTransparentTropical.
crates/tropical-gemm/src/types/min_plus.rs Implements ReprTransparentTropical for TropicalMinPlus.
crates/tropical-gemm/src/types/max_plus.rs Implements ReprTransparentTropical for TropicalMaxPlus.
crates/tropical-gemm/src/types/max_mul.rs Implements ReprTransparentTropical for TropicalMaxMul.
crates/tropical-gemm/src/types/direction.rs Introduces TropicalDirection with Max and Min tags.
crates/tropical-gemm/src/types/counting.rs Makes CountingTropical direction-aware (Max/Min), disables SIMD claim, adds Min tests.
crates/tropical-gemm/src/types/and_or.rs Implements ReprTransparentTropical for TropicalAndOr.
crates/tropical-gemm/src/testing/mod.rs Adds test-only module wiring (feature-gated).
crates/tropical-gemm/src/testing/bigint_semiring.rs Adds BigInt-based reference implementation for CRT tests.
crates/tropical-gemm/src/simd/kernels/portable.rs Updates microkernel signatures to take *const T and fixes tests accordingly.
crates/tropical-gemm/src/simd/kernels/neon.rs Updates NEON kernels to accept *const T and casts internally to scalar pointers.
crates/tropical-gemm/src/simd/kernels/avx2.rs Updates AVX2 kernels to accept *const T and casts internally to scalar pointers; updates tests.
crates/tropical-gemm/src/mat/ref_.rs Refactors MatRef to store element slices, adds safe scalar conversion via ReprTransparentTropical.
crates/tropical-gemm/src/mat/owned.rs Refactors Mat APIs to use element slices; adds from_elements convenience alias.
crates/tropical-gemm/src/mat/ops.rs Updates multiplication trait bounds to include Default.
crates/tropical-gemm/src/lib.rs Exposes new CRT module and Max/Min exports; wires test helpers.
crates/tropical-gemm/src/core/packing.rs Packing now operates on element type T (requires Copy + Default) instead of TropicalScalar.
crates/tropical-gemm/src/core/kernel.rs Microkernel interface now uses *const T rather than *const T::Scalar.
crates/tropical-gemm/src/api.rs Adds tropical_matmul_t for compound elements; updates scalar APIs to require ReprTransparentTropical.
crates/tropical-gemm/Cargo.toml Adds testing feature and BigInt/Integer deps.
crates/tropical-gemm-python/tests/test_count_ground_states_gpu.py Adds GPU round-trip tests for counting.
crates/tropical-gemm-python/tests/test_count_ground_states.py Adds/updates CPU round-trip tests for counting + typed errors.
crates/tropical-gemm-python/src/lib.rs Adds CPU + GPU counting bindings (CRT BigInt counts).
crates/tropical-gemm-python/Cargo.toml Adds num-bigint dependency.
crates/tropical-gemm-cuda/src/pair.rs Adds AoS pair types + packing helpers + unit tests.
crates/tropical-gemm-cuda/src/lib.rs Exposes counting CRT APIs + CountedMatU64; wires new modules.
crates/tropical-gemm-cuda/src/gpu_mat.rs Updates MatRef upload path to require ReprTransparentTropical.
crates/tropical-gemm-cuda/src/error.rs Adds InvalidState error variant.
crates/tropical-gemm-cuda/src/crt.rs Implements GPU CRT driver (BigInt + u64 fast-path) and shared per-prime runner.
crates/tropical-gemm-cuda/src/counting_kernel.rs Adds counting kernel launcher and dispatch, including ones-specialized path.
crates/tropical-gemm-cuda/src/context.rs Adds compilation/registration of counting kernels and warpk dispatch constants/dims.
crates/tropical-gemm-cuda/examples/bench_warpk_crossover.rs Bench for naive vs warpk (transposed-B) crossover.
crates/tropical-gemm-cuda/examples/bench_maxplus.rs Bench for MaxPlus GPU kernel throughput reference.
crates/tropical-gemm-cuda/examples/bench_kernel_only.rs Kernel-only bench for AoS vs ones-specialized kernels.
crates/tropical-gemm-cuda/examples/bench_e2e_u64.rs End-to-end bench: BigInt vs u64 fast-path.
crates/tropical-gemm-cuda/examples/bench_counting.rs End-to-end bench: GPU vs CPU counting.
crates/tropical-gemm-cuda/Cargo.toml Builds as cdylib + adds BigInt-related deps.
CountingTropicalGEMM.jl/test/runtests.jl Julia package tests incl. reference parity, bounds, and dimension mismatch.
CountingTropicalGEMM.jl/src/CountingTropicalGEMM.jl Julia FFI wrapper to C ABI with library resolution and error mapping.
CountingTropicalGEMM.jl/bench/bench_huge.jl Julia kernel-only huge-shape benchmark.
CountingTropicalGEMM.jl/bench/bench.jl Julia kernel-only benchmark driver.
CountingTropicalGEMM.jl/README.md Julia package build/install/usage docs.
CountingTropicalGEMM.jl/Project.toml Julia package metadata.
Cargo.lock Locks new Rust dependencies (num-bigint/num-integer).
.gitignore Ignores Julia Manifest for the new Julia package.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +233 to +234
assert_eq!(a_values.len(), m * k);
assert_eq!(b_values.len(), k * n);
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the BigInt path: avoid assert_eq! on user-controlled input lengths in this Result-returning API; return a CudaError instead of panicking.

Suggested change
assert_eq!(a_values.len(), m * k);
assert_eq!(b_values.len(), k * n);
if a_values.len() != m * k {
return Err(CudaError::InvalidState(
format!(
"invalid a_values length: got {}, expected {} (m * k)",
a_values.len(),
m * k
)
.into(),
));
}
if b_values.len() != k * n {
return Err(CudaError::InvalidState(
format!(
"invalid b_values length: got {}, expected {} (k * n)",
b_values.len(),
k * n
)
.into(),
));
}

Copilot uses AI. Check for mistakes.
Comment thread sanity_julia.jl
Comment on lines +3 to +4
Pkg.add(["CUDA", "TropicalNumbers"])
Pkg.add(path = "/mnt/home/xgao1/work/better_gpu_gemm/GenericTensorNetworks.jl")
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This script hard-codes a machine-specific local path (Pkg.add(path = "/mnt/home/.../GenericTensorNetworks.jl")), which will fail for other developers/CI. Consider removing this file from the repo, or make the path configurable (e.g., via ENV) / use a relative path / documented Pkg.develop instructions instead.

Copilot uses AI. Check for mistakes.
Comment thread bench_julia_vs_rust.jl
Comment on lines +13 to +14
Pkg.add(path = "/mnt/home/xgao1/work/better_gpu_gemm/GenericTensorNetworks.jl")

Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This benchmark script hard-codes a local filesystem path for GenericTensorNetworks.jl, which makes it non-portable and likely to break in other environments. Prefer documenting setup steps and using Pkg.develop/relative paths or environment configuration rather than committing an absolute path.

Suggested change
Pkg.add(path = "/mnt/home/xgao1/work/better_gpu_gemm/GenericTensorNetworks.jl")
if !haskey(ENV, "GENERIC_TENSOR_NETWORKS_PATH")
error("Set GENERIC_TENSOR_NETWORKS_PATH to a local GenericTensorNetworks.jl checkout before running this benchmark.")
end
gtn_path = abspath(expanduser(ENV["GENERIC_TENSOR_NETWORKS_PATH"]))
isdir(gtn_path) || error("GENERIC_TENSOR_NETWORKS_PATH does not exist or is not a directory: " * gtn_path)
Pkg.develop(path = gtn_path)

Copilot uses AI. Check for mistakes.
Comment on lines +17 to +19
/// The inner `i32` is always in `[0, P)` (the normalized representative).
/// Construct via `Mod::new` (which normalizes) or reconstruct from a raw
/// representative via `raw`. See module docs for the size constraint on `P`.
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comment is misleading: raw() is an accessor, not a way to reconstruct/create a Mod<P> from an existing raw representative. Consider rewording to avoid implying there is a from_raw constructor, or add an explicit from_raw_unchecked/from_raw API if that capability is intended.

Copilot uses AI. Check for mistakes.
Comment on lines +2626 to +2633
let ctx = CudaContext::new().map_err(|e| format!("CUDA init: {}", e))?;
match direction {
"max" => count_ground_states_gpu::<f32, Max>(
&ctx, &a_data, m, k, &b_data, n, &bound,
)
.map_err(|e| format!("GPU compute: {}", e)),
"min" => count_ground_states_gpu::<f32, Min>(
&ctx, &a_data, m, k, &b_data, n, &bound,
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count_ground_states_gpu_py constructs a fresh CudaContext::new() on every call, which will redo NVRTC compilation/context setup and defeats the crate’s existing global-context caching. Use the shared/global context (or a cached context per process/device) instead of creating a new one per invocation.

Suggested change
let ctx = CudaContext::new().map_err(|e| format!("CUDA init: {}", e))?;
match direction {
"max" => count_ground_states_gpu::<f32, Max>(
&ctx, &a_data, m, k, &b_data, n, &bound,
)
.map_err(|e| format!("GPU compute: {}", e)),
"min" => count_ground_states_gpu::<f32, Min>(
&ctx, &a_data, m, k, &b_data, n, &bound,
// Reuse the process-global CUDA context so we do not pay
// initialization / NVRTC compilation costs on every call.
let ctx = CudaContext::global().map_err(|e| format!("CUDA init: {}", e))?;
match direction {
"max" => count_ground_states_gpu::<f32, Max>(
ctx, &a_data, m, k, &b_data, n, &bound,
)
.map_err(|e| format!("GPU compute: {}", e)),
"min" => count_ground_states_gpu::<f32, Min>(
ctx, &a_data, m, k, &b_data, n, &bound,

Copilot uses AI. Check for mistakes.
Comment on lines +2651 to +2664
let counts_py: Vec<PyObject> = result
.counts
.into_iter()
.map(|bn| {
let s = bn.to_string();
// Use Python's built-in int() on the decimal string.
py.eval(
std::ffi::CString::new(format!("int({})", s))
.unwrap()
.as_c_str(),
None,
None,
)
.map(|b| b.unbind())
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting BigInt to Python int via py.eval("int(<decimal>)") is unnecessarily slow (parses Python code per element) and relies on code-eval. Prefer constructing PyLong directly (e.g., via the CPython API PyLong_FromString / PyLong_FromUnicodeObject, or calling builtins.int with the decimal string without eval).

Suggested change
let counts_py: Vec<PyObject> = result
.counts
.into_iter()
.map(|bn| {
let s = bn.to_string();
// Use Python's built-in int() on the decimal string.
py.eval(
std::ffi::CString::new(format!("int({})", s))
.unwrap()
.as_c_str(),
None,
None,
)
.map(|b| b.unbind())
let py_int = py.import("builtins")?.getattr("int")?;
let counts_py: Vec<PyObject> = result
.counts
.into_iter()
.map(|bn| {
let s = bn.to_string();
py_int.call1((s,)).map(|b| b.unbind())

Copilot uses AI. Check for mistakes.
Comment on lines +2791 to +2806
// Convert BigInt → Python int via decimal string representation.
let counts_py: Vec<PyObject> = result
.counts
.into_iter()
.map(|bn| {
let s = bn.to_string();
// Use Python's built-in int() on the decimal string.
py.eval(
std::ffi::CString::new(format!("int({})", s))
.unwrap()
.as_c_str(),
None,
None,
)
.map(|b| b.unbind())
})
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the GPU binding: BigInt → Python int conversion uses py.eval("int(<decimal>)"), which is very expensive for large outputs and should be replaced with direct PyLong creation (no eval).

Suggested change
// Convert BigInt → Python int via decimal string representation.
let counts_py: Vec<PyObject> = result
.counts
.into_iter()
.map(|bn| {
let s = bn.to_string();
// Use Python's built-in int() on the decimal string.
py.eval(
std::ffi::CString::new(format!("int({})", s))
.unwrap()
.as_c_str(),
None,
None,
)
.map(|b| b.unbind())
})
// Convert BigInt → Python int directly via CPython's PyLong API.
fn bigint_to_pyobject(py: Python<'_>, bn: BigInt) -> PyResult<PyObject> {
let bytes = bn.to_signed_bytes_le();
let ptr = unsafe {
pyo3::ffi::_PyLong_FromByteArray(
bytes.as_ptr(),
bytes.len(),
1,
1,
)
};
unsafe { Bound::from_owned_ptr_or_err(py, ptr).map(|obj| obj.unbind()) }
}
let counts_py: Vec<PyObject> = result
.counts
.into_iter()
.map(|bn| bigint_to_pyobject(py, bn))

Copilot uses AI. Check for mistakes.
Comment on lines +162 to +163
assert_eq!(a_values.len(), m * k);
assert_eq!(b_values.len(), k * n);
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This public API uses assert_eq! for input length validation, which will panic on user error. For a Result-returning GPU entry point, prefer returning a structured error (e.g., CudaError::DimensionMismatch / CudaError::InvalidState) so callers (including Python/Julia bindings) can handle it without a panic.

Suggested change
assert_eq!(a_values.len(), m * k);
assert_eq!(b_values.len(), k * n);
if a_values.len() != m * k {
return Err(CudaError::DimensionMismatch);
}
if b_values.len() != k * n {
return Err(CudaError::DimensionMismatch);
}

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants