Skip to content

[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180

Open
Flink-ddd wants to merge 4 commits into
mainfrom
feat/add-ws1-gemm
Open

[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180
Flink-ddd wants to merge 4 commits into
mainfrom
feat/add-ws1-gemm

Conversation

@Flink-ddd

@Flink-ddd Flink-ddd commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Fixes #146

Purpose

In RL train–inference alignment, a token's logprob must not drift between rollout (vLLM-side, varying batch shapes) and training (Megatron-side). A major source of this drift is the GEMM: cuBLAS selects kernels by problem shape and may use split-K, so the K-reduction order — and therefore the low bits of the output changes with batch size, chunked-prefill splitting, and padding layout. The same row produces different results depending on the batch it rides in.

This PR implements a batch-invariant deterministic GEMM (forward + backward), one op in the WS1 forward chain (#146). A row's output is bitwise invariant to batch size, chunked-prefill splitting, and padding layout. Invariance is achieved by pinning the tile shape, fixing the K-accumulation order, and forbidding split-K / shape-based kernel selection. BF16 inputs, FP32 accumulation, no TF32.

This is PR2 of the planned series (design note → kernel → tests → LM-head wiring → benchmark). Scope here: the kernel(s), op wiring, and invariance tests.

Invariance contract. The same (M, N, K) row produces the same output regardless of the surrounding batch, chunked-prefill split, or padding. Two design points protect this:

  • M is padded to a multiple of the tile (BM) so every M — including M=1 and non-tile-aligned M — takes the same kernel. Selecting a kernel by M would itself break batch-invariance, since M is the batch dimension.
  • Single CTA per output tile, fixed ascending K loop. No split-K, no cross-CTA K reduction, no atomics.

Backends. CUDA (SM90 TMA + mma.sync tensor cores; naive FP32 scalar fallback below SM90), Triton (portable / ROCm fallback + cross-backend invariance eference), and a PyTorch reference (torch.matmul, intentionally non-deterministic — reference/benchmark only, deliberately excluded from dispatch, since a non-deterministic fallback would silently defeat the op's purpose).

Out of scope (per #146): tensor-parallel GEMM (WS2), FP8, and occupancy/throughput tuning of the tensor-core path — correctness and invariance first; a slower-than-cuBLAS deterministic baseline is the accepted first milestone. Throughput tuning will be a separate perf PR (see Follow-ups).

Test Plan

  1. Build: compile the _C extension with the SM90 path enabled (KERNEL_ALIGN_DET_GEMM_SM90=1) on H100 (SM90), CUDA 12.4.
  2. Single-tile correctness: SM90 TMA + mma.sync kernel vs an FP32 reference on a 64×64 tile, then on multi-tile shapes (M-tiles, long K), to validate the ldmatrix addressing, B-operand layout, and epilogue.
  3. Invariance (hard gate, bitwise via torch.equal):
    • forward batch-invariance (same row, different batch sizes),
    • chunked-prefill split == full GEMM,
    • padding rows do not affect valid rows,
    • backward dA batch-invariance,
      on both the CUDA tensor-core kernel and the Triton path.
  4. Correctness vs FP32: forward and backward, placeholder tolerances pending the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 numerical contract.
  5. Target shapes: invariance across the 5 real projection shapes (QKV / o_proj / MLP up / MLP down / LM-head).
  6. Benchmark: overhead vs cuBLAS (TF32 disabled), as the fair baseline.

Test Result

Build: compiled and linked successfully with the SM90 TMA + mma.sync path enabled (CUDA 12.4, H100). det_gemm uses its own csrc/cuda/gemm/det_gemm_tma.cuh (shared::cluster); the shared csrc/utils/tma_utils.cuh emits shared::cta, which CUDA 12.4 ptxas rejects for cp.async.bulk.tensor — scoped the fix to this PR rather than touching the shared logp helpers (tracked separately in #). The KERNEL_ALIGN_DET_GEMM_SM90 build flag is independent of the fused_logp SM90 sources, which are left untouched.

Tests: pytest tests/test_det_gemm.py → 22/22 passed on H100 SXM (CUDA 12.4) — forward/backward batch-invariance, chunked-prefill,
padding-invariance (all bitwise), correctness vs FP32, and all 5 target
projection shapes, on both the CUDA tensor-core kernel and Triton.

-ltorch_cuda -o /tmp/tmptxrii14y.build-lib/rl_engine/_C.cpython-311-x86_64-linux-gnu.so -lcuda
===================================================================================== test session starts ======================================================================================
platform linux -- Python 3.11.10, pytest-9.1.1, pluggy-1.6.0 -- /usr/bin/python
cachedir: .pytest_cache
rootdir: /root/RL-Kernel
configfile: pyproject.toml
plugins: anyio-4.6.0
collected 22 items                                                                                                                                                                             

tests/test_det_gemm.py::test_forward_batch_invariance[cuda-deterministic_gemm] PASSED                                                                                                    [  4%]
tests/test_det_gemm.py::test_forward_batch_invariance[triton-deterministic_gemm_triton] PASSED                                                                                           [  9%]
tests/test_det_gemm.py::test_forward_chunked_prefill[cuda-deterministic_gemm] PASSED                                                                                                     [ 13%]
tests/test_det_gemm.py::test_forward_chunked_prefill[triton-deterministic_gemm_triton] PASSED                                                                                            [ 18%]
tests/test_det_gemm.py::test_forward_padding_invariance[cuda-deterministic_gemm] PASSED                                                                                                  [ 22%]
tests/test_det_gemm.py::test_forward_padding_invariance[triton-deterministic_gemm_triton] PASSED                                                                                         [ 27%]
tests/test_det_gemm.py::test_forward_correctness[cuda-deterministic_gemm] PASSED                                                                                                         [ 31%]
tests/test_det_gemm.py::test_forward_correctness[triton-deterministic_gemm_triton] PASSED                                                                                                [ 36%]
tests/test_det_gemm.py::test_backward_batch_invariance[cuda-deterministic_gemm] PASSED                                                                                                   [ 40%]
tests/test_det_gemm.py::test_backward_batch_invariance[triton-deterministic_gemm_triton] PASSED                                                                                          [ 45%]
tests/test_det_gemm.py::test_backward_correctness[cuda-deterministic_gemm] PASSED                                                                                                        [ 50%]
tests/test_det_gemm.py::test_backward_correctness[triton-deterministic_gemm_triton] PASSED                                                                                               [ 54%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape0-cuda-deterministic_gemm] PASSED                                                                                             [ 59%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape0-triton-deterministic_gemm_triton] PASSED                                                                                    [ 63%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape1-cuda-deterministic_gemm] PASSED                                                                                             [ 68%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape1-triton-deterministic_gemm_triton] PASSED                                                                                    [ 72%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape2-cuda-deterministic_gemm] PASSED                                                                                             [ 77%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape2-triton-deterministic_gemm_triton] PASSED                                                                                    [ 81%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape3-cuda-deterministic_gemm] PASSED                                                                                             [ 86%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape3-triton-deterministic_gemm_triton] PASSED                                                                                    [ 90%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape4-cuda-deterministic_gemm] PASSED                                                                                             [ 95%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape4-triton-deterministic_gemm_triton] PASSED                                                                                    [100%]

====================================================================================== 22 passed in 2.85s ======================================================================================

Benchmark (NVIDIA H100 80GB HBM3, SM90): overhead = det CUDA vs cuBLAS (TF32 disabled). Both deterministic paths trade speed for bitwise invariance by design (no split-K, fixed accumulation, FP32, no TF32).

shape M K N cuBLAS tf32 cuBLAS fp32 det CUDA det Triton overhead
qkv 4096 4096 12288 0.538 0.538 3.280 1.421 6.1x
o_proj 4096 4096 4096 0.190 0.190 1.164 0.478 6.1x
mlp_up 4096 4096 14336 0.656 0.704 3.800 1.688 5.4x
mlp_dn 4096 14336 4096 0.629 0.685 3.779 1.787 5.5x
lm_head 4096 4096 32000 1.513 1.528 8.269 3.897 5.4x

(The det CUDA path uses SM90 TMA + mma.sync tensor cores, 128×128 tile, single-CTA-per-tile, no split-K. Occupancy/throughput tuning is deferred per #146 and will be an ncu-driven perf PR; this milestone is correctness- and invariance-first.)

Files

  • csrc/cuda/gemm/det_gemm_kernel.cu — naive + SM90 TMA + mma.sync kernels, dispatch by compute capability (M padded to tile).
  • csrc/cuda/gemm/det_gemm_tma.cuh — det_gemm-local TMA / mbarrier primitives (shared::cluster).
  • csrc/ops.cpp — pybind registration (det_gemm_fwd / det_gemm_da / det_gemm_db).
  • rl_engine/kernels/ops/{cuda,triton,pytorch}/matmul/det_gemm.py — autograd wrappers + ops.
  • rl_engine/kernels/registry.py — det_gemm dispatch (CUDA primary, Triton fallback; PyTorch reference excluded).
  • tests/test_det_gemm.py — invariance + correctness (CUDA + Triton).
  • benchmarks/benchmark_det_gemm.py — overhead vs cuBLAS.
  • setup.py — det_gemm SM90 build flag; CUTLASS include removed; gencode fix.

Follow-ups

  • Perf PR (separate, ncu-driven): occupancy/throughput tuning of the SM90 tensor-core path to close the gap vs cuBLAS — to be tracked in a new perf issue.
  • PR3: replace placeholder test tolerances with the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 threshold table once it lands.
  • PR4: wire one real projection (LM head) through the deterministic path.
  • PR5: benchmark doc — overhead + supported shapes.
  • Fix csrc/utils/tma_utils.cuh (shared::cta → shared::cluster) so fused_logp SM90 compiles on CUDA 12.4 — separate issue #.

Summary by CodeRabbit

  • New Features

    • Added a new deterministic GEMM option for CUDA workloads, with support for a native PyTorch baseline and an optional Triton-backed implementation.
    • Exposed the new GEMM capability through the operator interface and documentation, including usage guidance and backend selection.
    • Added a benchmark report and script to compare performance across common GEMM shapes.
  • Bug Fixes

    • Improved batch-size and padding invariance for GEMM results.
    • Added validation and test coverage for forward and backward correctness on supported CUDA devices.

@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds a deterministic GEMM CUDA path with SM90 TMA support, Python/C++ bindings, Triton and native reference backends, registry wiring, build flags, docs, benchmark artifacts, and CUDA tests covering batch invariance and backward correctness.

Changes

Deterministic GEMM rollout

Layer / File(s) Summary
Kernel primitives and SM90 path
csrc/cuda/gemm/det_gemm_tma.cuh, csrc/cuda/gemm/det_gemm_kernel.cu
BF16 TMA helpers and deterministic GEMM kernels add the naive fixed-order path and the SM90 tiled path.
CUDA dispatch and C++ exports
csrc/cuda/gemm/det_gemm_kernel.cu, csrc/ops.cpp
The dispatch path pads M for batch invariance, selects SM90 or naive execution, and exposes forward/backward CUDA entrypoints through PyBind11.
CUDA Python op surface
rl_engine/kernels/ops/cuda/__init__.py, rl_engine/kernels/ops/cuda/matmul/__init__.py, rl_engine/kernels/ops/cuda/matmul/det_gemm.py, csrc/ops.cpp
The CUDA matmul package exports the deterministic op, and the wrapper calls the new CUDA entrypoints while enforcing BF16 CUDA inputs.
Native and Triton backends
rl_engine/kernels/ops/pytorch/matmul/__init__.py, rl_engine/kernels/ops/pytorch/matmul/det_gemm.py, rl_engine/kernels/ops/triton/matmul/__init__.py, rl_engine/kernels/ops/triton/matmul/det_gemm.py, rl_engine/kernels/registry.py
The reference matmul, Triton deterministic matmul, and registry backend enums and priority maps add alternate implementations and dispatch targets for det_gemm.
Build flags and extension sources
setup.py
The CUDA extension build includes the new det_gemm source and adds SM90-specific gencode, linking, and compiler definitions controlled by the new environment flags.
Docs, benchmark, and tests
docs/.nav.yml, docs/operators/det-gemm.md, benchmarks/benchmark_det_gemm.py, benchmarks/results/det_gemm_h100_tma.md, tests/test_det_gemm.py
The docs sidebar and det_gemm page describe the operator, the benchmark script and H100 result file report timings and overhead, and the CUDA tests cover batch invariance and gradient correctness across supported shapes.

Sequence Diagram(s)

sequenceDiagram
  participant DetGemmOp
  participant _DetGemmFn
  participant det_gemm_fwd
  participant gemm_dispatch
  participant det_gemm_da
  participant det_gemm_db
  DetGemmOp->>_DetGemmFn: apply(a, b)
  _DetGemmFn->>det_gemm_fwd: forward(a, b)
  det_gemm_fwd->>gemm_dispatch: choose SM90 or naive
  gemm_dispatch-->>det_gemm_fwd: C
  _DetGemmFn-->>DetGemmOp: C
  _DetGemmFn->>det_gemm_da: grad_out, b
  _DetGemmFn->>det_gemm_db: a, grad_out
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

  • Issue 151: The new deterministic CUDA, Triton, and registry wiring matches the batch-invariant GEMM implementation scope.
  • Issue 153: The new backward entrypoints and gradient-invariance tests match the backward-consistency objective for deterministic GEMM.

Possibly related PRs

  • RL-Align/RL-Kernel#91: Its SM90 build-flag gating and extension-flag changes overlap with the new det_gemm SM90 build path.

Suggested labels

component: kernels

Suggested reviewers

  • bitborne
  • EthanZero2Hero
  • inaniloquentee

Poem

A rabbit hopped through kernels bright,
With fixed-order steps and BF16 light.
SM90 hummed, Triton chimed,
The batch stayed steady, row by row aligned.
🐇 Crunch! Determinism tastes just right.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding a batch-invariant deterministic GEMM for forward and backward paths.
Linked Issues check ✅ Passed The PR matches #146 by adding a fixed-order BF16/FP32 deterministic GEMM, validating batch and padding invariance, and covering forward/backward paths.
Out of Scope Changes check ✅ Passed The changes stay focused on the deterministic GEMM, its exports, tests, docs, benchmark, and required build wiring, with no obvious unrelated additions.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/add-ws1-gemm

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@Flink-ddd Flink-ddd changed the title WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd) [WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd) Jun 22, 2026
Signed-off-by: vensen <vensenmu@gmail.com>
@Flink-ddd Flink-ddd force-pushed the feat/add-ws1-gemm branch from e797bc9 to 90d120c Compare June 24, 2026 07:21
@Flink-ddd Flink-ddd marked this pull request as ready for review June 26, 2026 05:21

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
rl_engine/kernels/registry.py (1)

99-112: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

det_gemm falls back to the logp backend on CPU.

CUDA and ROCm get explicit det_gemm entries, but CPU does not. Because get_op() defaults missing op types to OpBackend.PYTORCH_NATIVE, get_op("det_gemm") on CPU will try to instantiate NativeLogpOp instead of failing fast.

Suggested fix
             "cpu": {
                 "logp": [OpBackend.PYTORCH_NATIVE],
                 "attn": [OpBackend.PYTORCH_ATTN],
                 "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS],
                 "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP],
                 "ratio_kl": [OpBackend.PYTORCH_RATIO_KL],
+                "det_gemm": [],
             },
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/registry.py` around lines 99 - 112, The
`KernelRegistry.get_op()` fallback is incorrectly treating missing `det_gemm` as
`OpBackend.PYTORCH_NATIVE`, which routes CPU requests to `NativeLogpOp` instead
of failing or using a valid `det_gemm` backend. Update the registry/lookup logic
so `det_gemm` is explicitly handled for CPU in `KernelRegistry` (or excluded
from the generic native fallback), and ensure `get_op("det_gemm")` selects only
a real `det_gemm` implementation or raises a clear unsupported-op error.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@benchmarks/benchmark_det_gemm.py`:
- Around line 52-63: The benchmark currently starts timing in run() without
checking that the CUDA backend binding is actually available, so a None handle
can still fail later inside deterministic_gemm as NoneType.det_gemm_fwd. Add an
explicit preflight check before the SHAPES loop to verify the CUDA
wrapper/binding used by native_gemm and deterministic_gemm is initialized, and
if not, exit or skip with a clear message instead of letting _time() trigger the
opaque failure. If Triton is part of the path, gate deterministic_gemm_triton
the same way so the benchmark only runs when the required backend symbols are
present.

In `@csrc/cuda/gemm/det_gemm_kernel.cu`:
- Around line 230-238: The gemm dispatch path currently assumes the active CUDA
device matches the input tensors and does not reject mixed-GPU inputs. Update
check_in() and gemm_dispatch() to verify both tensors are on CUDA, are bf16, and
reside on the same device, then add an explicit device guard in gemm_dispatch()
so the allocation of c and the kernel launch via
at::cuda::getCurrentCUDAStream() are pinned to the inputs’ device. Use the
existing gemm_dispatch, check_in, and stream setup to place the guard and
same-device validation close to the launch site.

In `@docs/operators/det-gemm.md`:
- Around line 28-29: The CUDA backend description in the DetGemm documentation
is outdated and still describes a naive FP32 milestone with deferred tensor-core
work, while the merged implementation already uses the SM90 TMA plus mma.sync
path. Update the wording in the DetGemm operator docs table and any related CUDA
backend references so they accurately reflect the implemented DetGemmOp behavior
and current performance profile, keeping the TritonDetGemmOp description
consistent with the same section.

In `@rl_engine/kernels/ops/cuda/matmul/det_gemm.py`:
- Around line 16-20: The functional deterministic GEMM path is bypassing the
compiled-extension availability guard and can hit _C.det_gemm_fwd when _C is
None. Update _DetGemmFn.forward and deterministic_gemm to perform the same
availability check used by DetGemmOp.__call__ before touching _C, and raise the
intended RuntimeError instead of allowing an AttributeError to surface.

In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py`:
- Around line 101-105: The backward path in det_gemm.backward still computes dB
from Aᵀ @ grad_out, so the reduction depends on the live batch dimension and can
vary with chunking/padding. Update the db calculation to use the WS1 invariant
reduction contract instead of directly reducing over the current batch layout,
and ensure the Triton backward path produces the same dW regardless of batch
size or layout changes. Keep the fix localized around backward and the
_triton_gemm calls in det_gemm.py.
- Around line 68-90: The _triton_gemm helper currently derives K from a.shape
and launches _det_gemm_kernel without verifying that b.shape[0] matches, so add
an explicit inner-dimension check before the kernel call and raise the standard
matmul shape error on mismatch. Keep the validation in _triton_gemm near the
existing M, K, N shape extraction so invalid inputs are rejected before any
Triton launch.

In `@setup.py`:
- Around line 64-71: The SM90 build path still emits a conflicting plain SM90
gencode when KERNEL_ALIGN_DET_GEMM_SM90 is enabled, so update the nvcc flag
assembly in setup.py to avoid adding the compute_90/sm_90 pair on SM90 machines.
Make the logic in the startup flag block and the later SM90 append path
consistent so that enabling KERNEL_ALIGN_DET_GEMM_SM90 alone does not require
also setting KERNEL_ALIGN_FORCE_SM90. Use the existing nvcc_flags construction
and the SM90-specific branch to ensure only the intended 90a gencode is present.

In `@tests/test_det_gemm.py`:
- Around line 25-33: The CUDA backend is being added to _BACKENDS even when
deterministic_gemm is not actually available, which causes crashes before the
assertions run. Update the backend selection near pytestmark/_BACKENDS to check
the wrapper’s real availability signal from deterministic_gemm before appending
the CUDA case, or convert it into a clear module-level skip when the extension
binding is None. Keep the Triton entry gated separately with _HAS_TRITON.
- Around line 78-85: The correctness assertions in test_forward_correctness and
the related tests still use temporary loose thresholds instead of the shared
`#108` contract. Update the checks in the gemm test cases to use the `#108`
correctness harness/thresholds rather than hardcoded placeholder max_abs limits,
and make sure the affected test functions (including test_forward_correctness
and the nearby cases in the same range) validate against the contract
consistently.

---

Outside diff comments:
In `@rl_engine/kernels/registry.py`:
- Around line 99-112: The `KernelRegistry.get_op()` fallback is incorrectly
treating missing `det_gemm` as `OpBackend.PYTORCH_NATIVE`, which routes CPU
requests to `NativeLogpOp` instead of failing or using a valid `det_gemm`
backend. Update the registry/lookup logic so `det_gemm` is explicitly handled
for CPU in `KernelRegistry` (or excluded from the generic native fallback), and
ensure `get_op("det_gemm")` selects only a real `det_gemm` implementation or
raises a clear unsupported-op error.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0de95808-0f0e-49ba-ba17-f216e838e757

📥 Commits

Reviewing files that changed from the base of the PR and between be5ec9b and 70a60a5.

📒 Files selected for processing (17)
  • benchmarks/benchmark_det_gemm.py
  • benchmarks/results/det_gemm_h100_tma.md
  • csrc/cuda/gemm/det_gemm_kernel.cu
  • csrc/cuda/gemm/det_gemm_tma.cuh
  • csrc/ops.cpp
  • docs/.nav.yml
  • docs/operators/det-gemm.md
  • rl_engine/kernels/ops/cuda/__init__.py
  • rl_engine/kernels/ops/cuda/matmul/__init__.py
  • rl_engine/kernels/ops/cuda/matmul/det_gemm.py
  • rl_engine/kernels/ops/pytorch/matmul/__init__.py
  • rl_engine/kernels/ops/pytorch/matmul/det_gemm.py
  • rl_engine/kernels/ops/triton/matmul/__init__.py
  • rl_engine/kernels/ops/triton/matmul/det_gemm.py
  • rl_engine/kernels/registry.py
  • setup.py
  • tests/test_det_gemm.py

Comment thread benchmarks/benchmark_det_gemm.py
Comment thread csrc/cuda/gemm/det_gemm_kernel.cu
Comment thread docs/operators/det-gemm.md
Comment thread rl_engine/kernels/ops/cuda/matmul/det_gemm.py
Comment thread rl_engine/kernels/ops/triton/matmul/det_gemm.py
Comment thread rl_engine/kernels/ops/triton/matmul/det_gemm.py
Comment thread setup.py
Comment thread tests/test_det_gemm.py
Comment thread tests/test_det_gemm.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[WS1][kernels] Batch-invariant matmul / GEMM

1 participant