-
Notifications
You must be signed in to change notification settings - Fork 42
[FEAT][kernels] Add tensor-parallel linear_logp path #189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5e040e3
31ba11b
4858bbd
81ed9b2
ae7402e
ddf65a7
355a2e0
2c0137d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| #include <cuda.h> | ||
| #include <cuda_bf16.h> | ||
| #include <cudaTypedefs.h> | ||
| #include <cstdint> | ||
| #include <iostream> | ||
|
|
||
| // Type Traits for TMA | ||
|
|
@@ -51,33 +52,39 @@ inline void init_tensor_map( | |
| } | ||
|
|
||
| // Device API | ||
| __device__ inline void mbarrier_init(int addr, int count) { | ||
| asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(addr), "r"(count)); | ||
| __device__ inline void mbarrier_init(uint64_t addr, int count) { | ||
| asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "l"(addr), "r"(count)); | ||
| } | ||
|
|
||
| __device__ inline void mbarrier_arrive(int addr) { | ||
| asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "r"(addr) : "memory"); | ||
| __device__ inline void mbarrier_arrive(uint64_t addr) { | ||
| asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "l"(addr) : "memory"); | ||
| } | ||
|
|
||
| __device__ inline void mbarrier_arrive_expect_tx(int addr, int size) { | ||
| __device__ inline void mbarrier_arrive_expect_tx(uint64_t addr, int size) { | ||
| asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;" | ||
| :: "r"(addr), "r"(size) : "memory"); | ||
| :: "l"(addr), "r"(size) : "memory"); | ||
| } | ||
|
|
||
| __device__ inline void mbarrier_wait(int mbar_addr, int phase) { | ||
| __device__ inline void mbarrier_wait(uint64_t mbar_addr, int phase) { | ||
| int ticks = 0x989680; | ||
| asm volatile( | ||
| "{\n" | ||
| ".reg .pred P1;\n" | ||
| "LAB_WAIT:\n" | ||
| "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n" | ||
| "@!P1 bra.uni LAB_WAIT;\n" | ||
| "}" :: "r"(mbar_addr), "r"(phase), "r"(ticks) | ||
| "}" :: "l"(mbar_addr), "r"(phase), "r"(ticks) | ||
| ); | ||
| } | ||
|
|
||
| __device__ inline void tma_2d_g2s(int dst_smem_addr, const void *tmap_ptr, int x, int y, int mbar_addr) { | ||
| asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes " | ||
| __device__ inline void tma_2d_g2s( | ||
| uint64_t dst_smem_addr, | ||
| const void *tmap_ptr, | ||
| int x, | ||
| int y, | ||
| uint64_t mbar_addr | ||
| ) { | ||
| asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line changes the TMA destination from shared::cta.global.mbarrier to shared::cluster.global.tile.mbarrier. That's a scope/semantics change, not just a CUDA 12.4 syntax fix and we're not doing a cluster launch here. Can you confirm why this is correct in the non-cluster case? We ran the SM90 path on 8x H100 SXM and results looked fine, so I'm fairly confident, but I'd like the reasoning documented in this thread before merge. |
||
| "[%0], [%1, {%2, %3}], [%4];" | ||
| :: "r"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "r"(mbar_addr) : "memory"); | ||
| :: "l"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "l"(mbar_addr) : "memory"); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these lines are for CUDA12.4 fix? Because some features in 12.8 and 12.4 are really different, the code dev in 12.8 may have problem in 12.4. When I previously developed linear-logp cuda kernel, I was using cuda 12.4 but can not compile some features in tma_utils.cuh by @Flink-ddd . I found the previous works were done in CUDA12.8, so I later switch in CUDA 12.8, which you can see the #Performance section in docs/operators/linear-logp.md.
I think we should stick our CUDA version to 12.8? Maybe we should discuss on this aligness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. This is about keeping the mbarrier shared-address operand width explicit, not changing the CUDA baseline. I added a short comment there. Happy to align separately on whether we want CUDA 12.8 as the standard perf baseline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, the explicit operand width is fine.Makes sense, the explicit operand width is fine. Thanks for clarifying. Agreed on aligning separately on whether CUDA 12.8 should be our standard perf baseline, let's track that as its own discussion.