@@ -119,6 +119,7 @@ struct CollectiveMainloopFwdSm90 {
119119 static constexpr int NumMmaThreadsQK = size(TiledMmaQK{});
120120 static constexpr int NumMmaThreads = size(TiledMmaPV{});
121121 static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup;
122+ static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
122123 static_assert (NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0 );
123124 static_assert (NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0 );
124125 static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
@@ -646,11 +647,11 @@ struct CollectiveMainloopFwdSm90 {
646647 }
647648 }
648649
649- template <typename SharedStorage>
650+ template <typename SharedStorage, typename TensorQ, typename TensorQv >
650651 CUTLASS_DEVICE void load_q_impl (
651652 std::true_type /* UseTmaQ*/ , Params const & params, SharedStorage &shared_storage,
652653 SeqlenInfo_t const & seqlen_info, cute::tuple<int32_t , int32_t , int32_t , int32_t > block_coord,
653- int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, Tensor sQ , Tensor sQv ) {
654+ int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, TensorQ sQ , TensorQv sQv ) {
654655
655656 int const m_block = get<0 >(block_coord);
656657 int const bidh = get<1 >(block_coord);
@@ -694,11 +695,11 @@ struct CollectiveMainloopFwdSm90 {
694695 }
695696 }
696697
697- template <typename SharedStorage>
698+ template <typename SharedStorage, typename TensorQ, typename TensorQv >
698699 CUTLASS_DEVICE void load_q_impl (
699700 std::false_type /* UseTmaQ*/ , Params const & params, SharedStorage &shared_storage,
700701 SeqlenInfo_t const & seqlen_info, cute::tuple<int32_t , int32_t , int32_t , int32_t > block_coord,
701- int thread_idx, bool is_varlen_q, int /* warp_idx_in_warpgroup*/ , Tensor sQ , Tensor sQv ) {
702+ int thread_idx, bool is_varlen_q, int /* warp_idx_in_warpgroup*/ , TensorQ sQ , TensorQv sQv ) {
702703
703704 int const m_block = get<0 >(block_coord);
704705 int const bidh = get<1 >(block_coord);
0 commit comments