Skip to content

pass params_dtype to qk_norm creation#2718

Open
pstjohn wants to merge 2 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/qk-norm-dtype
Open

pass params_dtype to qk_norm creation#2718
pstjohn wants to merge 2 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/qk-norm-dtype

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 28, 2026

Previously layers would fail with

            assert (
>               query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
E           AssertionError: Queries, keys and values must have the same data type!

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py:1063: AssertionError

if you created a layer with dtype != float32. This ensures the dtype of the layernorm layers match those of the base attention layer.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 28, 2026

Greptile Summary

Fixed dtype mismatch bug in QK normalization that caused assertion failures when creating MultiheadAttention layers with dtype != float32.

Changes made:

  • Added params_dtype parameter to _create_qk_norm_modules method in multi_head_attention.py
  • Passed params_dtype to RMSNorm and LayerNorm constructors (L2Normalization correctly excluded as it's parameter-free)
  • Updated test_qk_norm_functionality to parameterize over params_dtype (float32, bfloat16)
  • Updated test input tensors to use the specified dtype

How it fixes the bug:
Previously, RMSNorm and LayerNorm were created without specifying params_dtype, causing them to default to float32. When the base attention layer used bfloat16, the queries/keys/values would have mismatched dtypes, triggering the assertion: "Queries, keys and values must have the same data type!"

The fix ensures normalization layers inherit the same dtype as the parent MultiheadAttention module.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The fix is minimal, targeted, and well-tested. It addresses a specific dtype mismatch bug that caused assertion failures when creating layers with dtype != float32. The implementation correctly passes params_dtype to RMSNorm and LayerNorm constructors, and the test coverage has been expanded to verify the fix works with both float32 and bfloat16
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/multi_head_attention.py Added params_dtype parameter to _create_qk_norm_modules method and passed it to RMSNorm/LayerNorm constructors to fix dtype mismatch
tests/pytorch/test_qk_norm.py Added params_dtype test parameterization (float32, bfloat16) and updated input tensors to use the correct dtype

Last reviewed commit: 37eaccf

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yaox12
Copy link
Member

yaox12 commented Mar 2, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants