Skip to content

[AMD] Fix bugs about AMD FA kernel#1701

Merged
LeiWang1999 merged 6 commits intotile-ai:mainfrom
danielhua23:amd/adaption
Feb 16, 2026
Merged

[AMD] Fix bugs about AMD FA kernel#1701
LeiWang1999 merged 6 commits intotile-ai:mainfrom
danielhua23:amd/adaption

Conversation

@danielhua23
Copy link
Copy Markdown
Contributor

@danielhua23 danielhua23 commented Jan 20, 2026

As tilelang updates, I found some issues when I running AMD FA kernel, this pr is to propose the workaround for AMD FA example

  • issue1: when saving tuned kernel, AMD FA will go through kernel_global_source of CythonKernelAdapter, but that attr is lost
  • issue2: original tuned config like block_M = 32 and block_N = 32 does not work now, so remove it, specifically, the issues are
Layout infer conflict between acc_s and acc_s_cast in T.Parallel loop:
    loop Fragment([32, 32] -> [8], replicate: 2, thread: 256, ...)
    fragment Fragment([32, 32] -> [4], replicate: 1, thread: 256, ...)

and

Check failed: pb->value != 0 (0 vs. 0) : Divide by zero

are the two issues expected? thanks

Summary by CodeRabbit

  • Chores

    • Optimized AMD Flash Attention example: narrowed tuning options, improved kernel execution for correct softmax-related computation and reduced problematic configurations.
  • Refactor

    • Added a compatibility alias on the kernel adapter to mirror existing kernel source access, improving integration with callers.

Copilot AI review requested due to automatic review settings January 20, 2026 10:04
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jan 20, 2026

No actionable comments were generated in the recent review. 🎉


📝 Walkthrough

Walkthrough

Adds a compat alias kernel_global_source to CythonKernelAdapter. Updates the AMD flash-attention forward example: map TileLang dtypes to PyTorch when allocating tensors, narrow autotuner configs, add explicit softmax steps and a copy to avoid a layout conflict before the final GEMM, and small comment/format edits.

Changes

Cohort / File(s) Summary
Cython Adapter Alias
tilelang/jit/adapter/cython/adapter.py
Add public attribute kernel_global_source: str | None = None and set it to mirror device_kernel_source in __init__ and from_database.
AMD Flash Attention Example
examples/amd/example_amd_flash_attn_fwd.py
Use param.dtype.as_torch() when creating GPU tensors, narrow get_configs search (block_M/block_N and threads), insert explicit softmax/pre-GEMM exp + reduction steps, copy acc_sacc_s_cast to avoid layout conflict, and add explanatory comments.
Misc metadata (unchanged APIs)
manifest_file, requirements.txt, pyproject.toml
Minor metadata/requirements edits unrelated to public API.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Poem

🐰 In kernels snug I found my cue,
I mirrored source and hopped right through.
Softmax hummed, then copies made,
Tuned configs trimmed for a quicker trade.
Compile and run — a carrot parade! 🥕✨

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[AMD] Fix bugs about AMD FA kernel' is partially related to the changeset. It correctly identifies the AMD Flash Attention kernel fixes as a main change, but is somewhat vague—'bugs' and 'FA kernel' lack specificity about the actual issues addressed (missing kernel_global_source attribute and incompatible configuration values).
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes two issues with the AMD Flash Attention kernel:

  1. Adds the missing kernel_global_source attribute to CythonKernelAdapter which is required when saving/loading tuned kernels
  2. Removes problematic configuration parameters (block_M=32, block_N=32, threads=512) that don't work with the current implementation

Changes:

  • Added kernel_global_source attribute as an alias for device_kernel_source in CythonKernelAdapter for compatibility
  • Updated supply_tensors_gpu to properly map TileLang dtypes to PyTorch dtypes
  • Removed problematic tuning configurations and updated from while to T.While construct

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
tilelang/jit/adapter/cython/adapter.py Added kernel_global_source field and properly initialized it in both __init__ and from_database methods as an alias to device_kernel_source
examples/amd/example_amd_flash_attn_fwd.py Fixed dtype mapping, removed problematic configs, updated to use T.While construct, and improved comments

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/amd/example_amd_flash_attn_fwd.py Outdated
Comment on lines +13 to +28
# TileLang dtype to PyTorch dtype mapping
dtype_map = {
T.float16: torch.float16,
T.float32: torch.float32,
T.int32: torch.int32,
T.int64: torch.int64,
}

tensors = []
for param in params:
if hasattr(param, "shape") and hasattr(param, "dtype"):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
# Convert TileLang dtype to PyTorch dtype
torch_dtype = dtype_map.get(param.dtype, torch.float16)
tensor = torch.randn(shape, dtype=torch_dtype, device="cuda")
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manual dtype mapping could be simplified by using the built-in torch_dtype() method from KernelParam. Instead of maintaining a manual dtype_map dictionary, you can replace lines 14-19 and 27 with:

torch_dtype = param.torch_dtype()

This would eliminate the need for manual mapping and automatically handle all dtype conversions, including edge cases like float8 types that require special handling for HIP vs CUDA backends.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/amd/example_amd_flash_attn_fwd.py`:
- Around line 13-28: Replace the manual dtype mapping in supply_tensors_gpu: the
current dtype_map and lookup using param.dtype should be removed and replaced by
calling param.torch_dtype() to obtain the correct PyTorch dtype; update the code
that sets torch_dtype (and remove the dtype_map variable) so tensor creation
uses torch_dtype = param.torch_dtype() ensuring all KernelParam dtypes
(bfloat16, float8, etc.) are handled and no silent fallback to torch.float16
occurs.
🧹 Nitpick comments (1)
examples/amd/example_amd_flash_attn_fwd.py (1)

53-57: Add a comment explaining why specific configurations were removed.

The docstring mentions "avoiding problematic combinations" but doesn't explain the root cause. Per the PR description, block_M=32 and block_N=32 cause layout inference conflicts and divide-by-zero errors. Adding this context prevents future maintainers from re-introducing these values:

📝 Suggested improvement
 def get_configs():
-    """Generates configurations for the autotuner, avoiding problematic combinations."""
-    block_M = [64, 128, 256]
-    block_N = [64, 128, 256]
-    threads = [128, 256]
+    """Generates configurations for the autotuner, avoiding problematic combinations.
+    
+    Note: block_M/N=32 removed due to layout inference conflicts between acc_s and
+    acc_s_cast fragments. threads=512 removed to avoid divide-by-zero check failures.
+    """
+    block_M = [64, 128, 256]  # 32 causes layout conflict
+    block_N = [64, 128, 256]  # 32 causes layout conflict
+    threads = [128, 256]  # 512 causes divide-by-zero

Comment thread examples/amd/example_amd_flash_attn_fwd.py Outdated
Copy link
Copy Markdown
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Sorry that I forgot to submit review though I left some messages.

shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
# Convert TileLang dtype to PyTorch dtype
torch_dtype = dtype_map.get(param.dtype, torch.float16)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parm.dtype.as_torch()

bx = b_split

while bx < num_q_blocks:
with T.While(bx < num_q_blocks):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need to change while into T.While?

…plify loop condition in fast_flashattn function. The dtype mapping has been removed in favor of a direct conversion method, improving clarity and error handling.
@LeiWang1999 LeiWang1999 merged commit 110ef30 into tile-ai:main Feb 16, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants