Skip to content

import flashqla to speedup gdn prefill#1295

Open
WANDY666 wants to merge 4 commits intomainfrom
pr-flashqla
Open

import flashqla to speedup gdn prefill#1295
WANDY666 wants to merge 4 commits intomainfrom
pr-flashqla

Conversation

@WANDY666
Copy link
Copy Markdown
Contributor

@WANDY666 WANDY666 commented May 8, 2026

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a FlashQLA backend dispatch for the chunk_gated_delta_rule in the Qwen3Next model, including a comprehensive parity test and benchmark suite. Key feedback highlights that the PyTorch and CUDA version checks (>= 2.8 and >= 12.8) appear to be placeholders for unreleased versions and should be corrected. Additionally, the scale parameter should be assigned its default value before the FlashQLA dispatch to prevent potential errors if the backend receives a null value.

Comment on lines +41 to +47
if (int(tv[0]), int(tv[1])) < (2, 8):
return None
cv = torch.version.cuda
if cv is None:
return None
cv_parts = cv.split(".")
if (int(cv_parts[0]), int(cv_parts[1])) < (12, 8):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The version checks for PyTorch (>= 2.8) and CUDA (>= 12.8) appear to be placeholders or typos, as these versions are not yet released (current stable versions are typically PyTorch 2.5/2.6 and CUDA 12.4/12.6). As written, this logic will disable the FlashQLA backend for almost all current environments. Please verify if these should be lower versions (e.g., PyTorch 2.4 and CUDA 12.1).

Comment on lines +219 to +233
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The scale parameter is passed to flashqla_fn before it is assigned its default value. In the fallback path (lines 257-258), scale defaults to k.shape[-1] ** -0.5 if it is None. If flash_qla.chunk_gated_delta_rule does not handle None for the scale argument, this will lead to incorrect results or a crash. You should move the default scale calculation before the FlashQLA dispatch logic.

Suggested change
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
if scale is None:
scale = k.shape[-1] ** -0.5
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant