diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 5adacd0ce2..1d7094a5fa 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -176,7 +176,7 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( "MXFPX is not supported for the selected weight combination"); } - if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { + if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 110) { bool const dynamic_cga = gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; bool const swap_ab = hopper_input.swap_ab; @@ -204,7 +204,8 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( gemm_config.epilogue_schedule, dynamic_cga, swap_ab); selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); - } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { + } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90 || + Arch::kMinComputeCapability == 110) { using EpilogueSchedule = void; // These are hardcoded in the launcher constexpr bool dynamic_cga = false; auto selected_func = @@ -225,10 +226,12 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( template constexpr bool are_tile_shapes_supported_sm100() { - // We use a runtime cluster shape for SM100, so we only support 1x1x1 and 2x1x1 cluster shapes. - if (cute::size<0>(ClusterShape{}) > 2 || cute::size<1>(ClusterShape{}) != 1 || - cute::size<2>(ClusterShape{}) != 1) { - return false; + if constexpr (Arch::kMinComputeCapability != 110) { + // We use a runtime cluster shape for SM100, so we only support 1x1x1 and 2x1x1 cluster shapes. + if (cute::size<0>(ClusterShape{}) > 2 || cute::size<1>(ClusterShape{}) != 1 || + cute::size<2>(ClusterShape{}) != 1) { + return false; + } } using namespace cute;