[AMD] Fix bugs about AMD FA kernel#1701
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
No actionable comments were generated in the recent review. 🎉 📝 WalkthroughWalkthroughAdds a compat alias Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR fixes two issues with the AMD Flash Attention kernel:
- Adds the missing
kernel_global_sourceattribute toCythonKernelAdapterwhich is required when saving/loading tuned kernels - Removes problematic configuration parameters (block_M=32, block_N=32, threads=512) that don't work with the current implementation
Changes:
- Added
kernel_global_sourceattribute as an alias fordevice_kernel_sourcein CythonKernelAdapter for compatibility - Updated
supply_tensors_gputo properly map TileLang dtypes to PyTorch dtypes - Removed problematic tuning configurations and updated from
whiletoT.Whileconstruct
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tilelang/jit/adapter/cython/adapter.py | Added kernel_global_source field and properly initialized it in both __init__ and from_database methods as an alias to device_kernel_source |
| examples/amd/example_amd_flash_attn_fwd.py | Fixed dtype mapping, removed problematic configs, updated to use T.While construct, and improved comments |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # TileLang dtype to PyTorch dtype mapping | ||
| dtype_map = { | ||
| T.float16: torch.float16, | ||
| T.float32: torch.float32, | ||
| T.int32: torch.int32, | ||
| T.int64: torch.int64, | ||
| } | ||
|
|
||
| tensors = [] | ||
| for param in params: | ||
| if hasattr(param, "shape") and hasattr(param, "dtype"): | ||
| # Force creation on GPU device | ||
| shape = [int(s) for s in param.shape] | ||
| tensor = torch.randn(shape, dtype=param.dtype, device="cuda") | ||
| # Convert TileLang dtype to PyTorch dtype | ||
| torch_dtype = dtype_map.get(param.dtype, torch.float16) | ||
| tensor = torch.randn(shape, dtype=torch_dtype, device="cuda") |
There was a problem hiding this comment.
The manual dtype mapping could be simplified by using the built-in torch_dtype() method from KernelParam. Instead of maintaining a manual dtype_map dictionary, you can replace lines 14-19 and 27 with:
torch_dtype = param.torch_dtype()
This would eliminate the need for manual mapping and automatically handle all dtype conversions, including edge cases like float8 types that require special handling for HIP vs CUDA backends.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/amd/example_amd_flash_attn_fwd.py`:
- Around line 13-28: Replace the manual dtype mapping in supply_tensors_gpu: the
current dtype_map and lookup using param.dtype should be removed and replaced by
calling param.torch_dtype() to obtain the correct PyTorch dtype; update the code
that sets torch_dtype (and remove the dtype_map variable) so tensor creation
uses torch_dtype = param.torch_dtype() ensuring all KernelParam dtypes
(bfloat16, float8, etc.) are handled and no silent fallback to torch.float16
occurs.
🧹 Nitpick comments (1)
examples/amd/example_amd_flash_attn_fwd.py (1)
53-57: Add a comment explaining why specific configurations were removed.The docstring mentions "avoiding problematic combinations" but doesn't explain the root cause. Per the PR description,
block_M=32andblock_N=32cause layout inference conflicts and divide-by-zero errors. Adding this context prevents future maintainers from re-introducing these values:📝 Suggested improvement
def get_configs(): - """Generates configurations for the autotuner, avoiding problematic combinations.""" - block_M = [64, 128, 256] - block_N = [64, 128, 256] - threads = [128, 256] + """Generates configurations for the autotuner, avoiding problematic combinations. + + Note: block_M/N=32 removed due to layout inference conflicts between acc_s and + acc_s_cast fragments. threads=512 removed to avoid divide-by-zero check failures. + """ + block_M = [64, 128, 256] # 32 causes layout conflict + block_N = [64, 128, 256] # 32 causes layout conflict + threads = [128, 256] # 512 causes divide-by-zero
LeiWang1999
left a comment
There was a problem hiding this comment.
LGTM, Sorry that I forgot to submit review though I left some messages.
| shape = [int(s) for s in param.shape] | ||
| tensor = torch.randn(shape, dtype=param.dtype, device="cuda") | ||
| # Convert TileLang dtype to PyTorch dtype | ||
| torch_dtype = dtype_map.get(param.dtype, torch.float16) |
| bx = b_split | ||
|
|
||
| while bx < num_q_blocks: | ||
| with T.While(bx < num_q_blocks): |
There was a problem hiding this comment.
why we need to change while into T.While?
…plify loop condition in fast_flashattn function. The dtype mapping has been removed in favor of a direct conversion method, improving clarity and error handling.
As tilelang updates, I found some issues when I running AMD FA kernel, this pr is to propose the workaround for AMD FA example
and
are the two issues expected? thanks
Summary by CodeRabbit
Chores
Refactor