Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions benchmarks/benchmark_linear_logp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def run_benchmark(args):
for num_tokens, hidden_dim, vocab in args.configs:
hidden, weight, target = _make_inputs(num_tokens, hidden_dim, vocab, device, dtype)

def fwd(op, h=hidden, w=weight):
def fwd(op, h=hidden, w=weight, t=target):
with torch.no_grad():
op(h, w, target)
op(h, w, t)

def fwd_bwd(op):
h = hidden.clone().requires_grad_(True)
w = weight.clone().requires_grad_(True)
op(h, w, target).sum().backward()
def fwd_bwd(op, h_src=hidden, w_src=weight, t=target):
h = h_src.clone().requires_grad_(True)
w = w_src.clone().requires_grad_(True)
op(h, w, t).sum().backward()

n_fwd = _time_ms(lambda: fwd(native), args.warmup, args.iters)
t_fwd = _time_ms(lambda: fwd(triton_op), args.warmup, args.iters)
Expand Down
20 changes: 12 additions & 8 deletions csrc/cuda/fused_linear_logp_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,16 @@ __global__ void fused_linear_logp_sm90_kernel(const __grid_constant__ CUtensorMa
float *sZt = sSum + BM;
int *mbar_base = reinterpret_cast<int *>(sZt + BM); // STAGES mbarriers (8B each)

const uint32_t sH_base = static_cast<uint32_t>(__cvta_generic_to_shared(sH));
const uint32_t sW_base = static_cast<uint32_t>(__cvta_generic_to_shared(sW));
int mbar[STAGES];
const uint64_t sH_base_tma = __cvta_generic_to_shared(sH);
const uint64_t sW_base_tma = __cvta_generic_to_shared(sW);
const uint32_t sH_base = static_cast<uint32_t>(sH_base_tma);
const uint32_t sW_base = static_cast<uint32_t>(sW_base_tma);
// mbarrier PTX expects the 64-bit shared address, while ldmatrix below uses
// the narrowed 32-bit shared address form.
uint64_t mbar[STAGES];
#pragma unroll
for (int s = 0; s < STAGES; ++s)
mbar[s] = static_cast<int>(__cvta_generic_to_shared(mbar_base + 2 * s));
mbar[s] = __cvta_generic_to_shared(mbar_base + 2 * s);
Comment on lines +88 to +97

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these lines are for CUDA12.4 fix? Because some features in 12.8 and 12.4 are really different, the code dev in 12.8 may have problem in 12.4. When I previously developed linear-logp cuda kernel, I was using cuda 12.4 but can not compile some features in tma_utils.cuh by @Flink-ddd . I found the previous works were done in CUDA12.8, so I later switch in CUDA 12.8, which you can see the #Performance section in docs/operators/linear-logp.md.
I think we should stick our CUDA version to 12.8? Maybe we should discuss on this aligness.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these lines are for CUDA12.4 fix? Because some features in 12.8 and 12.4 are really different, the code dev in 12.8 may have problem in 12.4. When I previously developed linear-logp cuda kernel, I was using cuda 12.4 but can not compile some features in tma_utils.cuh by @Flink-ddd . I found the previous works were done in CUDA12.8, so I later switch in CUDA 12.8, which you can see the #Performance section in docs/operators/linear-logp.md. I think we should stick our CUDA version to 12.8? Maybe we should discuss on this aligness.

Good catch. This is about keeping the mbarrier shared-address operand width explicit, not changing the CUDA baseline. I added a short comment there. Happy to align separately on whether we want CUDA 12.8 as the standard perf baseline.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, the explicit operand width is fine.Makes sense, the explicit operand width is fine. Thanks for clarifying. Agreed on aligning separately on whether CUDA 12.8 should be our standard perf baseline, let's track that as its own discussion.


for (int r = tid; r < num_rows; r += WG_THREADS) {
sMax[r] = -CUDART_INF_F;
Expand All @@ -111,11 +115,11 @@ __global__ void fused_linear_logp_sm90_kernel(const __grid_constant__ CUtensorMa
auto issue_load = [&](int k, int col_base) {
const int buf = k % STAGES;
const int k_off = k * BK;
tma_2d_g2s(static_cast<int>(sH_base + buf * BM * BK * sizeof(nv_bfloat16)), &h_tmap, k_off,
row_base, mbar[buf]);
tma_2d_g2s(static_cast<int>(sW_base + buf * BN * BK * sizeof(nv_bfloat16)), &w_tmap, k_off,
col_base, mbar[buf]);
mbarrier_arrive_expect_tx(mbar[buf], tile_bytes);
tma_2d_g2s(sH_base_tma + buf * BM * BK * sizeof(nv_bfloat16), &h_tmap, k_off, row_base,
mbar[buf]);
tma_2d_g2s(sW_base_tma + buf * BN * BK * sizeof(nv_bfloat16), &w_tmap, k_off, col_base,
mbar[buf]);
};

int phase[STAGES];
Expand Down
29 changes: 18 additions & 11 deletions csrc/utils/tma_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cuda.h>
#include <cuda_bf16.h>
#include <cudaTypedefs.h>
#include <cstdint>
#include <iostream>

// Type Traits for TMA
Expand Down Expand Up @@ -51,33 +52,39 @@ inline void init_tensor_map(
}

// Device API
__device__ inline void mbarrier_init(int addr, int count) {
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(addr), "r"(count));
__device__ inline void mbarrier_init(uint64_t addr, int count) {
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "l"(addr), "r"(count));
}

__device__ inline void mbarrier_arrive(int addr) {
asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "r"(addr) : "memory");
__device__ inline void mbarrier_arrive(uint64_t addr) {
asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "l"(addr) : "memory");
}

__device__ inline void mbarrier_arrive_expect_tx(int addr, int size) {
__device__ inline void mbarrier_arrive_expect_tx(uint64_t addr, int size) {
asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;"
:: "r"(addr), "r"(size) : "memory");
:: "l"(addr), "r"(size) : "memory");
}

