Skip to content

GPU hang with Triton kernel using Warp Specialization #38082

@Corendos

Description

@Corendos

I've encountered a hang bug while trying to run a Triton kernel using warp specialization.

Originally it was using a complex batched matmul kernel but I managed to reduce the hanging kernel to around 20 lines.

Here is the kernel (tried on RTX5090):

module {
  tt.func public @tma_kernel(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}) {
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %c64_i32 = arith.constant 64 : i32
    %c64_i64 = arith.constant 64 : i64
    %c1_i64 = arith.constant 1 : i64

    %desc0 = tt.make_tensor_descriptor %arg0, [%c64_i32, %c64_i32], [%c64_i64, %c1_i64] : <f16>, <tensor<64x64xf16>>
    %desc1 = tt.make_tensor_descriptor %arg1, [%c64_i32, %c64_i32], [%c64_i64, %c1_i64] : <f16>, <tensor<64x64xf16>>
    scf.for %i = %c0_i32 to %c2_i32 step %c1_i32 : i32 {
      %8 = arith.muli %c0_i32, %c64_i32 : i32
      %9 = arith.muli %c0_i32, %c64_i32 : i32

      %10 = tt.descriptor_load %desc0[%8, %9] : !tt.tensordesc<tensor<64x64xf16>> -> tensor<64x64xf16>
      tt.descriptor_store %desc1[%8, %9], %10 : !tt.tensordesc<tensor<64x64xf16>>, tensor<64x64xf16>
      scf.yield
    } {tt.flatten, tt.warp_specialize}
    tt.return
  }
}

This example is voluntarily weird (useless loop) but it's the only way to trigger the bug I've found.

Here are my observations:

  • If you set the loop upper-bound to 1 instead of 2 (by changing the line %c2_i32 = arith.constant 2 : i32 to %c2_i32 = arith.constant 1 : i32, there is no more hanging
  • If you remove the tt.warp_specialize attribute on the loop, no more hanging

Note that this kernel runs fine with the python version of Triton (3.6.0) that's why I'm opening an issue here.

If you want to reproduce, you can use this branch: https://github.com/Corendos/xla/tree/corendos/warp-specialize-hang
I added a test to xla/service/gpu/tests/gpu_triton_custom_call_test.cc

If you need any additional information, feel free to ping me !

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions