MoEGEMM as an extension of GroupGEMM#520
Closed
sanchitintel wants to merge 8 commits intointel:mainfrom
Closed
Conversation
This comment was marked as off-topic.
This comment was marked as off-topic.
sanchitintel
commented
Sep 26, 2025
2 tasks
|
@sanchitintel how is the progress of the PR? |
ef117e2 to
1545982
Compare
Author
|
Hi @airMeng, This PR doesn't have updated code with performance optimizations, which I otherwise have locally. It does have the updated API interface in the example, though. However, if @Antonyvance & the cutlass team wouldn't want to have MoE GEMM as a separate kernel, then it's better to port it to the Can you please explain why you asked? Thanks |
|
Close this PR which didn't update >90 days. The project has changed a lot, this PR is not applicable any more, create a new PR please in case you need it. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Background
When multiple GEMMs are to be computed, each with its own canonical
A,B,C,Dmatrices, GroupGEMM is useful for ensuring high GPU utilization & preventing launch overhead that'd otherwise occur for multiple GEMM kernel launches. In cutlass, the vanilla GroupGEMM uses a persistent kernel approach - the number of workgroups launched are equal to the number of Xe cores, and they loop through until they have work, (in this case, work, is the mainloop to compute one of the output tiles of any one of the GEMMs we try to compute with the GroupGEMM API).For Mixture of Experts used in Deep Learning models such as LLMs, the MoE GEMM use-case is something like this - each
expert(corresponding to agroup) has an associatedweightsizedN * K, which essentially a column-majorBmatrix. All theBmatrices are contiguous w.r.t. each other, i.e. their total size isnum_groups * N * K.N, Kare compile-time constants.Mfor each group is variable. AllAmatrices are also contiguous w.r.t. each other. Each set of tokens routed to an expert makes up theAmatrix for that group.MoEGEMMseems to be a natural candidate for leveraging GroupGEMM.The problem
The cutlass GroupGEMM API is generic in that it requires pointers of
A,B,C,Dtensors pertaining to each group.For launching the kernel, the CPU needs to provide a array of these GPU pointers (that array is also on the GPU).
However, for practical use-cases such as Mixture of Experts (each GroupGEMM
groupcorresponds to oneMoEexpert), such lists can't be conveniently pre-computed in advance (it's indeed possible to create it at the beginning of the kernel, and then synchronize across all workgroups, but that code can't be a part of generic Group GEMM).Solution proposed in this PR
Provide only the base
A,B,C,Dpointers, and also passN,K, so that the canonicalA,B,C,Dmatrices' pointers for each group can be computed on-the-fly (a prefix sum algorithm to compute a cumulative sum ofMmight help but based on our experimentation, it doesn't seem to make much difference, as smallMcase is memory-bound, anyway).To have minimal changes from the existing code, pass lists sized one instead of lists with size equal to the number of groups, as otherwise happens in the default case.
The PR adds a new kernel & a tile scheduler for MoEGEMM, while reusing existing MMA & epilogue collectives (but with modified code for
A,B,C,Dpointer computation).We could instead add a template parameter to make these changes in the existing kernels and also use
if constexprto separate it from the default GroupGEMM. While the current implementation in this PR introduces duplication, the alternative would make the code messier.Performance
With small
Mdimension for eachGEMM problem, the performance is worse than that of largeMdimension due to lower arithmetic intensity in the former case, but it's better than launching a separate kernel for each GEMM problem.Caveat
The example just portrays one way to use the API.
Also, it has mostly been copy-pasted from an existing example, so it can be revised further.