Skip to content

Commit 67d10cd

Browse files
Merge pull request #20 from ivandobskygithub/codex/fix-compilation-errors-in-tensor
Fix load_q_impl tensor templates
2 parents 80774da + 0e35b90 commit 67d10cd

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ struct CollectiveMainloopFwdSm90 {
119119
static constexpr int NumMmaThreadsQK = size(TiledMmaQK{});
120120
static constexpr int NumMmaThreads = size(TiledMmaPV{});
121121
static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup;
122+
static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
122123
static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0);
123124
static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0);
124125
static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
@@ -646,11 +647,11 @@ struct CollectiveMainloopFwdSm90 {
646647
}
647648
}
648649

649-
template <typename SharedStorage>
650+
template <typename SharedStorage, typename TensorQ, typename TensorQv>
650651
CUTLASS_DEVICE void load_q_impl(
651652
std::true_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage,
652653
SeqlenInfo_t const& seqlen_info, cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
653-
int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, Tensor sQ, Tensor sQv) {
654+
int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, TensorQ sQ, TensorQv sQv) {
654655

655656
int const m_block = get<0>(block_coord);
656657
int const bidh = get<1>(block_coord);
@@ -694,11 +695,11 @@ struct CollectiveMainloopFwdSm90 {
694695
}
695696
}
696697

697-
template <typename SharedStorage>
698+
template <typename SharedStorage, typename TensorQ, typename TensorQv>
698699
CUTLASS_DEVICE void load_q_impl(
699700
std::false_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage,
700701
SeqlenInfo_t const& seqlen_info, cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
701-
int thread_idx, bool is_varlen_q, int /*warp_idx_in_warpgroup*/, Tensor sQ, Tensor sQv) {
702+
int thread_idx, bool is_varlen_q, int /*warp_idx_in_warpgroup*/, TensorQ sQ, TensorQv sQv) {
702703

703704
int const m_block = get<0>(block_coord);
704705
int const bidh = get<1>(block_coord);

0 commit comments

Comments
 (0)