Skip to content

Commit 288bb0e

Browse files
[lang] Fix cta group kind in cp_async_bulk_tensor hir2ir lowering
Signed-off-by: Asher Mancinelli <amancinelli@nvidia.com>
1 parent 9114db8 commit 288bb0e

2 files changed

Lines changed: 59 additions & 1 deletion

File tree

experimental/cuda-lang/src/cuda/lang/_ir/op_impl/cp_async.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
require_uniform_int_tuple_type,
2020
tensor_map_descriptor_like,
2121
)
22-
from cuda.tile._ir.op_impl import require_constant_enum
22+
from cuda.tile._ir.op_impl import require_constant_enum, require_optional_constant_enum
2323
import cuda.lang._mlir.nvvm as mlir
2424

2525

@@ -91,6 +91,9 @@ def cp_async_bulk_tensor_global_to_shared_impl(
9191
require_optional(multicast_mask, require_integral_scalar_type)
9292
require_optional(l2_cache_hint, require_integral_scalar_type)
9393
require_optional(predicate, require_boolean_scalar_type)
94+
group_value = require_optional_constant_enum(group, cp_async.CTAGroup)
95+
if group_value is not None:
96+
group = loosely_typed_const(getattr(mlir.CTAGroupKind, group_value.name))
9497

9598
return _raw_nvvm_mlir_operation_impl(
9699
nvvm_mlir_interfaces.cp_async_bulk_tensor_shared_cluster_global,

experimental/cuda-lang/test/test_cp_async.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,61 @@ def kernel(x, pred, i, j, H: cl.Constant[int], W: cl.Constant[int]):
6060
)
6161
self.check_ptx_source(kernel, expect)
6262

63+
@require_blackwell_cc100()
64+
@pytest.mark.parametrize(
65+
"group,expect_group",
66+
(
67+
(cl.CTAGroup.CTA_1, "cta_group::1"),
68+
(cl.CTAGroup.CTA_2, "cta_group::2"),
69+
),
70+
)
71+
def test_shared_cluster_group(self, group, expect_group):
72+
@cl.kernel
73+
def kernel(x, pred, i, j, H: cl.Constant[int], W: cl.Constant[int]):
74+
tensor_map = cl.tensor_map_tiled(x, (H, W)).as_opaque_ptr()
75+
smem = cl.shared_array(shape=(H * W,), dtype=cl.int32, alignment=512)
76+
smem = cl.map_shared_to_cluster(smem.get_base_pointer(), 0)
77+
mbar = cl.shared_array(1, cl.mbarrier, alignment=8).get_base_pointer()
78+
79+
cl.cp_async_bulk_tensor_global_to_shared(
80+
tensor_map,
81+
(i, j),
82+
smem,
83+
mbar,
84+
group=group,
85+
)
86+
87+
self.check_ptx_source(
88+
kernel,
89+
"cp.async.bulk.tensor.2d.shared::cluster.global",
90+
expect_group,
91+
)
92+
93+
@require_blackwell_cc100()
94+
def test_shared_cluster_group_with_predicate_and_multicast(self):
95+
@cl.kernel
96+
def kernel(x, pred, i, j, H: cl.Constant[int], W: cl.Constant[int]):
97+
tensor_map = cl.tensor_map_tiled(x, (H, W)).as_opaque_ptr()
98+
smem = cl.shared_array(shape=(H * W,), dtype=cl.int32, alignment=512)
99+
smem = cl.map_shared_to_cluster(smem.get_base_pointer(), 0)
100+
mbar = cl.shared_array(1, cl.mbarrier, alignment=8).get_base_pointer()
101+
102+
cl.cp_async_bulk_tensor_global_to_shared(
103+
tensor_map,
104+
(i, j),
105+
smem,
106+
mbar,
107+
multicast_mask=0x3,
108+
group=cl.CTAGroup.CTA_2,
109+
predicate=pred,
110+
)
111+
112+
self.check_ptx_source(
113+
kernel,
114+
"cp.async.bulk.tensor.2d.shared::cluster.global",
115+
"multicast::cluster",
116+
)
117+
63118
def test_unsupported_kwargs_for_cta_mode(self, subtests):
64119
@cl.kernel
65120
def k1(x, pred, i, j, H: cl.Constant[int], W: cl.Constant[int]):

0 commit comments

Comments
 (0)