Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/pytorch/test_qk_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
@pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None:
@pytest.mark.parametrize("params_dtype", [torch.float32, torch.bfloat16])
def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps, params_dtype) -> None:
"""Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size = 256
num_attention_heads = 8
Expand All @@ -26,6 +27,7 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non
qk_norm_eps=qk_norm_eps,
bias=False,
device="cuda",
params_dtype=params_dtype,
).cuda()

# Check module structure based on qk_norm_type parameter
Expand Down Expand Up @@ -78,13 +80,11 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non

# Create input tensors
batch_size = 2 # Use a fixed batch size for testing
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
hidden_states = torch.randn(seq_len, batch_size, hidden_size, device="cuda", dtype=params_dtype)

if attention_type == "cross":
encoder_output = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
seq_len, batch_size, hidden_size, device="cuda", dtype=params_dtype
)
else:
encoder_output = None
Expand All @@ -109,7 +109,7 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non
if attention_type == "self":
head_dim = hidden_size // num_attention_heads
rotary_dim = head_dim // 2
rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)
rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=params_dtype)

with torch.no_grad():
output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb)
Expand Down
9 changes: 8 additions & 1 deletion transformer_engine/pytorch/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def __init__(
}

self.q_norm, self.k_norm = self._create_qk_norm_modules(
qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size
qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size, params_dtype
)

qkv_parallel_mode = "column" if set_parallel_mode else None
Expand Down Expand Up @@ -485,6 +485,7 @@ def _create_qk_norm_modules(
device: Union[torch.device, str],
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
) -> Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]:
"""
Create query and key normalization modules based on the specified normalization type.
Expand All @@ -501,6 +502,8 @@ def _create_qk_norm_modules(
Sequence length for L2Normalization optimization
micro_batch_size : Optional[int], default = None
Micro batch size for L2Normalization optimization
params_dtype : Optional[torch.dtype], default = None
Data type for the normalization modules

Returns
-------
Expand All @@ -524,11 +527,13 @@ def _create_qk_norm_modules(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
params_dtype=params_dtype,
)
k_norm = RMSNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
params_dtype=params_dtype,
)
return q_norm, k_norm

Expand All @@ -537,11 +542,13 @@ def _create_qk_norm_modules(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
params_dtype=params_dtype,
)
k_norm = LayerNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
params_dtype=params_dtype,
)
return q_norm, k_norm

Expand Down
Loading