diff --git a/tests/pytorch/test_qk_norm.py b/tests/pytorch/test_qk_norm.py index 873bd91863..b182d175e7 100644 --- a/tests/pytorch/test_qk_norm.py +++ b/tests/pytorch/test_qk_norm.py @@ -11,7 +11,8 @@ @pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"]) @pytest.mark.parametrize("attention_type", ["self", "cross"]) @pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) -def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None: +@pytest.mark.parametrize("params_dtype", [torch.float32, torch.bfloat16]) +def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps, params_dtype) -> None: """Test QK normalization functionality, module structure, and numerical behavior.""" hidden_size = 256 num_attention_heads = 8 @@ -26,6 +27,7 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non qk_norm_eps=qk_norm_eps, bias=False, device="cuda", + params_dtype=params_dtype, ).cuda() # Check module structure based on qk_norm_type parameter @@ -78,13 +80,11 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non # Create input tensors batch_size = 2 # Use a fixed batch size for testing - hidden_states = torch.randn( - seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 - ) + hidden_states = torch.randn(seq_len, batch_size, hidden_size, device="cuda", dtype=params_dtype) if attention_type == "cross": encoder_output = torch.randn( - seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + seq_len, batch_size, hidden_size, device="cuda", dtype=params_dtype ) else: encoder_output = None @@ -109,7 +109,7 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non if attention_type == "self": head_dim = hidden_size // num_attention_heads rotary_dim = head_dim // 2 - rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32) + rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=params_dtype) with torch.no_grad(): output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..d4a73ef64c 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -355,7 +355,7 @@ def __init__( } self.q_norm, self.k_norm = self._create_qk_norm_modules( - qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size + qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size, params_dtype ) qkv_parallel_mode = "column" if set_parallel_mode else None @@ -485,6 +485,7 @@ def _create_qk_norm_modules( device: Union[torch.device, str], seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, ) -> Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]: """ Create query and key normalization modules based on the specified normalization type. @@ -501,6 +502,8 @@ def _create_qk_norm_modules( Sequence length for L2Normalization optimization micro_batch_size : Optional[int], default = None Micro batch size for L2Normalization optimization + params_dtype : Optional[torch.dtype], default = None + Data type for the normalization modules Returns ------- @@ -524,11 +527,13 @@ def _create_qk_norm_modules( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) k_norm = RMSNorm( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) return q_norm, k_norm @@ -537,11 +542,13 @@ def _create_qk_norm_modules( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) k_norm = LayerNorm( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) return q_norm, k_norm