1313See the License for the specific language governing permissions and
1414limitations under the License.
1515"""
16-
16+ import paddle
17+ paddle .compat .enable_torch_proxy ()
1718import einops
1819import pytest
1920import torch
21+ import numpy as np
2022from tests .test_helpers .sink_attention_reference import sink_attention_unified
2123
2224import flashinfer
2325from flashinfer .utils import get_compute_capability
2426
2527
26- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
27- @pytest .mark .parametrize ("batch_size" , [1 , 4 , 16 ])
28+ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
29+ # @pytest.mark.parametrize("batch_size", [1, 4, 16])
30+ # @pytest.mark.parametrize("page_size", [32])
31+ # @pytest.mark.parametrize("seq_len", [32, 128, 1024])
32+ # @pytest.mark.parametrize("num_qo_heads", [32])
33+ # @pytest.mark.parametrize("num_kv_heads", [8, 32])
34+ # @pytest.mark.parametrize("head_dim", [64, 128])
35+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
36+ @pytest .mark .parametrize ("batch_size" , [4 ])
2837@pytest .mark .parametrize ("page_size" , [32 ])
29- @pytest .mark .parametrize ("seq_len" , [32 , 128 , 1024 ])
38+ @pytest .mark .parametrize ("seq_len" , [32 ])
3039@pytest .mark .parametrize ("num_qo_heads" , [32 ])
31- @pytest .mark .parametrize ("num_kv_heads" , [8 , 32 ])
32- @pytest .mark .parametrize ("head_dim" , [64 , 128 ])
40+ @pytest .mark .parametrize ("num_kv_heads" , [8 ])
41+ @pytest .mark .parametrize ("head_dim" , [64 ])
3342def test_blackwell_trtllm_gen_decode_attention_sink (
3443 dtype ,
3544 batch_size ,
@@ -39,11 +48,11 @@ def test_blackwell_trtllm_gen_decode_attention_sink(
3948 num_kv_heads ,
4049 head_dim ,
4150):
42- compute_capability = get_compute_capability (torch .device (device = "cuda" ))
43- if compute_capability [0 ] in [11 , 12 ]:
44- pytest .skip ("trtllm-gen does not support SM110/SM120/SM121 GPUs." )
45- seed = 0
46- torch .manual_seed (seed )
51+ # compute_capability = get_compute_capability(torch.device(device="cuda"))
52+ # if compute_capability[0] in [11, 12]:
53+ # pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
54+ # seed = 0
55+ # torch.manual_seed(seed)
4756 device = "cuda:0"
4857
4958 seq_lens = torch .full ((batch_size ,), seq_len , dtype = torch .int32 , device = device )
@@ -121,16 +130,24 @@ def test_blackwell_trtllm_gen_decode_attention_sink(
121130 else :
122131 raise ValueError (f"Unsupported dtype: { dtype } " )
123132
124- torch .testing .assert_close (o_ref , output , atol = atol , rtol = rtol )
133+ # torch.testing.assert_close(o_ref, output, atol=atol, rtol=rtol)
134+ np .testing .assert_allclose (o_ref .float (), output .float (), atol = atol , rtol = rtol )
125135
126136
127- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
128- @pytest .mark .parametrize ("batch_size" , [1 , 4 , 16 ])
137+ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
138+ # @pytest.mark.parametrize("batch_size", [1, 4, 16])
139+ # @pytest.mark.parametrize("page_size", [32])
140+ # @pytest.mark.parametrize("seq_len", [32, 128, 1024])
141+ # @pytest.mark.parametrize("num_qo_heads", [32])
142+ # @pytest.mark.parametrize("num_kv_heads", [8, 32])
143+ # @pytest.mark.parametrize("head_dim", [64, 128])
144+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
145+ @pytest .mark .parametrize ("batch_size" , [1 ])
129146@pytest .mark .parametrize ("page_size" , [32 ])
130- @pytest .mark .parametrize ("seq_len" , [32 , 128 , 1024 ])
147+ @pytest .mark .parametrize ("seq_len" , [32 ])
131148@pytest .mark .parametrize ("num_qo_heads" , [32 ])
132- @pytest .mark .parametrize ("num_kv_heads" , [8 , 32 ])
133- @pytest .mark .parametrize ("head_dim" , [64 , 128 ])
149+ @pytest .mark .parametrize ("num_kv_heads" , [8 ])
150+ @pytest .mark .parametrize ("head_dim" , [64 ])
134151def test_blackwell_trtllm_gen_context_attention_sink (
135152 dtype ,
136153 batch_size ,
@@ -140,11 +157,12 @@ def test_blackwell_trtllm_gen_context_attention_sink(
140157 num_kv_heads ,
141158 head_dim ,
142159):
143- compute_capability = get_compute_capability (torch .device (device = "cuda" ))
144- if compute_capability [0 ] in [11 , 12 ]:
145- pytest .skip ("trtllm-gen does not support SM110/SM120/SM121 GPUs." )
160+ # compute_capability = get_compute_capability(torch.device(device="cuda"))
161+ # if compute_capability[0] in [11, 12]:
162+ # pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
146163 seed = 0
147- torch .manual_seed (seed )
164+ paddle .seed (seed )
165+ # torch.manual_seed(seed)
148166 device = "cuda:0"
149167
150168 seq_lens = torch .full ((batch_size ,), seq_len , dtype = torch .int32 , device = device )
@@ -233,4 +251,5 @@ def test_blackwell_trtllm_gen_context_attention_sink(
233251 else :
234252 raise ValueError (f"Unsupported dtype: { dtype } " )
235253
236- torch .testing .assert_close (o_ref , output , atol = atol , rtol = rtol )
254+ ref_o = o_ref .float ().numpy ()
255+ np .testing .assert_allclose (ref_o , paddle_o , atol = atol , rtol = rtol )
0 commit comments