__device__ inline void mbarrier_wait(int mbar_addr, int phase) {
__device__ inline void mbarrier_wait(uint64_t mbar_addr, int phase) {
int ticks = 0x989680;
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}" :: "r"(mbar_addr), "r"(phase), "r"(ticks)
"}" :: "l"(mbar_addr), "r"(phase), "r"(ticks)
);
}

__device__ inline void tma_2d_g2s(int dst_smem_addr, const void *tmap_ptr, int x, int y, int mbar_addr) {
asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes "
__device__ inline void tma_2d_g2s(
uint64_t dst_smem_addr,
const void *tmap_ptr,
int x,
int y,
uint64_t mbar_addr
) {
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes "

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line changes the TMA destination from shared::cta.global.mbarrier to shared::cluster.global.tile.mbarrier. That's a scope/semantics change, not just a CUDA 12.4 syntax fix and we're not doing a cluster launch here. Can you confirm why this is correct in the non-cluster case? We ran the SM90 path on 8x H100 SXM and results looked fine, so I'm fairly confident, but I'd like the reasoning documented in this thread before merge.

"[%0], [%1, {%2, %3}], [%4];"
:: "r"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "r"(mbar_addr) : "memory");
:: "l"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "l"(mbar_addr) : "memory");
}
1 change: 1 addition & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nav:
- operators/README.md
- operators/fused-logp.md
- operators/linear-logp.md
- operators/linear-logp-tp-test.md
- operators/grpo-loss.md
- operators/ratio-kl.md
- operators/sampling.md
Expand Down
8 changes: 5 additions & 3 deletions docs/design/runtime-dispatch.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ logical type, and the registry selects the first available backend for the curre

| Platform | Priority |
| --- | --- |
| CUDA | SM90 fused LogP when available, CUDA generic, FlashInfer, Triton generic, PyTorch native |
| CUDA | CUDA generic LogP by default; experimental SM90 fused LogP only when explicitly enabled, FlashInfer, Triton generic, PyTorch native |
| ROCm | AITER, Triton generic, PyTorch native |
| CPU | PyTorch native |

For CUDA devices with compute capability 9.0 or newer, the registry inserts the SM90
LogP backend at the front of the CUDA priority list.
For CUDA devices with compute capability 9.0 or newer, the registry only inserts
the legacy SM90 LogP backend when `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1` is
set. The fused linear logp SM90 backend is gated separately and remains the
default linear logp backend when the extension is built on Hopper.
Comment thread
Flink-ddd marked this conversation as resolved.

## Relevant Files

Expand Down
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Every operator page should include:

- [Fused LogP](fused-logp.md)
- [Fused Linear LogP](linear-logp.md)
- [Fused Linear LogP TP Test Runbook](linear-logp-tp-test.md)
- [GRPO Loss](grpo-loss.md)
- [Policy Ratio + KL Penalty](ratio-kl.md)
- [Sampling](sampling.md)
Expand Down
4 changes: 2 additions & 2 deletions docs/operators/fused-logp.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ output = logp_op(logits, token_ids)

| Backend | Wrapper | Native symbol | Notes |
| --- | --- | --- | --- |
| CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | TMA-oriented path for Hopper-class GPUs. |
| CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | Experimental TMA-oriented path for 2D contiguous bf16 logits on Hopper-class GPUs. It is disabled by default and requires `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1`; otherwise the wrapper delegates to the CUDA generic fallback. |
| CUDA generic | `FusedLogpGenericOp` | `_C.fused_logp` | Generic compiled extension fallback. |
| PyTorch native | `NativeOp` | None | Baseline fallback path. |

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `logits` | `[N, V]` | `bfloat16` for SM90 path | Contiguous, on the target device. |
| `logits` | `[N, V]` | `bfloat16` for the experimental SM90 fast path; fp16/fp32 use generic fallback | Contiguous, on the target device for the experimental SM90 fast path. |
| `token_ids` / `labels` | `[N]` | Converted to `int32` | Same logical device as `logits`. |
| Output | `[N]` | Backend-defined tensor dtype | One selected log probability per row. |

Expand Down
Loading
Loading