Skip to content

[Frontend/template] add SDPA modules#214

Open
student-Jungmin wants to merge 4 commits intoPSAL-POSTECH:torch_v2.8from
student-Jungmin:feature/SDPA
Open

[Frontend/template] add SDPA modules#214
student-Jungmin wants to merge 4 commits intoPSAL-POSTECH:torch_v2.8from
student-Jungmin:feature/SDPA

Conversation

@student-Jungmin
Copy link

@student-Jungmin student-Jungmin commented Mar 2, 2026

This PR implements Scaled Dot-Product Attention (SDPA) modules to support the torch_v2.8 branch. The implementation includes the following key updates across 5 files:

First, I added code to overload the SDPA dispatcher function in PyTorchSimDevice/torch_openreg/openreg/init.py. On the frontend side, I implemented a new lowering pass for SDPA in mlir_lowering.py and introduced mlir_sdpa_template.py, which contains the dispatcher function, MLIR templates, and necessary helper classes.

Additionally, I enhanced mlir_template.py by improving the tile candidate generation logic and modifying def_sram_buffer() to support SRAM-only buffers Finally, I included test cases in tests/test_sdpa.py to verify the functional correctness of the entire SDPA implementation.

Important: Prerequisites for Execution
This PR has specific dependencies. Please ensure the following steps are taken before running the code; otherwise, it will fail to execute:
[1] LLVM Dependency: My previous llvm-project PR must be applied, and mlir-opt must be recompiled.
[2] Lowering Pass Configuration: The TestDmaFineGrained lowering pass must be excluded during execution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant