[Frontend/template] add SDPA modules#214
Open
student-Jungmin wants to merge 4 commits intoPSAL-POSTECH:torch_v2.8from
Open
[Frontend/template] add SDPA modules#214student-Jungmin wants to merge 4 commits intoPSAL-POSTECH:torch_v2.8from
student-Jungmin wants to merge 4 commits intoPSAL-POSTECH:torch_v2.8from
Conversation
614f144 to
3d9cb38
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR implements Scaled Dot-Product Attention (SDPA) modules to support the torch_v2.8 branch. The implementation includes the following key updates across 5 files:
First, I added code to overload the SDPA dispatcher function in PyTorchSimDevice/torch_openreg/openreg/init.py. On the frontend side, I implemented a new lowering pass for SDPA in mlir_lowering.py and introduced mlir_sdpa_template.py, which contains the dispatcher function, MLIR templates, and necessary helper classes.
Additionally, I enhanced mlir_template.py by improving the tile candidate generation logic and modifying def_sram_buffer() to support SRAM-only buffers Finally, I included test cases in tests/test_sdpa.py to verify the functional correctness of the entire SDPA implementation.
Important: Prerequisites for Execution
This PR has specific dependencies. Please ensure the following steps are taken before running the code; otherwise, it will fail to execute:
[1] LLVM Dependency: My previous llvm-project PR must be applied, and mlir-opt must be recompiled.
[2] Lowering Pass Configuration: The TestDmaFineGrained lowering pass must be excluded during execution.