@@ -60,6 +60,61 @@ def kernel(x, pred, i, j, H: cl.Constant[int], W: cl.Constant[int]):
6060 )
6161 self .check_ptx_source (kernel , expect )
6262
63+ @require_blackwell_cc100 ()
64+ @pytest .mark .parametrize (
65+ "group,expect_group" ,
66+ (
67+ (cl .CTAGroup .CTA_1 , "cta_group::1" ),
68+ (cl .CTAGroup .CTA_2 , "cta_group::2" ),
69+ ),
70+ )
71+ def test_shared_cluster_group (self , group , expect_group ):
72+ @cl .kernel
73+ def kernel (x , pred , i , j , H : cl .Constant [int ], W : cl .Constant [int ]):
74+ tensor_map = cl .tensor_map_tiled (x , (H , W )).as_opaque_ptr ()
75+ smem = cl .shared_array (shape = (H * W ,), dtype = cl .int32 , alignment = 512 )
76+ smem = cl .map_shared_to_cluster (smem .get_base_pointer (), 0 )
77+ mbar = cl .shared_array (1 , cl .mbarrier , alignment = 8 ).get_base_pointer ()
78+
79+ cl .cp_async_bulk_tensor_global_to_shared (
80+ tensor_map ,
81+ (i , j ),
82+ smem ,
83+ mbar ,
84+ group = group ,
85+ )
86+
87+ self .check_ptx_source (
88+ kernel ,
89+ "cp.async.bulk.tensor.2d.shared::cluster.global" ,
90+ expect_group ,
91+ )
92+
93+ @require_blackwell_cc100 ()
94+ def test_shared_cluster_group_with_predicate_and_multicast (self ):
95+ @cl .kernel
96+ def kernel (x , pred , i , j , H : cl .Constant [int ], W : cl .Constant [int ]):
97+ tensor_map = cl .tensor_map_tiled (x , (H , W )).as_opaque_ptr ()
98+ smem = cl .shared_array (shape = (H * W ,), dtype = cl .int32 , alignment = 512 )
99+ smem = cl .map_shared_to_cluster (smem .get_base_pointer (), 0 )
100+ mbar = cl .shared_array (1 , cl .mbarrier , alignment = 8 ).get_base_pointer ()
101+
102+ cl .cp_async_bulk_tensor_global_to_shared (
103+ tensor_map ,
104+ (i , j ),
105+ smem ,
106+ mbar ,
107+ multicast_mask = 0x3 ,
108+ group = cl .CTAGroup .CTA_2 ,
109+ predicate = pred ,
110+ )
111+
112+ self .check_ptx_source (
113+ kernel ,
114+ "cp.async.bulk.tensor.2d.shared::cluster.global" ,
115+ "multicast::cluster" ,
116+ )
117+
63118 def test_unsupported_kwargs_for_cta_mode (self , subtests ):
64119 @cl .kernel
65120 def k1 (x , pred , i , j , H : cl .Constant [int ], W : cl .Constant [int ]):
0 commit comments