Closed
Conversation
|
@rolandschulz, could you help review? Thx very much. The context is we are integrating flash-attention-2 kernel to Hugging Face, so need align API w/ CUDA to make it easy to integrate, thx, |
|
Need to figure out how to place python package, torch extension and associated triton kernel. Let's hold this since it requires a redesign. |
|
We have refined Flash Attention implementation, if you still need this PR, please update it bases on the latest source code. |
Author
Got it, thx! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR provides a template for a flash-attn Python API. Currently, based on a sycl-tla kernel, I have reproduced part of flash-attn’s functionality, such as fwd and varlen_fwd.
I see that the current implementation of the kernel does not expose an external interface, and the test files are quite limited in scope. If possible, perhaps we can jointly maintain a common API, similar to what Dao-AILab/flash-attention does. Subsequent unit tests can be based on
test_flash_attn.pyto remain consistent with the CUDA official interface.During the reproduction I found some edge-case issues when running tests. As a result, I modified
xe_flash_attn_prefill_epilogue.hppandxe_flash_attn_prefill.hpp.If you are interested, we can discuss this further. If there are any issues with the current code, please let me know. Thanks!
Current method to build the API:
After that, you can run tests using the same import statements as in
test_flash_attn.py.