Counting tropical#42
Conversation
CPU-first plan for exact ground-state counting via CountingTropical with Mod<P> count field and host-side Chinese Remainder reconstruction to BigInt. Mirrors GenericTensorNetworks.jl big_integer_solve pattern. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Addresses codex review: the original single spec conflated an ABI change (thread T end-to-end, drop the T::Scalar repr-transparent transmute) with the CRT/BigInt driver and Python surface. Split per codex's scope recommendation. Spec A: make Mat<CountingTropical<T, C, D>> work through the existing GEMM pipeline. Parameterize direction via Max/Min marker trait. No CRT, no BigInt, no Python changes. Count type is plain u64. Spec B (new file): Mod<const P: i32> count scalar, CRT driver with a caller-supplied count_upper_bound for unique reconstruction (fixing the unsound consecutive-equality termination criterion), Python binding. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
14-task plan: Phase 1 adds the TropicalDirection marker and parameterizes CountingTropical by it; Phase 2 widens the internal GEMM ABI from *const T::Scalar to *const T (atomic across kernel/packing/dispatch/mat, single commit at Task 10); Phase 3 wires CountingTropical through and adds end-to-end integration tests for both Max and Min directions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a third generic D: TropicalDirection to CountingTropical. The Max default preserves source compatibility for all existing call sites. tropical_zero, tropical_add, and tropical_add_argmax route through D to support both Max (ground state = largest) and Min (ground state = smallest) semantics. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous SIMD_AVAILABLE=true, SIMD_WIDTH=8 claim was aspirational: no SIMD kernel exists for CountingTropical. Set to (false, 1) to reflect reality. Vectorized counting will be a later follow-up once an SoA-capable kernel lands. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Encodes the repr(transparent) layout contract as an unsafe marker trait. Implemented for MaxPlus, MinPlus, MaxMul, AndOr. Will be used in Phase 2 to let the scalar-slice public API safely reinterpret &[S::Scalar] as &[S] after the kernel ABI widens from *const T::Scalar to *const T. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…Phase 2, Tasks 5-10) - Microkernel::execute now takes *const T / *mut T instead of scalar pointers; SIMD kernels cast internally using repr(transparent) guarantee - pack_a/pack_b relaxed from TropicalScalar to Copy + Default; packed buffers are Vec<T> - tropical_gemm_portable/inner and argmax variants thread *const T end-to-end - KernelDispatch::dispatch_gemm and tropical_gemm_dispatch accept *const T - MatRef stores &[S] instead of &[S::Scalar]; from_slice/as_slice kept with ReprTransparentTropical bound; from_elements/as_element_slice added for &[S] - Public API (tropical_matmul, tropical_matmul_with_argmax, TropicalGemm, batched variants) adds ReprTransparentTropical + Default bounds and casts scalar slices on entry - New tropical_matmul_t<T>(a: &[T], ...) -> Vec<T> function added and re-exported - ops.rs Mul impls and Mat::matmul_batched impl updated with + Default bound - tropical-gemm-cuda from_matref/from_mats updated with ReprTransparentTropical bound All 281 lib tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Blanket KernelDispatch impl for CountingTropical<T, C, D>: all variants route to the portable (non-SIMD) kernel, which is the only kernel that supports compound elements today. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Covers both Max and Min directions through the public tropical_matmul_t entry point: small hand-checked matmul, tie merging, and count multiplication along a single k path. Also export Max and Min direction markers from lib.rs for public API. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The type alias MaxPlus cannot be used as a tuple-struct constructor; .map(MaxPlus) fails with E0423. Use the underlying TropicalMaxPlus. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
12-task plan: Phase 1 adds Mod<P> scalar; Phase 2 adds num-bigint and CRT primitives (prime table, pairwise combine); Phase 3 adds the count_ground_states driver with caller-supplied upper bound; Phase 4 adds correctness tests including a BigInt oracle semiring; Phase 5 adds the Python binding (build gated on cluster Python 3.7+); Phase 6 runs the final regression gate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Residue type with normalizing constructor, modular add/mul using i64 intermediates, and raw accessor. P must satisfy (P-1)^2 < 2^62 to keep scalar_mul overflow-free. No TropicalScalar impl yet — added in the next commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Routes scalar_add/scalar_mul through modular arithmetic, returns 0/1 for identities. pos_infinity/neg_infinity/scalar_max/scalar_min panic with a clear message — Mod<P> is only valid in the count field of CountingTropical, never as a tropical value. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CRT_PRIMES: 16 distinct 30-bit primes satisfying (P-1)^2 < 2^60. crt_combine: fold one (residue, prime) into a running (value, modulus) accumulator via the standard extended-gcd inverse method. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Takes a caller-supplied BigInt count_upper_bound, selects the smallest subset of CRT_PRIMES whose product exceeds 2*bound (ensuring unique reconstruction), runs spec-A matmul once per prime with Mod<P> count, and folds residues into per-cell BigInt via crt_combine. Asserts the tropical value field is identical across primes — any divergence indicates an internal invariant bug, not a numerical issue. Also exports bound_for_single_matmul(k: usize) -> BigInt for the common case where both input matrices have per-cell counts of 1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Test-only gadget exposing a slow but exact BigInt-count matmul. Used to verify the CRT driver produces the same counts as a direct BigInt accumulator computation. Module is gated on cfg(test) or the 'testing' feature so it does not bloat the public surface. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cross-checks count_ground_states against reference_matmul on small random matrices (both Max and Min), an all-ties corner case, and a minimal 1x1 sanity shape. Test requires the 'testing' feature. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Forces a BigInt bound exceeding u128::MAX so the driver must use multiple CRT_PRIMES. Asserts the reconstructed count still equals the true (small) count — exercises the multi-prime fold path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NaN comparisons are reflexively false, so they don't satisfy the is_strictly_better condition in either direction. This triggers the tie path, confirming the driver handles NaN without corrupting results. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The original plan claimed NaN inputs would trip the cross-prime value invariant and panic. That assumption was wrong: is_strictly_better is '>' which is false for both NaN>x and x>NaN, so every NaN candidate takes the tie branch. With the accumulator starting at -inf for Max, NaN values merge into the -inf count and get discarded on the first real comparison — no panic, and the final counts silently ignore the NaN input. Test renamed and rewritten to pin that behavior so any future change is intentional. Real NaN safety is a separate question worth revisiting if it matters to users. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Python API:
count_ground_states_py(a, b, direction='min', count_upper_bound=None)
-> (values: ndarray[float32], counts: ndarray[object] of Python int)
Default bound uses bound_for_single_matmul(k). Validates shape compat
and direction string. Releases the GIL during the CRT heavy compute.
Build requires python >= 3.7; load via `module load python` on the
cluster.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Skips gracefully if the extension is not installed (e.g., when running from a plain cargo checkout without maturin develop). Covers trivial 1x1, tie merging in both directions, Python-int (not numpy int) return type, and validation of direction. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The counting_crt.rs integration test imports from tropical_gemm::testing, which is only available when the testing feature is enabled. Without this gate, cargo test with default features would fail to compile. Now: - cargo test -p tropical-gemm --features testing: all 335 tests pass - cargo test -p tropical-gemm: 333 tests pass (excludes 2 gated by testing) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
SoA element layout (two GpuMatrix buffers: values + counts) to reuse existing GpuMatrix<Scalar> infrastructure. Four kernel instantiations (f32/f64 × Max/Min) stamped via preprocessor macro, matching the existing tropical_gemm.cu pattern. Modulus P stays a runtime argument to avoid 16x kernel cache pressure. GPU CRT driver mirrors the CPU spec-B driver: upload value matrices once, launch per-prime, download residues, reconstruct to BigInt host-side. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
8-task plan across 7 phases: Task 1 promotes CPU CRT helpers to pub; Tasks 2-3 add counting_gemm.cu and wire it into CudaContext; Task 4 adds the Rust launch wrapper; Task 5 adds count_ground_states_gpu and cross-check tests; Task 6-7 add the Python binding; Task 8 is the final regression gate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tropical-gemm-cuda will reuse these helpers verbatim in its GPU CRT driver (spec C). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Four specializations via preprocessor macro: {f32, f64} × {Max, Min}.
SoA layout (parallel value + count pointers), runtime modulus P.
One thread per output cell, 2D grid.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a second NVRTC compile + load_ptx call for the 4 new counting kernels alongside the existing tropical_gemm module. Adds counting_grid_dims / counting_block_dims helpers for launch config. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5-task plan: Task 1 rewrites counting_gemm.cu as a BLIS-style tiled macro; Tasks 2-3 wire the new grid/block dims per T via a launch_dims trait method (landed together since they depend on each other); Task 4 adds large, off-boundary, and f64 medium cross-check tests; Task 5 is the final regression gate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two-level blocking mirroring tropical_gemm.cu: shared-memory block tile (64x32x64 for f32, 32x16x32 for f64) with a 4x4 register tile per thread. Four parallel tiles per block (A value, A count, B value, B count). Four kernel specializations via macro: f32/f64 × Max/Min. Replaces the naive one-thread-per-cell kernel. Same kernel names, same API, same signature — internal change only. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Split counting_block_dims / counting_grid_dims into _f32 and _f64 variants matching the tile sizes of the tiled kernel (64x64 for f32, 32x32 for f64). Add a launch_dims associated method on CountingCudaKernel so the launch wrapper dispatches per (T, D) via the type system rather than branching on T at runtime. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- 512x512x512 f32 Max: exercises the tile loop over many iterations. - 17x19x23 f32 Max: every dim prime, stresses predicated tile-load bounds checks at every edge simultaneously. - 128x128x128 f64 Max: validates the f64 macro on a non-trivial shape. All three cross-check against the CPU count_ground_states oracle. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces per-step software `x % P` (~50 cycles on Turing) with Barrett reduction using a host-precomputed reciprocal mu = floor(2^64 / P): q = __umul64hi(x, mu); // 4-8 cycles r = x - q*P; // [0, 2P) if (r >= P) r -= P; // correction Verified via PTX: 0 div/rem instructions, 1600 mul.hi in the compiled module. Also: - Tie-branch modular reduction deferred to write-back: cnt_accum stays in u64 un-reduced during the K loop (bounded by K*P < 2^61). - Divergent three-way if/else replaced with predicated selects to reduce warp divergence. - Kernel signature adds `unsigned long long MU` arg; host wrapper computes it via u128 math. Benchmark (Quadro RTX 6000, f32 Max, 1 prime): size GPU before GPU after speedup 128 2.11 ms 1.66 ms +27% 256 6.68 ms 5.91 ms +13% 512 25.9 ms 24.5 ms +6% 1024 122 ms 114 ms +7% 2048 501 ms 473 ms +6% 4096 2.16 s 1.97 s +10% Also adds `examples/bench_counting.rs` for future re-runs. Modest gain (~6-10% on large sizes) confirms the modular reduction was not the sole bottleneck. Further wins (3-5×) would require structural changes: async copy (needs sm_80+), warp-level reduction across k, or vectorized loads. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reduces ptxas register usage from 150 to 64 per thread, fitting the
1024 threads/block exactly in the Turing register file (64*1024 =
65536 = SM reg file size). Zero spills.
Occupancy is NOT the primary bottleneck (74 vs 72 G-ops/s at 4096²
— modest gain). The counting workload is fundamentally bound by:
- 2x global memory traffic vs MaxPlus (value + count arrays)
- 2x shared memory pressure (32 KB vs 16 KB per block, halving
blocks-per-SM)
- ~5x arithmetic per inner step (Barrett + tie-merge vs FADD+FMAX)
The existing MaxPlus kernel reaches ~1500 G-ops/s at 4096² on the same
RTX 6000; our counting kernel peaks at ~75 G-ops/s. Closing that gap
would need structural changes: packed (val,cnt) element loads,
warp-level K-reductions, or async copy pipelining (requires sm_80+).
Also adds examples/bench_maxplus.rs as a ceiling-reference bench.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After measurement, the hand-tiled BLIS-style counting kernel provided
no perf advantage over GPUArrays.jl-style naive (one thread per output
cell). Compared pure kernel launch times:
Pure kernel (data on device, f32 Max):
size Tiled BLIS Naive
128² 0.02 ms 0.02 ms
1024² 5.1 ms 5.1 ms
2048² 36 ms 36 ms
4096² 276 ms 276 ms (497 G tropical-ops/s)
Reason the tile didn't help: counting's 2x memory traffic + u64
accumulator + Barrett state costs ~150 registers/thread in the tiled
version, capping occupancy at 25%. The naive version uses ~30 regs/
thread and lets L1/L2 caches handle data reuse, which turns out to be
good enough for counting's access pattern (broadcast A-row, coalesced
B-col within a warp).
Comparison to Julia CUDA.jl generic matmul on the same hardware
(Quadro RTX 6000):
Julia (includes a*b allocation): 158 G tropical-ops/s at 2048²
Our kernel (pure launch): 475 G tropical-ops/s at 2048²
-> ~3x faster on the hot path.
Earlier end-to-end bench showed Rust at 37 G-ops/s vs Julia at 158
because count_ground_states_gpu had ~400 ms of host-side BigInt
allocations (4M BigInts at ~100 ns each). That's driver overhead
inherent to the CountedMat<BigInt> return type, not kernel slowness.
Driver cleanups in this commit:
- Allocate count_a, count_b, value_c, count_c ONCE (not per prime).
- Use GpuMatrix::alloc (GPU-side zero) instead of from_host(zeros).
- Skip downloading value_c on subsequent primes (invariant guaranteed
except under NaN, which we already pinned as known-limitation).
- Debug-only invariant check on prime TensorBFS#2.
- Fast-path Vec<BigInt> construction for single-prime case.
End-to-end driver bench improvement (f32 Max, 1 prime, via
count_ground_states_gpu):
size before after kernel-only (ceiling)
128 2.11 ms 1.09 ms 0.02 ms
256 6.68 ms 4.01 ms 0.10 ms
512 25.9 ms 17.2 ms 0.66 ms
1024 122 ms 99 ms 5.1 ms
2048 501 ms 459 ms 36 ms <-- BigInt alloc dominates
4096 2.16 s 1.98 s 276 ms
The remaining driver overhead (400 ms at 2048², 1.7 s at 4096²) is
almost entirely BigInt allocation: 4M cells × ~100 ns heap-alloc per
BigInt. Follow-up optimization: add a Vec<u64> fast-path variant when
count_upper_bound fits in u64 (bound < ~2^60, i.e. num_primes ≤ 2).
Also adds examples/bench_kernel_only.rs and bench_maxplus.rs for
continued performance validation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Design doc for the next round of counting-kernel optimization on A100+. Stage A: 32 threads cooperate per output cell, K-stride inner loop, warp shuffle reduction with the tropical-add operator. Stage B: pack (value, count) into 8-byte AoS pairs to halve global loads. Also adds bench_kernel_single — single-launch variant of the kernel bench, complementing bench_kernel_only's amortized timing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add warp-K-reduction counting kernel: 32 threads cooperate on each
output cell, K-stride inner loop, 5-step shfl_xor tree reduction with
the tropical-add operator (u64 acc shuffled hi/lo).
Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only):
- Square shapes (128²–4096²): warpk LOSES 0.15-0.77x vs naive.
Strided-B reads (32 lanes at fixed j) are non-coalesced; naive's
coalesced-B pattern dominates whenever M*N is large enough to fill
the GPU.
- Tall-skinny / parallelism-starved (M=N=32, K=4096): warpk WINS 9.0x
(98 vs 11 G/s); M=N=64, K=4096: 3.1x.
Dispatch in launch_counting_gemm: use_warpk = (K>=64) && (M*N<=64*64).
Naive remains the default for square/large; warpk handles small-shape
high-K regime where the SMs are otherwise idle.
4 new tests (boundary, non-aligned, all-ties, Min+f64), 13/13 green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Successor to spec E. Pack (value, count) into 8-byte (f32) or 16-byte (f64) structs on-device, halving global LDG instructions in the inner loop. Targets the dominant large-shape regime where naive runs and where memory pressure on the inner loop is the next bottleneck after Barrett. Expected ~1.3-1.5x kernel speedup. Output stays SoA (callers consume value/count separately). Pack runs once in count_ground_states_gpu before the prime loop, amortized. Spec includes three open decisions for user review before implementation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pack (value, count) into 8-byte (f32) or 16-byte (f64) structs on-device, halving global LDG instructions in the counting kernel inner loop. Output buffers stay SoA. count_ground_states_gpu packs inputs once before the prime loop, amortized. Adds: - src/pair.rs: PairF32/PairF64 with DeviceRepr+ValidAsZeroBits, pack helpers, PackPair trait. 7 unit tests for layout + roundtrip. - 8 new AoS kernels in counting_gemm.cu (4 naive + 4 warpk). - launch_counting_gemm_aos trait method + free function. Same shape-aware dispatch as SoA path (warpk for small M*N high-K, naive otherwise). - bench_kernel_aos example for AoS-vs-SoA head-to-head. Driver in crt.rs now packs to AoS and routes through the AoS kernels; SoA kernels remain reachable via launch_counting_gemm for benchmarking. Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only): Naive path: 1.11-1.24x speedup (4096²: 600 -> 666 G-ops/s) Warpk path: 1.64-1.92x speedup (M=N=64,K=4096: 128 -> 246 G/s) Stacking AoS on spec-E warpk dispatch lifts the small-shape regime by 6.0x total over the original SoA naive kernel. 69/69 tests green (56 lib + 13 counting_gpu integration). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Spec F's rollout plan called for retiring the SoA path once AoS was measured to win at all shapes. AoS now wins everywhere (1.11x at 4096² square, 1.92x in the warpk small-shape regime), so: - Drop 8 SoA kernels (4 naive + 4 warpk) from counting_gemm.cu. - Rename the 8 _aos kernels to canonical (no suffix). Single layout, single source of truth. - Drop _AOS / _WARPK_AOS trait consts; KERNEL_NAME / KERNEL_NAME_WARPK now refer to the AoS kernels. - Drop SoA `launch_counting_gemm` trait method + free fn. The remaining `launch_counting_gemm` takes pair buffers (the former _aos signature). - Driver in crt.rs updated to canonical name. - Benches: delete bench_kernel_single, bench_kernel_warpk, bench_kernel_aos (head-to-head benches no longer meaningful with one layout). Rewrite bench_kernel_only as the canonical AoS bench covering both naive and warpk dispatch paths. Verified on A100-SXM4-80GB: 56 lib + 13 integration = 69/69 green. Bench numbers unchanged (662 G/s naive @ 4096², 241 G/s warpk @ M=N=64). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Insight: count_ground_states_gpu's only entry point passes all-ones input counts. The AoS general kernel does u32×u32->u64 multiply + Barrett mod every k-step, but for ones inputs the count product is literally 1 — all that arithmetic is wasted. Specialize. Inner loop becomes: 2 fp loads, 1 add, 1 compare, 1 increment. No count loads, no multiply, no per-step Barrett. Output count fits in u32 (max value K) with single Barrett at end. Warp reduction uses single shuffle per acc_cnt (5 shuffles total) instead of the hi/lo u64 split (10 shuffles). Targets 1.6-2.1x kernel speedup on the dominant 4096² square path — ~1100-1400 G-ops/s, approaching the MaxPlus 1500 G/s reference. AoS general kernels remain in place as fallback for future non-ones callers (chained matmul, etc.). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Insight from codex review: count_ground_states_gpu's only entry point
passes all-ones input counts. The AoS general kernel does u32×u32->u64
multiply + Barrett mod every k-step, but for ones inputs the count
product is literally 1 — that arithmetic is wasted.
Specialize. Inner loop becomes 2 fp loads + 1 add + 1 compare + 1
conditional set/inc. Output count fits in u32 (max value K), single
Barrett at end. Warp reduction uses one shfl per acc_cnt instead of
the hi/lo u64 split (5 shuffles vs 10).
Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only):
Naive path:
128²: 163 -> 348 G/s (2.13x)
1024²: 625 -> 1944 G/s (3.11x)
2048²: 665 -> 2136 G/s (3.21x)
4096²: 665 -> 1946 G/s (2.93x)
Warpk path: 1.0-1.2x (memory-bound on strided-B, not count-arith).
Exceeds the prior MaxPlus reference (~1500 G/s, no counting) at >= 512²
square. The ones kernel removes both count loads (halving bandwidth)
and count arithmetic; the freed cycles flip the regime from
mixed-bound to streamlined compute-bound.
Adds:
- 8 new kernels in counting_gemm.cu (naive + warpk × 4 type/dir).
- KERNEL_NAME_ONES + WARPK_ONES trait consts, launch_counting_gemm_ones
method + free fn taking value-only GpuMatrix<T>.
- count_ground_states_gpu driver routes through ones path; AoS pack
removed.
- 3 ones-path correctness tests; 16/16 integration tests green.
AoS general kernels stay in tree as fallback for future non-ones
callers (chained matmul, etc.) — currently orphaned in production but
reachable via launch_counting_gemm.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codex's TensorBFS#2 candidate. Current warpk kernel's only memory bottleneck is non-coalesced B reads (32 lanes at fixed j, strided by N). At small N this is masked; at moderate N L2 chokes on transactions. Fix: upload B as B^T (N×K row-major) when the warpk regime is dispatched. Lanes then read 32 contiguous K elements per warp-step, identical pattern to A. No kernel-shape change; only the B layout. Naive path keeps current B layout (already coalesced for that parallelization). Driver branches on use_warpk and uploads the appropriate B. Expected 1.5-2.5x in warpk regime, plus likely expansion of COUNTING_WARPK_MN_CEILING past 64*64 once warpk is broadly competitive. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codex's TensorBFS#2 candidate. The warpk kernel's only memory-bound bottleneck was non-coalesced B reads (32 lanes at fixed j, strided by N). Fix: upload B as B^T (N×K row-major) when warpk dispatched. Lanes then read 32 contiguous K elements per warp-step, identical to A. Naive path keeps non-transposed B (already coalesced for that parallelization). Driver branches B upload on use_warpk and applies host-side transpose. transpose_row_major helper added to crt.rs. Trait method launch_counting_gemm_ones now takes M, K, N explicitly because B's GpuMatrix shape is layout-dependent. Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only): M=N=64, K=4096: 122 -> 1324 G/s (10.9x over naive-ones) M=N=128, K=4096: 480 -> 1903 G/s (3.96x) M=N=512, K=4096: 1786 -> 2316 G/s (1.30x) M=N=1024, K=4096: 2031 -> 2525 G/s (1.24x) Peak hits 2.5 TG/s — beyond the prior MaxPlus reference. COUNTING_WARPK_MN_CEILING raised 64*64 -> 128*128. Conservative re-tune: kernel wins extend past 1024² but host-side transpose cost on single-prime calls limits the profitable ceiling. Multi-prime CRT calls would win further out; on-device transpose (follow-up) would push the ceiling to ~256-1024². Adds warpk_transposed_b_layout test (deliberately asymmetric A/B to catch transpose-helper bugs). 17/17 integration + 56/56 lib green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rhead Per memory profiling, BigInt construction dominates e2e at 2048² (~65% of 459 ms). Kernel itself is now <5% (1946 G/s ones-kernel takes ~70 ms of e2e). The dominant cost is the output format. For count_upper_bound < 2^60 (covers any single-matmul case via bound_for_single_matmul(k)), CRT product fits in u64 with one or two 30-bit primes. Specialize: return Vec<u64> instead of Vec<BigInt>, zero per-cell allocation. API: parallel entry point count_ground_states_gpu_u64 + CountedMatU64. BigInt path stays for general / chained-matmul callers. Driver refactored to share kernel-loop helper between the two paths. Targets ~5x e2e speedup (459 ms -> ~80-100 ms at 2048²). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Eliminates per-cell BigInt heap allocation by returning Vec<u64> counts when the CRT product fits in u63 (≤2 of the 30-bit primes — covers count_upper_bound < 2^60, all single-matmul use cases). CPU additions (tropical_gemm::crt): - crt_combine_u64: pairwise CRT in u64 with extended-Euclidean inverse. - choose_primes_u64: prime-prefix selection bounded by 2^63. - bound_for_single_matmul_u64. - 4 unit tests including BigInt-parity randomized check. GPU driver (tropical_gemm_cuda::crt): - run_kernels_per_prime: extracted helper for the shared device side (layout choice, B-transpose-or-not, kernel launches, downloads). BigInt and u64 entry points both use it. - count_ground_states_gpu_u64 + CountedMatU64: new entry point. Single-prime path skips CRT combine entirely; 2-prime path runs pairwise crt_combine_u64. Measured on A100-SXM4-80GB (f32 Max, single matmul, e2e): 256²: 3.54 -> 0.26 ms (13.50x) 512²: 13.92 -> 0.91 ms (15.36x) 1024²: 63.14 -> 8.38 ms (7.53x) 2048²: 292.46 -> 29.79 ms (9.82x) 4096²: 1237.90 -> 152.28 ms (8.13x) Cumulative e2e improvement at 2048² since branch start: 459 ms -> 30 ms = ~15x. Combined with kernel-only progress (50 -> 1946 G/s = 39x), the counting matmul is now genuinely fast end-to-end on A100. 21/21 integration tests + 60/60 lib tests green. BigInt path unchanged in behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a Julia interface to count_ground_states_gpu_u64 for the four (T, D) combos. C ABI added to tropical-gemm-cuda (now built as both rlib + cdylib). Julia package CountingTropicalGEMM.jl/ at repo root, ~100 LOC main module + 50 LOC tests. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a C ABI to tropical-gemm-cuda (now built as both rlib + cdylib)
exposing count_ground_states_gpu_u64 to non-Rust callers. Four entry
points (f32/f64 × Max/Min) plus version + last-error TLS helpers,
each wrapped in catch_unwind for panic safety across the FFI boundary.
Julia package CountingTropicalGEMM.jl/ at repo root. Library
resolution via TROPICAL_GEMM_LIB env override, workspace dev-build
fallback, or system search paths. Result type CountedMatU64{T}
mirrors the Rust struct. Typed exceptions (BoundTooLargeError for
the u64 envelope, CountingTropicalGEMMError otherwise).
Wrapper transposes Julia's column-major Matrix{T} to row-major at
the FFI boundary; transpose cost is negligible vs kernel time.
Tested 8/8 green on A100-SXM4-80GB:
f32 Max small (hand-verifiable), f64 Min vs reference (randomized),
all-ties large K, BoundTooLargeError, DimensionMismatch.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CountingTropicalGEMM.jl/bench/bench.jl exercises the same shapes as
Rust's bench_e2e_u64. Reports tropical-ops/s through the full Julia
path (transpose, ccall, kernel, reconstruction, output reshape).
Measured on A100-SXM4-80GB:
256²: 0.69 ms (vs Rust 0.26 ms)
512²: 6.66 ms (vs Rust 0.91 ms)
1024²: 39.7 ms (vs Rust 8.4 ms)
2048²: 120 ms (vs Rust 30 ms)
4096²: 441 ms (vs Rust 152 ms)
Julia overhead scales with M*N — confirmed source is the
column→row-major boundary transpose on inputs + outputs (Julia's
Matrix{T} is column-major; the kernel expects row-major). Output
transpose is the expensive one at 4096²: 16M UInt64 elements = 128 MB
of host memory motion.
Overhead is the wrapper's, not the kernel's. Kernel itself still
hits ~1946 G/s; full Julia path at 4096² is ~312 G-ops/s e2e.
Possible follow-up: layout-flag in C ABI to accept column-major
inputs directly, avoiding host transpose.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tg_bench_kernel_only_u64_<T>_<D> to the C ABI: uploads once,
runs the kernel `iters` times, returns avg per-launch wall time
(ms). Bypasses CRT combine and u64 reconstruction — measures pure
kernel runtime + sync, matching what bench_kernel_only.rs reports.
Julia: bench_kernel_only_u64(dir, A, B, bound; iters) -> Float64.
Bench script rewritten to use this entry point.
Measured on A100-SXM4-80GB (f32 Max, 1 prime, kernel-only via Julia):
Naive path:
1024²: 1.112 ms (1931 G/s)
2048²: 8.704 ms (1974 G/s)
4096²: 70.08 ms (1961 G/s)
Warpk path (transposed B, K=4096):
M=N=64: 0.027 ms (1244 G/s)
M=N=512: 1.097 ms (1958 G/s)
M=N=1024: 4.030 ms (2132 G/s) — peak
Numbers match the Rust bench_kernel_only example to within timing
noise, confirming the binding adds no kernel-side overhead.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Confirms throughput scales: 1829 G/s at 16384² × 16384 = 8.8 × 10^12 tropical-ops in 4.81 sec (within ~7% of 4096²'s 1961 G/s peak). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two larger sizes to bench_julia_vs_rust.jl. Confirms
GenericTensorNetworks/CUDA.jl's auto-compiled broadcast kernel for
CountingTropical{Float32, Mod{P}} saturates at ~220-265 G/s across
sizes from 512² to 8192² — vs our specialized kernel at ~1960 G/s.
Measured A100-SXM4-80GB (kernel-only ms / G tropical-ops/s):
4096²: GTN 614 ms (224 G/s) vs ours 70 ms (1961 G/s) — 8.8x
8192²: GTN 4979 ms (221 G/s) vs ours ~560 ms (~1960) — ~9x
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #42 +/- ##
==========================================
- Coverage 96.41% 94.19% -2.22%
==========================================
Files 18 22 +4
Lines 892 1086 +194
==========================================
+ Hits 860 1023 +163
- Misses 32 63 +31 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
This PR adds full-stack support for “counting tropical” GEMM: new semiring direction tags (Max/Min), CRT-based exact counting on CPU and GPU (including a u64 fast-path), plus Python and Julia bindings and accompanying design specs/bench scripts.
Changes:
- Introduce
TropicalDirection(Max/Min) and makeCountingTropicaldirection-aware; addReprTransparentTropicalto safely support scalar-slice reinterpretation for transparent semiring wrappers. - Add GPU counting CRT driver (BigInt + u64 fast-path), AoS pair layout helpers, counting CUDA kernels (including ones-specialized + warpk-transposed-B variants), and new CUDA benchmarks.
- Add Python bindings/tests for CPU+GPU count-ground-states and a Julia package (
CountingTropicalGEMM.jl) with tests/benchmarks; add multiple design spec documents.
Reviewed changes
Copilot reviewed 69 out of 71 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| sanity_julia.jl | Local Julia sanity script for CountingTropical GPU matmul (currently contains machine-specific paths). |
| bench_julia_vs_rust.jl | Julia benchmark to compare CUDA.jl generic CountingTropical matmul vs Rust (currently contains machine-specific paths). |
| docs/superpowers/specs/2026-04-27-julia-binding-design.md | Design spec for Julia binding architecture and FFI surface. |
| docs/superpowers/specs/2026-04-27-counting-u64-fastpath-design.md | Design spec for u64 fast-path in GPU driver + CRT helpers. |
| docs/superpowers/specs/2026-04-27-counting-kernel-warpk-transposed-design.md | Design spec for warpk kernel reading transposed B for coalesced loads. |
| docs/superpowers/specs/2026-04-27-counting-kernel-ones-specialized-design.md | Design spec for ones-specialized counting kernels. |
| docs/superpowers/specs/2026-04-27-counting-kernel-aos-design.md | Design spec for AoS (value,count) layout for counting kernels. |
| docs/superpowers/specs/2026-04-21-counting-tropical-cuda-design.md | Prior design spec for GPU counting kernel + CRT driver. |
| docs/superpowers/specs/2026-04-21-counting-tropical-crt-design.md | Prior design spec for CRT/BigInt CPU driver and bound contract. |
| docs/superpowers/specs/2026-04-21-counting-tropical-compose-design.md | Prior design spec for making CountingTropical compose with GEMM pipeline. |
| docs/superpowers/specs/2026-04-21-counting-kernel-tiled-design.md | Prior design spec for tiled CUDA counting kernel. |
| crates/tropical-gemm/tests/counting_crt.rs | New/updated CRT correctness tests vs BigInt oracle + edge cases. |
| crates/tropical-gemm/tests/counting_compose.rs | Tests for CountingTropical composing through tropical_matmul_t (Max/Min, ties, multiplication). |
| crates/tropical-gemm/src/types/traits.rs | Adds ReprTransparentTropical marker trait for safe scalar↔element reinterpretation. |
| crates/tropical-gemm/src/types/modp.rs | Adds Mod<P> scalar for CRT residue arithmetic. |
| crates/tropical-gemm/src/types/mod.rs | Wires new direction + Mod modules and re-exports Max/Min/Mod/ReprTransparentTropical. |
| crates/tropical-gemm/src/types/min_plus.rs | Implements ReprTransparentTropical for TropicalMinPlus. |
| crates/tropical-gemm/src/types/max_plus.rs | Implements ReprTransparentTropical for TropicalMaxPlus. |
| crates/tropical-gemm/src/types/max_mul.rs | Implements ReprTransparentTropical for TropicalMaxMul. |
| crates/tropical-gemm/src/types/direction.rs | Introduces TropicalDirection with Max and Min tags. |
| crates/tropical-gemm/src/types/counting.rs | Makes CountingTropical direction-aware (Max/Min), disables SIMD claim, adds Min tests. |
| crates/tropical-gemm/src/types/and_or.rs | Implements ReprTransparentTropical for TropicalAndOr. |
| crates/tropical-gemm/src/testing/mod.rs | Adds test-only module wiring (feature-gated). |
| crates/tropical-gemm/src/testing/bigint_semiring.rs | Adds BigInt-based reference implementation for CRT tests. |
| crates/tropical-gemm/src/simd/kernels/portable.rs | Updates microkernel signatures to take *const T and fixes tests accordingly. |
| crates/tropical-gemm/src/simd/kernels/neon.rs | Updates NEON kernels to accept *const T and casts internally to scalar pointers. |
| crates/tropical-gemm/src/simd/kernels/avx2.rs | Updates AVX2 kernels to accept *const T and casts internally to scalar pointers; updates tests. |
| crates/tropical-gemm/src/mat/ref_.rs | Refactors MatRef to store element slices, adds safe scalar conversion via ReprTransparentTropical. |
| crates/tropical-gemm/src/mat/owned.rs | Refactors Mat APIs to use element slices; adds from_elements convenience alias. |
| crates/tropical-gemm/src/mat/ops.rs | Updates multiplication trait bounds to include Default. |
| crates/tropical-gemm/src/lib.rs | Exposes new CRT module and Max/Min exports; wires test helpers. |
| crates/tropical-gemm/src/core/packing.rs | Packing now operates on element type T (requires Copy + Default) instead of TropicalScalar. |
| crates/tropical-gemm/src/core/kernel.rs | Microkernel interface now uses *const T rather than *const T::Scalar. |
| crates/tropical-gemm/src/api.rs | Adds tropical_matmul_t for compound elements; updates scalar APIs to require ReprTransparentTropical. |
| crates/tropical-gemm/Cargo.toml | Adds testing feature and BigInt/Integer deps. |
| crates/tropical-gemm-python/tests/test_count_ground_states_gpu.py | Adds GPU round-trip tests for counting. |
| crates/tropical-gemm-python/tests/test_count_ground_states.py | Adds/updates CPU round-trip tests for counting + typed errors. |
| crates/tropical-gemm-python/src/lib.rs | Adds CPU + GPU counting bindings (CRT BigInt counts). |
| crates/tropical-gemm-python/Cargo.toml | Adds num-bigint dependency. |
| crates/tropical-gemm-cuda/src/pair.rs | Adds AoS pair types + packing helpers + unit tests. |
| crates/tropical-gemm-cuda/src/lib.rs | Exposes counting CRT APIs + CountedMatU64; wires new modules. |
| crates/tropical-gemm-cuda/src/gpu_mat.rs | Updates MatRef upload path to require ReprTransparentTropical. |
| crates/tropical-gemm-cuda/src/error.rs | Adds InvalidState error variant. |
| crates/tropical-gemm-cuda/src/crt.rs | Implements GPU CRT driver (BigInt + u64 fast-path) and shared per-prime runner. |
| crates/tropical-gemm-cuda/src/counting_kernel.rs | Adds counting kernel launcher and dispatch, including ones-specialized path. |
| crates/tropical-gemm-cuda/src/context.rs | Adds compilation/registration of counting kernels and warpk dispatch constants/dims. |
| crates/tropical-gemm-cuda/examples/bench_warpk_crossover.rs | Bench for naive vs warpk (transposed-B) crossover. |
| crates/tropical-gemm-cuda/examples/bench_maxplus.rs | Bench for MaxPlus GPU kernel throughput reference. |
| crates/tropical-gemm-cuda/examples/bench_kernel_only.rs | Kernel-only bench for AoS vs ones-specialized kernels. |
| crates/tropical-gemm-cuda/examples/bench_e2e_u64.rs | End-to-end bench: BigInt vs u64 fast-path. |
| crates/tropical-gemm-cuda/examples/bench_counting.rs | End-to-end bench: GPU vs CPU counting. |
| crates/tropical-gemm-cuda/Cargo.toml | Builds as cdylib + adds BigInt-related deps. |
| CountingTropicalGEMM.jl/test/runtests.jl | Julia package tests incl. reference parity, bounds, and dimension mismatch. |
| CountingTropicalGEMM.jl/src/CountingTropicalGEMM.jl | Julia FFI wrapper to C ABI with library resolution and error mapping. |
| CountingTropicalGEMM.jl/bench/bench_huge.jl | Julia kernel-only huge-shape benchmark. |
| CountingTropicalGEMM.jl/bench/bench.jl | Julia kernel-only benchmark driver. |
| CountingTropicalGEMM.jl/README.md | Julia package build/install/usage docs. |
| CountingTropicalGEMM.jl/Project.toml | Julia package metadata. |
| Cargo.lock | Locks new Rust dependencies (num-bigint/num-integer). |
| .gitignore | Ignores Julia Manifest for the new Julia package. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| assert_eq!(a_values.len(), m * k); | ||
| assert_eq!(b_values.len(), k * n); |
There was a problem hiding this comment.
Same as the BigInt path: avoid assert_eq! on user-controlled input lengths in this Result-returning API; return a CudaError instead of panicking.
| assert_eq!(a_values.len(), m * k); | |
| assert_eq!(b_values.len(), k * n); | |
| if a_values.len() != m * k { | |
| return Err(CudaError::InvalidState( | |
| format!( | |
| "invalid a_values length: got {}, expected {} (m * k)", | |
| a_values.len(), | |
| m * k | |
| ) | |
| .into(), | |
| )); | |
| } | |
| if b_values.len() != k * n { | |
| return Err(CudaError::InvalidState( | |
| format!( | |
| "invalid b_values length: got {}, expected {} (k * n)", | |
| b_values.len(), | |
| k * n | |
| ) | |
| .into(), | |
| )); | |
| } |
| Pkg.add(["CUDA", "TropicalNumbers"]) | ||
| Pkg.add(path = "/mnt/home/xgao1/work/better_gpu_gemm/GenericTensorNetworks.jl") |
There was a problem hiding this comment.
This script hard-codes a machine-specific local path (Pkg.add(path = "/mnt/home/.../GenericTensorNetworks.jl")), which will fail for other developers/CI. Consider removing this file from the repo, or make the path configurable (e.g., via ENV) / use a relative path / documented Pkg.develop instructions instead.
| Pkg.add(path = "/mnt/home/xgao1/work/better_gpu_gemm/GenericTensorNetworks.jl") | ||
|
|
There was a problem hiding this comment.
This benchmark script hard-codes a local filesystem path for GenericTensorNetworks.jl, which makes it non-portable and likely to break in other environments. Prefer documenting setup steps and using Pkg.develop/relative paths or environment configuration rather than committing an absolute path.
| Pkg.add(path = "/mnt/home/xgao1/work/better_gpu_gemm/GenericTensorNetworks.jl") | |
| if !haskey(ENV, "GENERIC_TENSOR_NETWORKS_PATH") | |
| error("Set GENERIC_TENSOR_NETWORKS_PATH to a local GenericTensorNetworks.jl checkout before running this benchmark.") | |
| end | |
| gtn_path = abspath(expanduser(ENV["GENERIC_TENSOR_NETWORKS_PATH"])) | |
| isdir(gtn_path) || error("GENERIC_TENSOR_NETWORKS_PATH does not exist or is not a directory: " * gtn_path) | |
| Pkg.develop(path = gtn_path) |
| /// The inner `i32` is always in `[0, P)` (the normalized representative). | ||
| /// Construct via `Mod::new` (which normalizes) or reconstruct from a raw | ||
| /// representative via `raw`. See module docs for the size constraint on `P`. |
There was a problem hiding this comment.
Doc comment is misleading: raw() is an accessor, not a way to reconstruct/create a Mod<P> from an existing raw representative. Consider rewording to avoid implying there is a from_raw constructor, or add an explicit from_raw_unchecked/from_raw API if that capability is intended.
| let ctx = CudaContext::new().map_err(|e| format!("CUDA init: {}", e))?; | ||
| match direction { | ||
| "max" => count_ground_states_gpu::<f32, Max>( | ||
| &ctx, &a_data, m, k, &b_data, n, &bound, | ||
| ) | ||
| .map_err(|e| format!("GPU compute: {}", e)), | ||
| "min" => count_ground_states_gpu::<f32, Min>( | ||
| &ctx, &a_data, m, k, &b_data, n, &bound, |
There was a problem hiding this comment.
count_ground_states_gpu_py constructs a fresh CudaContext::new() on every call, which will redo NVRTC compilation/context setup and defeats the crate’s existing global-context caching. Use the shared/global context (or a cached context per process/device) instead of creating a new one per invocation.
| let ctx = CudaContext::new().map_err(|e| format!("CUDA init: {}", e))?; | |
| match direction { | |
| "max" => count_ground_states_gpu::<f32, Max>( | |
| &ctx, &a_data, m, k, &b_data, n, &bound, | |
| ) | |
| .map_err(|e| format!("GPU compute: {}", e)), | |
| "min" => count_ground_states_gpu::<f32, Min>( | |
| &ctx, &a_data, m, k, &b_data, n, &bound, | |
| // Reuse the process-global CUDA context so we do not pay | |
| // initialization / NVRTC compilation costs on every call. | |
| let ctx = CudaContext::global().map_err(|e| format!("CUDA init: {}", e))?; | |
| match direction { | |
| "max" => count_ground_states_gpu::<f32, Max>( | |
| ctx, &a_data, m, k, &b_data, n, &bound, | |
| ) | |
| .map_err(|e| format!("GPU compute: {}", e)), | |
| "min" => count_ground_states_gpu::<f32, Min>( | |
| ctx, &a_data, m, k, &b_data, n, &bound, |
| let counts_py: Vec<PyObject> = result | ||
| .counts | ||
| .into_iter() | ||
| .map(|bn| { | ||
| let s = bn.to_string(); | ||
| // Use Python's built-in int() on the decimal string. | ||
| py.eval( | ||
| std::ffi::CString::new(format!("int({})", s)) | ||
| .unwrap() | ||
| .as_c_str(), | ||
| None, | ||
| None, | ||
| ) | ||
| .map(|b| b.unbind()) |
There was a problem hiding this comment.
Converting BigInt to Python int via py.eval("int(<decimal>)") is unnecessarily slow (parses Python code per element) and relies on code-eval. Prefer constructing PyLong directly (e.g., via the CPython API PyLong_FromString / PyLong_FromUnicodeObject, or calling builtins.int with the decimal string without eval).
| let counts_py: Vec<PyObject> = result | |
| .counts | |
| .into_iter() | |
| .map(|bn| { | |
| let s = bn.to_string(); | |
| // Use Python's built-in int() on the decimal string. | |
| py.eval( | |
| std::ffi::CString::new(format!("int({})", s)) | |
| .unwrap() | |
| .as_c_str(), | |
| None, | |
| None, | |
| ) | |
| .map(|b| b.unbind()) | |
| let py_int = py.import("builtins")?.getattr("int")?; | |
| let counts_py: Vec<PyObject> = result | |
| .counts | |
| .into_iter() | |
| .map(|bn| { | |
| let s = bn.to_string(); | |
| py_int.call1((s,)).map(|b| b.unbind()) |
| // Convert BigInt → Python int via decimal string representation. | ||
| let counts_py: Vec<PyObject> = result | ||
| .counts | ||
| .into_iter() | ||
| .map(|bn| { | ||
| let s = bn.to_string(); | ||
| // Use Python's built-in int() on the decimal string. | ||
| py.eval( | ||
| std::ffi::CString::new(format!("int({})", s)) | ||
| .unwrap() | ||
| .as_c_str(), | ||
| None, | ||
| None, | ||
| ) | ||
| .map(|b| b.unbind()) | ||
| }) |
There was a problem hiding this comment.
Same issue as the GPU binding: BigInt → Python int conversion uses py.eval("int(<decimal>)"), which is very expensive for large outputs and should be replaced with direct PyLong creation (no eval).
| // Convert BigInt → Python int via decimal string representation. | |
| let counts_py: Vec<PyObject> = result | |
| .counts | |
| .into_iter() | |
| .map(|bn| { | |
| let s = bn.to_string(); | |
| // Use Python's built-in int() on the decimal string. | |
| py.eval( | |
| std::ffi::CString::new(format!("int({})", s)) | |
| .unwrap() | |
| .as_c_str(), | |
| None, | |
| None, | |
| ) | |
| .map(|b| b.unbind()) | |
| }) | |
| // Convert BigInt → Python int directly via CPython's PyLong API. | |
| fn bigint_to_pyobject(py: Python<'_>, bn: BigInt) -> PyResult<PyObject> { | |
| let bytes = bn.to_signed_bytes_le(); | |
| let ptr = unsafe { | |
| pyo3::ffi::_PyLong_FromByteArray( | |
| bytes.as_ptr(), | |
| bytes.len(), | |
| 1, | |
| 1, | |
| ) | |
| }; | |
| unsafe { Bound::from_owned_ptr_or_err(py, ptr).map(|obj| obj.unbind()) } | |
| } | |
| let counts_py: Vec<PyObject> = result | |
| .counts | |
| .into_iter() | |
| .map(|bn| bigint_to_pyobject(py, bn)) |
| assert_eq!(a_values.len(), m * k); | ||
| assert_eq!(b_values.len(), k * n); |
There was a problem hiding this comment.
This public API uses assert_eq! for input length validation, which will panic on user error. For a Result-returning GPU entry point, prefer returning a structured error (e.g., CudaError::DimensionMismatch / CudaError::InvalidState) so callers (including Python/Julia bindings) can handle it without a panic.
| assert_eq!(a_values.len(), m * k); | |
| assert_eq!(b_values.len(), k * n); | |
| if a_values.len() != m * k { | |
| return Err(CudaError::DimensionMismatch); | |
| } | |
| if b_values.len() != k * n { | |
| return Err(CudaError::DimensionMismatch); | |
| } |
No description provided.