I've encountered a hang bug while trying to run a Triton kernel using warp specialization.
Originally it was using a complex batched matmul kernel but I managed to reduce the hanging kernel to around 20 lines.
Here is the kernel (tried on RTX5090):
module {
tt.func public @tma_kernel(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%desc0 = tt.make_tensor_descriptor %arg0, [%c64_i32, %c64_i32], [%c64_i64, %c1_i64] : <f16>, <tensor<64x64xf16>>
%desc1 = tt.make_tensor_descriptor %arg1, [%c64_i32, %c64_i32], [%c64_i64, %c1_i64] : <f16>, <tensor<64x64xf16>>
scf.for %i = %c0_i32 to %c2_i32 step %c1_i32 : i32 {
%8 = arith.muli %c0_i32, %c64_i32 : i32
%9 = arith.muli %c0_i32, %c64_i32 : i32
%10 = tt.descriptor_load %desc0[%8, %9] : !tt.tensordesc<tensor<64x64xf16>> -> tensor<64x64xf16>
tt.descriptor_store %desc1[%8, %9], %10 : !tt.tensordesc<tensor<64x64xf16>>, tensor<64x64xf16>
scf.yield
} {tt.flatten, tt.warp_specialize}
tt.return
}
}
This example is voluntarily weird (useless loop) but it's the only way to trigger the bug I've found.
Here are my observations:
- If you set the loop upper-bound to 1 instead of 2 (by changing the line
%c2_i32 = arith.constant 2 : i32 to %c2_i32 = arith.constant 1 : i32, there is no more hanging
- If you remove the
tt.warp_specialize attribute on the loop, no more hanging
Note that this kernel runs fine with the python version of Triton (3.6.0) that's why I'm opening an issue here.
If you want to reproduce, you can use this branch: https://github.com/Corendos/xla/tree/corendos/warp-specialize-hang
I added a test to xla/service/gpu/tests/gpu_triton_custom_call_test.cc
If you need any additional information, feel free to ping me !
I've encountered a hang bug while trying to run a Triton kernel using warp specialization.
Originally it was using a complex batched matmul kernel but I managed to reduce the hanging kernel to around 20 lines.
Here is the kernel (tried on RTX5090):
This example is voluntarily weird (useless loop) but it's the only way to trigger the bug I've found.
Here are my observations:
%c2_i32 = arith.constant 2 : i32to%c2_i32 = arith.constant 1 : i32, there is no more hangingtt.warp_specializeattribute on the loop, no more hangingNote that this kernel runs fine with the python version of Triton (3.6.0) that's why I'm opening an issue here.
If you want to reproduce, you can use this branch: https://github.com/Corendos/xla/tree/corendos/warp-specialize-hang
I added a test to
xla/service/gpu/tests/gpu_triton_custom_call_test.ccIf you need any additional information, feel free to ping me !