diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py index bb3c59599e56..fc3efbf5f549 100644 --- a/python/tvm/backend/cuda/op.py +++ b/python/tvm/backend/cuda/op.py @@ -70,7 +70,7 @@ def cuda_func_call(func_name, *args, source_code, return_type="void"): return_type: str The return type of the CUDA function. """ - return call_intrin(return_type, "tirx.cuda_func_call", func_name, *args, source_code) + return call_intrin(return_type, "tirx.cuda.func_call", func_name, *args, source_code) def cuda_warp_reduce(value, op, width=32): @@ -97,7 +97,7 @@ def cuda_warp_reduce(value, op, width=32): call : PrimExpr The reduced value (same dtype as *value*). """ - return call_intrin(value.dtype, "tirx.cuda_warp_reduce", value, op, width) + return call_intrin(value.dtype, "tirx.cuda.warp_reduce", value, op, width) def cuda_warp_sum(value, width=32): @@ -141,7 +141,7 @@ def cuda_cta_reduce(value, op, num_warps, scratch): call : PrimExpr The reduced value broadcast to all threads (same dtype as *value*). """ - return call_intrin(value.dtype, "tirx.cuda_cta_reduce", value, op, num_warps, scratch) + return call_intrin(value.dtype, "tirx.cuda.cta_reduce", value, op, num_warps, scratch) def cuda_cta_sum(value, num_warps, scratch): @@ -182,7 +182,7 @@ def cuda_copy_bytes(dst, src, num_bytes): call : PrimExpr A void call expression. """ - return call_intrin("void", "tirx.cuda_copy_bytes", dst, src, num_bytes) + return call_intrin("void", "tirx.cuda.copy_bytes", dst, src, num_bytes) def cuda_copy_128b(dst, src): @@ -220,7 +220,7 @@ def cuda_warp_sync(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_warp_sync") + return call_intrin("", "tirx.cuda.warp_sync") def cuda_cta_sync(): @@ -231,7 +231,7 @@ def cuda_cta_sync(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_cta_sync") + return call_intrin("", "tirx.cuda.cta_sync") def cuda_grid_sync(): @@ -242,7 +242,7 @@ def cuda_grid_sync(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_grid_sync") + return call_intrin("", "tirx.cuda.grid_sync") def cuda_cluster_sync(): @@ -253,7 +253,7 @@ def cuda_cluster_sync(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_cluster_sync") + return call_intrin("", "tirx.cuda.cluster_sync") def cuda_thread_rank(): @@ -271,7 +271,7 @@ def cuda_thread_rank(): call : PrimExpr The call expression (``int32``). """ - return call_intrin("int32", "tirx.cuda_thread_rank") + return call_intrin("int32", "tirx.cuda.thread_rank") def cuda_half2float(src): @@ -287,7 +287,7 @@ def cuda_half2float(src): call : PrimExpr The call expression. """ - return call_intrin("float32", "tirx.cuda_half2float", src) + return call_intrin("float32", "tirx.cuda.half2float", src) def cuda_bfloat162float(src): @@ -303,7 +303,7 @@ def cuda_bfloat162float(src): call : PrimExpr The call expression. """ - return call_intrin("float32", "tirx.cuda_bfloat162float", src) + return call_intrin("float32", "tirx.cuda.bfloat162float", src) def cuda_float22half2(dst, src): @@ -322,7 +322,7 @@ def cuda_float22half2(dst, src): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_float22half2", dst, src) + return call_intrin("", "tirx.cuda.float22half2", dst, src) def cuda_trap_when_assert_failed(cond): @@ -338,7 +338,7 @@ def cuda_trap_when_assert_failed(cond): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_trap_when_assert_failed", cond) + return call_intrin("", "tirx.cuda.trap_when_assert_failed", cond) def cuda_runtime_instr_desc(desc, sf_id): @@ -357,7 +357,7 @@ def cuda_runtime_instr_desc(desc, sf_id): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_runtime_instr_desc", desc, sf_id) + return call_intrin("", "tirx.cuda.runtime_instr_desc", desc, sf_id) def cuda_half8tofloat8(src_addr, dst_addr): @@ -376,7 +376,7 @@ def cuda_half8tofloat8(src_addr, dst_addr): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_half8tofloat8", src_addr, dst_addr) + return call_intrin("", "tirx.cuda.half8tofloat8", src_addr, dst_addr) def cuda_float8tohalf8(src_addr, dst_addr): @@ -395,7 +395,7 @@ def cuda_float8tohalf8(src_addr, dst_addr): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_float8tohalf8", src_addr, dst_addr) + return call_intrin("", "tirx.cuda.float8tohalf8", src_addr, dst_addr) def ptx_mma_sp( @@ -480,7 +480,7 @@ def ptx_mma_sp( """ return call_intrin( dtype, - "tirx.ptx_mma_sp", + "tirx.ptx.mma_sp", shape, A_layout, B_layout, @@ -536,7 +536,7 @@ def ptx_cp_async_bulk( """ return call_intrin( dtype, - "tirx.ptx_cp_async_bulk", + "tirx.ptx.cp_async_bulk", shared_ptr, shared_offset, global_ptr, @@ -572,7 +572,7 @@ def ptx_cp_async_bulk_shared_to_cluster(dst_ptr, src_ptr, size, mbar): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_cp_async_bulk_shared_to_cluster", dst_ptr, src_ptr, size, mbar) + return call_intrin("", "tirx.ptx.cp_async_bulk_shared_to_cluster", dst_ptr, src_ptr, size, mbar) def ptx_cp_async_mbarrier_arrive(barrier_id): @@ -589,7 +589,7 @@ def ptx_cp_async_mbarrier_arrive(barrier_id): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_cp_async_mbarrier_arrive", barrier_id) + return call_intrin("", "tirx.ptx.cp_async_mbarrier_arrive", barrier_id) def ptx_fence(sem: str, scope: str): @@ -611,7 +611,7 @@ def ptx_fence(sem: str, scope: str): """ _choice("sem", sem, _FENCE_SEM) _choice("scope", scope, _FENCE_SCOPE) - return call_intrin("", "tirx.ptx_fence", sem, scope) + return call_intrin("", "tirx.ptx.fence", sem, scope) def ptx_fence_proxy_async(space: str = ""): @@ -631,7 +631,7 @@ def ptx_fence_proxy_async(space: str = ""): The call expression. """ _choice("space", space, _FENCE_PROXY_ASYNC_SPACE) - return call_intrin("", "tirx.ptx_fence_proxy_async", space) + return call_intrin("", "tirx.ptx.fence_proxy_async", space) def ptx_mbarrier_init(bar, thread_count): @@ -650,7 +650,7 @@ def ptx_mbarrier_init(bar, thread_count): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count) + return call_intrin("", "tirx.ptx.mbarrier_init", bar, thread_count) def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None): @@ -677,11 +677,11 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None): ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``. """ if cta_id is None and pred is None: - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar) + return call_intrin("", "tirx.ptx.mbarrier_arrive", bar) assert cta_id is not None and pred is not None if count is None: - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred) - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred, count) + return call_intrin("", "tirx.ptx.mbarrier_arrive", bar, cta_id, pred) + return call_intrin("", "tirx.ptx.mbarrier_arrive", bar, cta_id, pred, count) def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count): @@ -691,7 +691,7 @@ def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count): ``@p mapa.shared::cluster.u32`` + ``@p mbarrier.arrive.shared::cluster.b64 _, [addr], count`` with the guard defaulted to 1. """ - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, count) + return call_intrin("", "tirx.ptx.mbarrier_arrive", bar, cta_id, True, count) def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): @@ -722,13 +722,13 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): The call expression. """ if cta_id is None and pred is None: - return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count) + return call_intrin("", "tirx.ptx.mbarrier_arrive_expect_tx", bar, byte_count) assert cta_id is not None # Cross-CTA expect_tx from an already-elected thread: default the guard to 1 # (the caller has elected a single lane), so callers can pass cta_id alone. if pred is None: pred = True - return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count, cta_id, pred) + return call_intrin("", "tirx.ptx.mbarrier_arrive_expect_tx", bar, byte_count, cta_id, pred) def ptx_mbarrier_try_wait(bar, phase): @@ -747,7 +747,7 @@ def ptx_mbarrier_try_wait(bar, phase): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase) + return call_intrin("", "tirx.ptx.mbarrier_try_wait", bar, phase) def ptx_mbarrier_try_wait_acquire_cluster(bar, phase): @@ -764,7 +764,7 @@ def ptx_mbarrier_try_wait_acquire_cluster(bar, phase): phase : int The phase of the barrier. """ - return call_intrin("", "tirx.ptx_mbarrier_try_wait_acquire_cluster", bar, phase) + return call_intrin("", "tirx.ptx.mbarrier_try_wait_acquire_cluster", bar, phase) def ptx_mbarrier_try_wait_once(bar, phase, ticks): @@ -774,7 +774,7 @@ def ptx_mbarrier_try_wait_once(bar, phase, ticks): This is intended for bounded debug waits; production waits should use :func:`ptx_mbarrier_try_wait`. """ - return call_intrin("uint32", "tirx.ptx_mbarrier_try_wait_once", bar, phase, ticks) + return call_intrin("uint32", "tirx.ptx.mbarrier_try_wait_once", bar, phase, ticks) def ptx_bar_arrive(name_bar_id, thread_count): @@ -793,7 +793,7 @@ def ptx_bar_arrive(name_bar_id, thread_count): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_bar_arrive", name_bar_id, thread_count) + return call_intrin("", "tirx.ptx.bar_arrive", name_bar_id, thread_count) def ptx_bar_sync(name_bar_id, thread_count): @@ -812,7 +812,7 @@ def ptx_bar_sync(name_bar_id, thread_count): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_bar_sync", name_bar_id, thread_count) + return call_intrin("", "tirx.ptx.bar_sync", name_bar_id, thread_count) def ptx_cp_async( @@ -870,7 +870,7 @@ def ptx_cp_async( _choice("fill_mode", fill_mode, _CP_ASYNC_FILL_MODE) return call_intrin( "", - "tirx.ptx_cp_async", + "tirx.ptx.cp_async", dst_ptr, src_ptr, cp_size, @@ -921,7 +921,7 @@ def ptx_cp_async_commit_group(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_cp_async_commit_group") + return call_intrin("", "tirx.ptx.cp_async_commit_group") def ptx_cp_async_wait_group(num=0): @@ -938,7 +938,7 @@ def ptx_cp_async_wait_group(num=0): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_cp_async_wait_group", num) + return call_intrin("", "tirx.ptx.cp_async_wait_group", num) def ptx_cp_async_bulk_tensor_global_to_cluster( @@ -985,7 +985,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster( has_cache_policy, *coords = coords return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", + "tirx.ptx.cp_async_bulk_tensor_global_to_cluster", dim, dst_ptr, bar, @@ -999,7 +999,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", + "tirx.ptx.cp_async_bulk_tensor_global_to_cluster", dim, dst_ptr, bar, @@ -1054,7 +1054,7 @@ def ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( has_cache_policy, *coords = coords return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", + "tirx.ptx.cp_async_bulk_tensor_tile_gather4_global_to_cluster", dim, dst_ptr, bar, @@ -1068,7 +1068,7 @@ def ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", + "tirx.ptx.cp_async_bulk_tensor_tile_gather4_global_to_cluster", dim, dst_ptr, bar, @@ -1112,7 +1112,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global( has_cache_policy, *coords = coords return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global", + "tirx.ptx.cp_async_bulk_tensor_shared_to_global", dim, src_ptr, tensormap_addr, @@ -1123,7 +1123,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global", + "tirx.ptx.cp_async_bulk_tensor_shared_to_global", dim, src_ptr, tensormap_addr, @@ -1161,7 +1161,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster_prefetch( has_cache_policy, *coords = coords return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch", + "tirx.ptx.cp_async_bulk_tensor_global_to_cluster_prefetch", dim, tensormap_addr, cache_hint, @@ -1171,7 +1171,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster_prefetch( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch", + "tirx.ptx.cp_async_bulk_tensor_global_to_cluster_prefetch", dim, tensormap_addr, cache_policy, @@ -1216,7 +1216,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global_reduce( _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP) return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce", + "tirx.ptx.cp_async_bulk_tensor_shared_to_global_reduce", dim, src_ptr, tensormap_addr, @@ -1229,7 +1229,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global_reduce( _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP) return call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce", + "tirx.ptx.cp_async_bulk_tensor_shared_to_global_reduce", dim, src_ptr, tensormap_addr, @@ -1248,7 +1248,7 @@ def ptx_cp_async_bulk_commit_group(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_cp_async_bulk_commit_group") + return call_intrin("", "tirx.ptx.cp_async_bulk_commit_group") def ptx_cp_async_bulk_wait_group(n=0, read=True): @@ -1267,7 +1267,7 @@ def ptx_cp_async_bulk_wait_group(n=0, read=True): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_cp_async_bulk_wait_group", n, read) + return call_intrin("", "tirx.ptx.cp_async_bulk_wait_group", n, read) def ptx_barrier_cluster_arrive(sem="", aligned=True): @@ -1282,7 +1282,7 @@ def ptx_barrier_cluster_arrive(sem="", aligned=True): Whether all threads in the warp must execute the same instruction. """ _choice("sem", sem, _CLUSTER_BARRIER_SEM) - return call_intrin("", "tirx.ptx_barrier_cluster_arrive", sem, aligned) + return call_intrin("", "tirx.ptx.barrier_cluster_arrive", sem, aligned) def ptx_barrier_cluster_wait(acquire=False, aligned=True): @@ -1296,7 +1296,7 @@ def ptx_barrier_cluster_wait(acquire=False, aligned=True): aligned : bool Whether all threads in the warp must execute the same instruction. """ - return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned) + return call_intrin("", "tirx.ptx.barrier_cluster_wait", acquire, aligned) def ptx_clc_try_cancel(handle, mbar): @@ -1314,7 +1314,7 @@ def ptx_clc_try_cancel(handle, mbar): mbar : PrimExpr Pointer to the mbarrier signalled when the handle lands. """ - return call_intrin("", "tirx.ptx_clc_try_cancel", handle, mbar) + return call_intrin("", "tirx.ptx.clc_try_cancel", handle, mbar) def ptx_clc_query_cancel(handle): @@ -1328,12 +1328,12 @@ def ptx_clc_query_cancel(handle): handle : PrimExpr Pointer to the 16B (uint4) smem response handle. """ - return call_intrin("uint32", "tirx.ptx_clc_query_cancel", handle) + return call_intrin("uint32", "tirx.ptx.clc_query_cancel", handle) def ptx_elect_sync(): """TVM intrinsic to call elect.sync""" - return call_intrin("uint32", "tirx.ptx_elect_sync") + return call_intrin("uint32", "tirx.ptx.elect_sync") def ptx_fence_mbarrier_init(): @@ -1346,7 +1346,7 @@ def ptx_fence_mbarrier_init(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_fence_mbarrier_init") + return call_intrin("", "tirx.ptx.fence_mbarrier_init") def ptx_fetch_register(bits, reg_name): @@ -1365,7 +1365,7 @@ def ptx_fetch_register(bits, reg_name): call : PrimExpr The call expression. """ - return call_intrin("int" + str(bits), "tirx.ptx_fetch_register", bits, reg_name) + return call_intrin("int" + str(bits), "tirx.ptx.fetch_register", bits, reg_name) def ptx_mma( @@ -1460,7 +1460,7 @@ def ptx_mma( base = [ "", - "tirx.ptx_mma", + "tirx.ptx.mma", shape, a_layout, b_layout, @@ -1552,7 +1552,7 @@ def ptx_mma_legacy(*all_args, operator=None): ] if operator is not None: call_args.append(operator) - return call_intrin("", "tirx.ptx_mma_legacy", *call_args) + return call_intrin("", "tirx.ptx.mma_legacy", *call_args) def ptx_mma_sp_legacy(*all_args): @@ -1679,7 +1679,7 @@ def ptx_ldmatrix(trans, num, dtype, smem_ptr, *dst_handles): f"ldmatrix .x{int(num)}.{dtype_bare} expects {n_regs} destination " f"handles, got {len(dst_handles)}" ) - return call_intrin("", "tirx.ptx_ldmatrix", trans, num, dtype, smem_ptr, *dst_handles) + return call_intrin("", "tirx.ptx.ldmatrix", trans, num, dtype, smem_ptr, *dst_handles) _PTX_TO_NUMPY_DTYPE = { @@ -1773,7 +1773,7 @@ def ptx_ldmatrix_legacy(*all_args): # int8+trans manual-loop fallback (ldmatrix can't transpose int8). return call_intrin( elem_dtype, - "tirx.ptx_ldmatrix_legacy", + "tirx.ptx.ldmatrix_legacy", trans, num, dtype, @@ -1826,7 +1826,7 @@ def ptx_stmatrix(trans, num, dtype, smem_ptr, *src_handles, shape="m8n8", space= f"handles, got {len(src_handles)}" ) return call_intrin( - "", "tirx.ptx_stmatrix", trans, num, dtype, shape, space, smem_ptr, *src_handles + "", "tirx.ptx.stmatrix", trans, num, dtype, shape, space, smem_ptr, *src_handles ) @@ -1850,7 +1850,7 @@ def ptx_wgmma_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): swizzle : int The swizzle value (CUtensorMapSwizzle_enum). """ - return call_intrin("", "tirx.ptx_wgmma_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle) + return call_intrin("", "tirx.ptx.wgmma_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle) def ptx_wgmma_noop_barrier(reg): @@ -1866,7 +1866,7 @@ def ptx_wgmma_noop_barrier(reg): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_wgmma_noop_barrier", reg) + return call_intrin("", "tirx.ptx.wgmma_noop_barrier", reg) def ptx_wgmma_mma_async_ss( @@ -1917,7 +1917,7 @@ def ptx_wgmma_mma_async_ss( """ # noqa: E501 return call_intrin( "", - "tirx.ptx_wgmma_mma_async_ss", + "tirx.ptx.wgmma_mma_async_ss", M, N, K, @@ -1980,7 +1980,7 @@ def ptx_wgmma_mma_async_rs( """ return call_intrin( "", - "tirx.ptx_wgmma_mma_async_rs", + "tirx.ptx.wgmma_mma_async_rs", M, N, K, @@ -2004,7 +2004,7 @@ def ptx_wgmma_fence(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_wgmma_fence") + return call_intrin("", "tirx.ptx.wgmma_fence") def ptx_wgmma_commit_group(): @@ -2015,7 +2015,7 @@ def ptx_wgmma_commit_group(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_wgmma_commit_group") + return call_intrin("", "tirx.ptx.wgmma_commit_group") def ptx_wgmma_wait_group(n): @@ -2031,7 +2031,7 @@ def ptx_wgmma_wait_group(n): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_wgmma_wait_group", n) + return call_intrin("", "tirx.ptx.wgmma_wait_group", n) def ptx_setmaxnreg(inc: bool, reg_count): @@ -2045,7 +2045,7 @@ def ptx_setmaxnreg(inc: bool, reg_count): reg_count : int The register count. """ - return call_intrin("", "tirx.ptx_setmaxnreg", inc, reg_count) + return call_intrin("", "tirx.ptx.setmaxnreg", inc, reg_count) def ptx_tcgen05_alloc(dst_ptr, n_cols, cta_group=1): @@ -2068,7 +2068,7 @@ def ptx_tcgen05_alloc(dst_ptr, n_cols, cta_group=1): one warp from each of the peer CTAs perform the allocation. """ _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_alloc", dst_ptr, n_cols, cta_group) + return call_intrin("", "tirx.ptx.tcgen05_alloc", dst_ptr, n_cols, cta_group) def ptx_tcgen05_dealloc(taddr, n_cols, cta_group=1): @@ -2090,7 +2090,7 @@ def ptx_tcgen05_dealloc(taddr, n_cols, cta_group=1): one warp from each of the peer CTAs perform the deallocation. """ _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_dealloc", taddr, n_cols, cta_group) + return call_intrin("", "tirx.ptx.tcgen05_dealloc", taddr, n_cols, cta_group) def ptx_tcgen05_relinquish_alloc_permit(cta_group=1): @@ -2106,7 +2106,7 @@ def ptx_tcgen05_relinquish_alloc_permit(cta_group=1): one warp from each of the peer CTAs perform the relinquishing. """ _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_relinquish_alloc_permit", cta_group) + return call_intrin("", "tirx.ptx.tcgen05_relinquish_alloc_permit", cta_group) def ptx_tcgen05_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): @@ -2130,7 +2130,7 @@ def ptx_tcgen05_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): The swizzle value (CUtensorMapSwizzle_enum). """ return call_intrin( - "", "tirx.ptx_tcgen05_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle + "", "tirx.ptx.tcgen05_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle ) @@ -2202,7 +2202,7 @@ def ptx_tcgen05_encode_instr_descriptor( _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP) return call_intrin( "", - "tirx.ptx_tcgen05_encode_instr_descriptor", + "tirx.ptx.tcgen05_encode_instr_descriptor", desc, d_dtype, a_dtype, @@ -2300,7 +2300,7 @@ def ptx_tcgen05_encode_instr_descriptor_block_scaled( _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP) return call_intrin( "", - "tirx.ptx_tcgen05_encode_instr_descriptor_block_scaled", + "tirx.ptx.tcgen05_encode_instr_descriptor_block_scaled", desc, d_dtype, a_dtype, @@ -2407,7 +2407,7 @@ def ptx_tcgen05_mma( ] if pred is not None: args.append(pred) - return call_intrin("", "tirx.ptx_tcgen05_mma", *args) + return call_intrin("", "tirx.ptx.tcgen05_mma", *args) def ptx_tcgen05_mma_block_scale( @@ -2481,7 +2481,7 @@ def ptx_tcgen05_mma_block_scale( _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) return call_intrin( "", - "tirx.ptx_tcgen05_mma_block_scale", + "tirx.ptx.tcgen05_mma_block_scale", d_dtype, a_dtype, b_dtype, @@ -2569,7 +2569,7 @@ def ptx_tcgen05_mma_sp( return call_intrin( "", - "tirx.ptx_tcgen05_mma_sp", + "tirx.ptx.tcgen05_mma_sp", d_dtype, a_dtype, b_dtype, @@ -2660,7 +2660,7 @@ def ptx_tcgen05_mma_sp_block_scale( _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) return call_intrin( "", - "tirx.ptx_tcgen05_mma_sp_block_scale", + "tirx.ptx.tcgen05_mma_sp_block_scale", d_dtype, a_dtype, b_dtype, @@ -2683,14 +2683,14 @@ def ptx_tcgen05_fence_before_thread_sync(): """TVM intrinsic to call tcgen05.fence::before_thread_sync Orders all prior asynchronous tcgen05 operations relative to subsequent operations. """ - return call_intrin("", "tirx.ptx_tcgen05_fence_before_thread_sync") + return call_intrin("", "tirx.ptx.tcgen05_fence_before_thread_sync") def ptx_tcgen05_fence_after_thread_sync(): """TVM intrinsic to call tcgen05.fence::after_thread_sync Orders all subsequent asynchronous tcgen05 operations relative to previous operations. """ - return call_intrin("", "tirx.ptx_tcgen05_fence_after_thread_sync") + return call_intrin("", "tirx.ptx.tcgen05_fence_after_thread_sync") def _choice(name: str, value, options): @@ -2770,7 +2770,7 @@ def ptx_tcgen05_cp( return call_intrin( "", - "tirx.ptx_tcgen05_cp", + "tirx.ptx.tcgen05_cp", taddr, src_desc, shape, @@ -2798,7 +2798,7 @@ def ptx_tcgen05_shift(taddr, cta_group=1): the peer CTA. """ _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_shift", taddr, cta_group) + return call_intrin("", "tirx.ptx.tcgen05_shift", taddr, cta_group) def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, col=0, pack=False): @@ -2828,7 +2828,7 @@ def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, col=0, pack=False): Pack two 16-bit chunks into a single 32-bit register. """ _choice("shape", shape, _TCGEN05_LDST_SHAPES) - return call_intrin("", "tirx.ptx_tcgen05_ld", src_addr, row, col, shape, num, pack, *regs) + return call_intrin("", "tirx.ptx.tcgen05_ld", src_addr, row, col, shape, num, pack, *regs) def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, col=0, unpack=False): @@ -2858,21 +2858,21 @@ def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, col=0, unpack=False): Unpack a 32-bit register into two 16-bit chunks. """ _choice("shape", shape, _TCGEN05_LDST_SHAPES) - return call_intrin("", "tirx.ptx_tcgen05_st", dst_addr, row, col, shape, num, unpack, *regs) + return call_intrin("", "tirx.ptx.tcgen05_st", dst_addr, row, col, shape, num, unpack, *regs) def ptx_tcgen05_wait_ld(): """TVM intrinsic to call tcgen05.wait::ld.sync.aligned Wait for the completion of all prior async tcgen05.ld operations. """ - return call_intrin("", "tirx.ptx_tcgen05_wait_ld") + return call_intrin("", "tirx.ptx.tcgen05_wait_ld") def ptx_tcgen05_wait_st(): """TVM intrinsic to call tcgen05.wait::st.sync.aligned Wait for the completion of all prior async tcgen05.st operations. """ - return call_intrin("", "tirx.ptx_tcgen05_wait_st") + return call_intrin("", "tirx.ptx.tcgen05_wait_st") def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None): @@ -2904,7 +2904,7 @@ def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None): args = [bar, cta_group, cta_mask] if pred is not None: args.append(pred) - return call_intrin("", "tirx.ptx_tcgen05_commit", *args) + return call_intrin("", "tirx.ptx.tcgen05_commit", *args) def timer_init_cuda(profiler_buffer, profiler_tag, profiler_write_offset, num_groups, group_id): @@ -3100,7 +3100,7 @@ def cuda_atomic_add(res_addr, value): The call expression. """ value = tir.convert(value) - return call_intrin(value.dtype, "tirx.cuda_atomic_add", res_addr, value) + return call_intrin(value.dtype, "tirx.cuda.atomic_add", res_addr, value) def cuda_thread_fence(): @@ -3111,7 +3111,7 @@ def cuda_thread_fence(): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_thread_fence") + return call_intrin("", "tirx.cuda.thread_fence") def cuda_warpgroup_sync(bar_no): @@ -3131,7 +3131,7 @@ def cuda_warpgroup_sync(bar_no): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_warpgroup_sync", bar_no) + return call_intrin("", "tirx.cuda.warpgroup_sync", bar_no) def cuda_syncthreads_and(cond): @@ -3147,7 +3147,7 @@ def cuda_syncthreads_and(cond): call : PrimExpr The call expression. """ - return call_intrin("int64", "tirx.cuda_syncthreads_and", cond) + return call_intrin("int64", "tirx.cuda.syncthreads_and", cond) def cuda_syncthreads_or(cond): @@ -3163,7 +3163,7 @@ def cuda_syncthreads_or(cond): call : PrimExpr The call expression. """ - return call_intrin("int64", "tirx.cuda_syncthreads_or", cond) + return call_intrin("int64", "tirx.cuda.syncthreads_or", cond) def cuda_nano_sleep(time): @@ -3179,7 +3179,7 @@ def cuda_nano_sleep(time): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_nano_sleep", time) + return call_intrin("", "tirx.cuda.nano_sleep", time) def cuda_printf(fmt, *args): @@ -3198,7 +3198,7 @@ def cuda_printf(fmt, *args): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_printf", fmt, *args) + return call_intrin("", "tirx.cuda.printf", fmt, *args) def cuda_ldg(addr, dtype): @@ -3214,7 +3214,7 @@ def cuda_ldg(addr, dtype): Returns """ - return call_intrin(dtype, "tirx.cuda_ldg", addr, dtype) + return call_intrin(dtype, "tirx.cuda.ldg", addr, dtype) def cuda_get_tmem_addr(addr, row_offset, col_offset): @@ -3236,7 +3236,7 @@ def cuda_get_tmem_addr(addr, row_offset, col_offset): call : PrimExpr The call expression. """ - return call_intrin("uint32", "tirx.cuda_get_tmem_addr", addr, row_offset, col_offset) + return call_intrin("uint32", "tirx.cuda.get_tmem_addr", addr, row_offset, col_offset) def cuda_cvta_generic_to_shared(ptr): @@ -3246,7 +3246,7 @@ def cuda_cvta_generic_to_shared(ptr): precompute the shared-memory address at the wrapper layer instead of inside the asm helper body. """ - return call_intrin("uint32", "tirx.cuda_cvta_generic_to_shared", ptr) + return call_intrin("uint32", "tirx.cuda.cvta_generic_to_shared", ptr) def cuda_smem_addr_from_uint64(cluster_addr): @@ -3255,7 +3255,7 @@ def cuda_smem_addr_from_uint64(cluster_addr): Wraps ``static_cast(cluster_addr)``. Used by cp.async.bulk.shared::cluster.* op-wrappers. """ - return call_intrin("uint32", "tirx.cuda_smem_addr_from_uint64", cluster_addr) + return call_intrin("uint32", "tirx.cuda.smem_addr_from_uint64", cluster_addr) def cuda_sm100_tma_2sm_mbarrier_addr(bar): @@ -3276,7 +3276,7 @@ def ptx_exp2(x): call : PrimExpr The call expression returning 2^x (approximate). """ - return call_intrin("float32", "tirx.ptx_exp2", x) + return call_intrin("float32", "tirx.ptx.exp2", x) def ptx_rcp(x): @@ -3292,7 +3292,7 @@ def ptx_rcp(x): call : PrimExpr The call expression returning 1/x (approximate). """ - return call_intrin("float32", "tirx.ptx_rcp", x) + return call_intrin("float32", "tirx.ptx.rcp", x) def ptx_any_sync(mask, pred): @@ -3310,7 +3310,7 @@ def ptx_any_sync(mask, pred): call : PrimExpr The call expression returning 1 if any thread in mask has pred != 0. """ - return call_intrin("int32", "tirx.ptx_any_sync", mask, pred) + return call_intrin("int32", "tirx.ptx.any_sync", mask, pred) def ptx_reduce3_max_f32(a, b, c): @@ -3326,7 +3326,7 @@ def ptx_reduce3_max_f32(a, b, c): call : PrimExpr The call expression returning max(a, b, c). """ - return call_intrin("float32", "tirx.ptx_reduce3_max_f32", a, b, c) + return call_intrin("float32", "tirx.ptx.reduce3_max_f32", a, b, c) def ptx_reduce3_min_f32(a, b, c): @@ -3342,7 +3342,7 @@ def ptx_reduce3_min_f32(a, b, c): call : PrimExpr The call expression returning min(a, b, c). """ - return call_intrin("float32", "tirx.ptx_reduce3_min_f32", a, b, c) + return call_intrin("float32", "tirx.ptx.reduce3_min_f32", a, b, c) def _ptx_binary_arith(op_name, dtype, d, a, b, *, rounding="rn", ftz=False, sat=False): @@ -3354,7 +3354,7 @@ def _ptx_binary_arith(op_name, dtype, d, a, b, *, rounding="rn", ftz=False, sat= raise ValueError(f"PTX {op_name}.f32x2 does not accept .sat") return call_intrin( "", - f"tirx.ptx_{op_name}_{dtype}", + f"tirx.ptx.{op_name}_{dtype}", d, a, b, @@ -3373,7 +3373,7 @@ def _ptx_fma(dtype, d, a, b, c, *, rounding="rn", ftz=False, sat=False): raise ValueError("PTX fma.f32x2 does not accept .sat") return call_intrin( "", - f"tirx.ptx_fma_{dtype}", + f"tirx.ptx.fma_{dtype}", d, a, b, @@ -3466,7 +3466,7 @@ def ptx_max_f32(a, b, *, ftz=False, nan=False): nan : bool If True, propagate NaN inputs (``.NaN``). """ - return call_intrin("float32", "tirx.ptx_max_f32", a, b, int(ftz), int(nan)) + return call_intrin("float32", "tirx.ptx.max_f32", a, b, int(ftz), int(nan)) def ptx_griddepcontrol_wait(): @@ -3476,7 +3476,7 @@ def ptx_griddepcontrol_wait(): :func:`ptx_griddepcontrol_launch_dependents` have finished. Acts as a full memory barrier. """ - return call_intrin("", "tirx.ptx_griddepcontrol_wait") + return call_intrin("", "tirx.ptx.griddepcontrol_wait") def ptx_griddepcontrol_launch_dependents(): @@ -3485,7 +3485,7 @@ def ptx_griddepcontrol_launch_dependents(): Signals that the current grid has reached a point where dependent grids may begin execution. """ - return call_intrin("", "tirx.ptx_griddepcontrol_launch_dependents") + return call_intrin("", "tirx.ptx.griddepcontrol_launch_dependents") _PTX_LD_SCOPE = {"cta", "cluster", "gpu", "sys"} @@ -3565,7 +3565,7 @@ def ptx_ld_acquire(addr, return_type, ptx_type, *, scope="gpu", space="global"): _choice("space", space, _PTX_LD_SPACE) _choice("ptx_type", ptx_type, _PTX_LD_TYPE) return call_intrin( - return_type, "tirx.ptx_ld_acquire", addr, return_type, ptx_type, scope, space + return_type, "tirx.ptx.ld_acquire", addr, return_type, ptx_type, scope, space ) @@ -3591,7 +3591,7 @@ def ptx_ld( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( return_type, - "tirx.ptx_ld", + "tirx.ptx.ld", addr, cache_policy, return_type, @@ -3610,7 +3610,7 @@ def ptx_ld_volatile(addr, return_type, ptx_type, *, space="global"): """ _choice("space", space, _PTX_LD_VOLATILE_SPACE) _choice("ptx_type", ptx_type, _PTX_LD_TYPE) - return call_intrin(return_type, "tirx.ptx_ld_volatile", addr, return_type, ptx_type, space) + return call_intrin(return_type, "tirx.ptx.ld_volatile", addr, return_type, ptx_type, space) def ptx_ld_global_acquire(res, addr): @@ -3629,7 +3629,7 @@ def ptx_ld_global_acquire(res, addr): call : PrimExpr The call expression. """ - return call_intrin("", "tirx.ptx_ld_global_acquire", res, addr) + return call_intrin("", "tirx.ptx.ld_global_acquire", res, addr) def ptx_red_scalar( @@ -3655,7 +3655,7 @@ def ptx_red_scalar( raise ValueError(f"Unsupported PTX red sem {sem!r}") return call_intrin( "", - "tirx.ptx_red_scalar", + "tirx.ptx.red_scalar", address, value, cache_policy, @@ -3689,7 +3689,7 @@ def ptx_atom_scalar( raise ValueError(f"Unsupported PTX atom sem {sem!r}") return call_intrin( _PTX_SCALAR_RETURN_TYPE[ptx_type], - "tirx.ptx_atom_scalar", + "tirx.ptx.atom_scalar", address, value, cache_policy, @@ -3720,7 +3720,7 @@ def ptx_st( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_st", + "tirx.ptx.st", address, *values, cache_policy, @@ -3736,12 +3736,12 @@ def ptx_st( def ptx_st_bulk(ptr, num_bytes, *, weak=False, space="shared::cta"): if space not in ("", "shared::cta"): raise ValueError(f"Unsupported PTX st.bulk space {space!r}") - return call_intrin("", "tirx.ptx_st_bulk", ptr, num_bytes, int(bool(weak)), space) + return call_intrin("", "tirx.ptx.st_bulk", ptr, num_bytes, int(bool(weak)), space) def ptx_prefetch_tensormap(tensormap_addr, space=""): _choice("space", space, _PTX_PREFETCH_TENSORMAP_SPACE) - return call_intrin("", "tirx.ptx_prefetch_tensormap", tensormap_addr, space) + return call_intrin("", "tirx.ptx.prefetch_tensormap", tensormap_addr, space) def ptx_mbarrier_test_wait_parity(barrier, phase, *, sem="", scope="", space="shared::cta"): @@ -3754,7 +3754,7 @@ def ptx_mbarrier_test_wait_parity(barrier, phase, *, sem="", scope="", space="sh if space not in ("shared", "shared::cta"): raise ValueError(f"Unsupported mbarrier.test_wait.parity space {space!r}") return call_intrin( - "uint32", "tirx.ptx_mbarrier_test_wait_parity", barrier, phase, sem, scope, space + "uint32", "tirx.ptx.mbarrier_test_wait_parity", barrier, phase, sem, scope, space ) @@ -3773,7 +3773,7 @@ def ptx_cp_async_bulk_g2s_cta( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_cp_async_bulk_g2s_cta", + "tirx.ptx.cp_async_bulk_g2s_cta", dst_ptr, src_ptr, num_bytes, @@ -3800,7 +3800,7 @@ def ptx_cp_async_bulk_g2s_cluster( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_cp_async_bulk_g2s_cluster", + "tirx.ptx.cp_async_bulk_g2s_cluster", dst_ptr, src_ptr, num_bytes, @@ -3814,7 +3814,7 @@ def ptx_cp_async_bulk_g2s_cluster( def ptx_cp_async_bulk_s2s_cluster(dst_ptr, src_ptr, num_bytes, mbarrier): return call_intrin( - "", "tirx.ptx_cp_async_bulk_s2s_cluster", dst_ptr, src_ptr, num_bytes, mbarrier + "", "tirx.ptx.cp_async_bulk_s2s_cluster", dst_ptr, src_ptr, num_bytes, mbarrier ) @@ -3824,7 +3824,7 @@ def ptx_cp_async_bulk_s2g( cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) return call_intrin( "", - "tirx.ptx_cp_async_bulk_s2g", + "tirx.ptx.cp_async_bulk_s2g", dst_ptr, src_ptr, num_bytes, @@ -3836,83 +3836,83 @@ def ptx_cp_async_bulk_s2g( def ptx_fns_b32(mask, base, offset): - return call_intrin("uint32", "tirx.ptx_fns_b32", mask, base, offset) + return call_intrin("uint32", "tirx.ptx.fns_b32", mask, base, offset) def ptx_add_rn_f32_bf16(acc, x): - return call_intrin("float32", "tirx.ptx_add_rn_f32_bf16", acc, x) + return call_intrin("float32", "tirx.ptx.add_rn_f32_bf16", acc, x) def cuda_uint_as_float(bits): - return call_intrin("float32", "tirx.cuda_uint_as_float", bits) + return call_intrin("float32", "tirx.cuda.uint_as_float", bits) def cuda_float_as_uint(x): - return call_intrin("uint32", "tirx.cuda_float_as_uint", x) + return call_intrin("uint32", "tirx.cuda.float_as_uint", x) def cuda_ballot_sync(mask, pred): - return call_intrin("uint32", "tirx.cuda_ballot_sync", mask, pred) + return call_intrin("uint32", "tirx.cuda.ballot_sync", mask, pred) def cuda_ffs_u32(value): - return call_intrin("int32", "tirx.cuda_ffs_u32", value) + return call_intrin("int32", "tirx.cuda.ffs_u32", value) def cuda_reduce_add_sync_u32(mask, value): - return call_intrin("uint32", "tirx.cuda_reduce_add_sync_u32", mask, value) + return call_intrin("uint32", "tirx.cuda.reduce_add_sync_u32", mask, value) def cuda_reduce_min_sync_u32(mask, value): - return call_intrin("uint32", "tirx.cuda_reduce_min_sync_u32", mask, value) + return call_intrin("uint32", "tirx.cuda.reduce_min_sync_u32", mask, value) def cuda_clock64(): - return call_intrin("uint64", "tirx.cuda_clock64") + return call_intrin("uint64", "tirx.cuda.clock64") def cuda_make_float2(x, y): - return call_intrin("uint64", "tirx.cuda_make_float2", x, y) + return call_intrin("uint64", "tirx.cuda.make_float2", x, y) def cuda_float2_x(packed): - return call_intrin("float32", "tirx.cuda_float2_x", packed) + return call_intrin("float32", "tirx.cuda.float2_x", packed) def cuda_float2_y(packed): - return call_intrin("float32", "tirx.cuda_float2_y", packed) + return call_intrin("float32", "tirx.cuda.float2_y", packed) def cuda_fmul2_rn(a, b): - return call_intrin("uint64", "tirx.cuda_fmul2_rn", a, b) + return call_intrin("uint64", "tirx.cuda.fmul2_rn", a, b) def cuda_fadd2_rn(a, b): - return call_intrin("uint64", "tirx.cuda_fadd2_rn", a, b) + return call_intrin("uint64", "tirx.cuda.fadd2_rn", a, b) def cuda_float22bfloat162_rn(v0, v1): - return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn", v0, v1) + return call_intrin("uint32", "tirx.cuda.float22bfloat162_rn", v0, v1) def cuda_float22bfloat162_rn_from_float2(packed): - return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn_from_float2", packed) + return call_intrin("uint32", "tirx.cuda.float22bfloat162_rn_from_float2", packed) def cuda_bfloat1622float2(packed): - return call_intrin("uint64", "tirx.cuda_bfloat1622float2", packed) + return call_intrin("uint64", "tirx.cuda.bfloat1622float2", packed) def cuda_hmin2(a, b): - return call_intrin("uint32", "tirx.cuda_hmin2", a, b) + return call_intrin("uint32", "tirx.cuda.hmin2", a, b) def cuda_hmax2(a, b): - return call_intrin("uint32", "tirx.cuda_hmax2", a, b) + return call_intrin("uint32", "tirx.cuda.hmax2", a, b) def cuda_fp8x4_e4m3_from_float4(x, y, z, w): - return call_intrin("uint32", "tirx.cuda_fp8x4_e4m3_from_float4", x, y, z, w) + return call_intrin("uint32", "tirx.cuda.fp8x4_e4m3_from_float4", x, y, z, w) def ptx_map_shared_rank(ptr, rank): @@ -3941,7 +3941,7 @@ def ptx_mapa(ptr, rank, *, space="", ptx_type="u64", return_type="uint64"): raise ValueError(f"Unsupported mapa space {space!r}") if ptx_type not in ("u32", "u64"): raise ValueError(f"Unsupported mapa type {ptx_type!r}") - return call_intrin(return_type, "tirx.ptx_mapa", ptr, rank, space, ptx_type, return_type) + return call_intrin(return_type, "tirx.ptx.mapa", ptr, rank, space, ptx_type, return_type) def cuda_atomic_cas(ptr, old_val, new_val): @@ -3964,7 +3964,7 @@ def cuda_atomic_cas(ptr, old_val, new_val): The call expression. """ old_val = tir.convert(old_val) - return call_intrin(old_val.dtype, "tirx.cuda_atomic_cas", ptr, old_val, new_val) + return call_intrin(old_val.dtype, "tirx.cuda.atomic_cas", ptr, old_val, new_val) ######################################################## @@ -3981,7 +3981,7 @@ def nvshmem_my_pe(): The call expression. """ - return call_intrin("int32", "tirx.nvshmem_my_pe") + return call_intrin("int32", "tirx.nvshmem.my_pe") def nvshmem_n_pes(): @@ -3993,7 +3993,7 @@ def nvshmem_n_pes(): The call expression. """ - return call_intrin("int32", "tirx.nvshmem_n_pes") + return call_intrin("int32", "tirx.nvshmem.n_pes") def nvshmem_getmem_nbi(dst, src, nelems, pe): @@ -4019,7 +4019,7 @@ def nvshmem_getmem_nbi(dst, src, nelems, pe): The call expression. """ # noqa: E501 - return call_intrin("", "tirx.nvshmem_getmem_nbi", dst, src, nelems, pe) + return call_intrin("", "tirx.nvshmem.getmem_nbi", dst, src, nelems, pe) def nvshmem_putmem_nbi(dst, src, nelems, pe): @@ -4045,7 +4045,7 @@ def nvshmem_putmem_nbi(dst, src, nelems, pe): The call expression. """ - return call_intrin("", "tirx.nvshmem_putmem_nbi", dst, src, nelems, pe) + return call_intrin("", "tirx.nvshmem.putmem_nbi", dst, src, nelems, pe) def nvshmem_getmem_nbi_warp(dst, src, nelems, pe): @@ -4071,7 +4071,7 @@ def nvshmem_getmem_nbi_warp(dst, src, nelems, pe): The call expression. """ # noqa: E501 - return call_intrin("", "tirx.nvshmem_getmem_nbi_warp", dst, src, nelems, pe) + return call_intrin("", "tirx.nvshmem.getmem_nbi_warp", dst, src, nelems, pe) def nvshmem_putmem_nbi_warp(dst, src, nelems, pe): @@ -4097,7 +4097,7 @@ def nvshmem_putmem_nbi_warp(dst, src, nelems, pe): The call expression. """ - return call_intrin("", "tirx.nvshmem_putmem_nbi_warp", dst, src, nelems, pe) + return call_intrin("", "tirx.nvshmem.putmem_nbi_warp", dst, src, nelems, pe) def nvshmem_getmem_nbi_block(dst, src, nelems, pe): @@ -4123,7 +4123,7 @@ def nvshmem_getmem_nbi_block(dst, src, nelems, pe): The call expression. """ # noqa: E501 - return call_intrin("", "tirx.nvshmem_getmem_nbi_block", dst, src, nelems, pe) + return call_intrin("", "tirx.nvshmem.getmem_nbi_block", dst, src, nelems, pe) def nvshmem_putmem_nbi_block(dst, src, nelems, pe): @@ -4149,7 +4149,7 @@ def nvshmem_putmem_nbi_block(dst, src, nelems, pe): The call expression. """ - return call_intrin("", "tirx.nvshmem_putmem_nbi_block", dst, src, nelems, pe) + return call_intrin("", "tirx.nvshmem.putmem_nbi_block", dst, src, nelems, pe) def nvshmem_signal_op(sig_addr, signal, sig_op, pe): @@ -4176,7 +4176,7 @@ def nvshmem_signal_op(sig_addr, signal, sig_op, pe): """ _choice("sig_op", sig_op, _NVSHMEM_SIG_OP) - return call_intrin("", "tirx.nvshmem_signal_op", sig_addr, signal, sig_op, pe) + return call_intrin("", "tirx.nvshmem.signal_op", sig_addr, signal, sig_op, pe) def nvshmem_wait_until(ivar, cmp, cmp_value, type="uint64_t"): @@ -4203,7 +4203,7 @@ def nvshmem_wait_until(ivar, cmp, cmp_value, type="uint64_t"): """ _choice("cmp", cmp, _NVSHMEM_CMP) - return call_intrin("", "tirx.nvshmem_wait_until", ivar, cmp, cmp_value, type) + return call_intrin("", "tirx.nvshmem.wait_until", ivar, cmp, cmp_value, type) def nvshmem_quiet(): @@ -4215,7 +4215,7 @@ def nvshmem_quiet(): The call expression. """ - return call_intrin("", "tirx.nvshmem_quiet") + return call_intrin("", "tirx.nvshmem.quiet") def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, signal, sig_op, pe): @@ -4251,7 +4251,7 @@ def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, signal, sig_op, pe): """ # noqa: E501 return call_intrin( - "", "tirx.nvshmem_putmem_signal_nbi", dst, src, nelems, sig_addr, signal, sig_op, pe + "", "tirx.nvshmem.putmem_signal_nbi", dst, src, nelems, sig_addr, signal, sig_op, pe ) @@ -4288,7 +4288,7 @@ def nvshmem_putmem_signal_nbi_warp(dst, src, nelems, sig_addr, signal, sig_op, p """ # noqa: E501 return call_intrin( - "", "tirx.nvshmem_putmem_signal_nbi_warp", dst, src, nelems, sig_addr, signal, sig_op, pe + "", "tirx.nvshmem.putmem_signal_nbi_warp", dst, src, nelems, sig_addr, signal, sig_op, pe ) @@ -4325,7 +4325,7 @@ def nvshmem_putmem_signal_nbi_block(dst, src, nelems, sig_addr, signal, sig_op, """ # noqa: E501 return call_intrin( - "", "tirx.nvshmem_putmem_signal_nbi_block", dst, src, nelems, sig_addr, signal, sig_op, pe + "", "tirx.nvshmem.putmem_signal_nbi_block", dst, src, nelems, sig_addr, signal, sig_op, pe ) @@ -4338,7 +4338,7 @@ def nvshmem_fence(): The call expression. """ - return call_intrin("", "tirx.nvshmem_fence") + return call_intrin("", "tirx.nvshmem.fence") def nvshmem_barrier_all(): @@ -4350,4 +4350,4 @@ def nvshmem_barrier_all(): The call expression. """ - return call_intrin("", "tirx.nvshmem_barrier_all") + return call_intrin("", "tirx.nvshmem.barrier_all") diff --git a/python/tvm/backend/cuda/operator/intrinsics/cp_async.py b/python/tvm/backend/cuda/operator/intrinsics/cp_async.py index 3e6bc015e81f..a63c3e784fa1 100644 --- a/python/tvm/backend/cuda/operator/intrinsics/cp_async.py +++ b/python/tvm/backend/cuda/operator/intrinsics/cp_async.py @@ -383,6 +383,9 @@ def _scale(n): return result[0] if isinstance(result, tuple) else result +CODEGEN_REGISTRY["tirx.ptx.cp_async_raw"] = CODEGEN_REGISTRY["tirx.ptx.cp_async"] + + # ============================================================================= # cp.async.bulk.tensor (TMA) — one device_intrinsic per arity variant of each # PTX form. Per-dim coord operands materialise via the ``c_signature`` callable. diff --git a/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py index 0afa042b5e19..55beaaa69add 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py @@ -813,7 +813,7 @@ def _make_desc_wrap(desc_buf, smem_buf, base, ldo, sdo, swizzle_val): """Build: { AllocBuffer(desc); encode(desc, smem); krp }""" encode_call = tvm.tirx.call_intrin( "", - "tirx.ptx_tcgen05_encode_matrix_descriptor", + "tirx.ptx.tcgen05_encode_matrix_descriptor", tvm.tirx.address_of(desc_buf[0]), smem_buf.ptr_to(base), ldo, diff --git a/python/tvm/backend/cuda/script.py b/python/tvm/backend/cuda/script.py index 76ba87344bc3..effea0c88555 100644 --- a/python/tvm/backend/cuda/script.py +++ b/python/tvm/backend/cuda/script.py @@ -136,7 +136,7 @@ def __init__(self): def __call__(self, *args, **kwds): # Accept the legacy 6-arg form ``(elem_dtype, dst, dst_off, src, # src_off, cp_size)`` that the printer round-trips for the raw - # ``tirx.ptx_cp_async`` Call emitted by + # ``tirx.ptx.cp_async`` Call emitted by # ``tvm.backend.cuda.transform.InjectPTXAsyncCopy``. The pass-emitted # Call has 5 args (no ``tvm_access_ptr`` fold) and a # per-element-dtype Call.dtype, so build it directly. @@ -146,7 +146,7 @@ def __call__(self, *args, **kwds): elem_dtype, dst, dst_off, src, src_off, cp_size = args return tvm.tirx.Call( tvm.DataType(elem_dtype), - tvm.ir.Op.get("tirx.ptx_cp_async"), + tvm.ir.Op.get("tirx.ptx.cp_async_raw"), [dst, dst_off, src, src_off, cp_size], ) return _dtype_forward(_cuda_op.ptx_cp_async)(*args, **kwds) @@ -201,7 +201,7 @@ def g2c_bar_addr( cache_policy, has_cache_policy = _cuda_op._resolve_cache_policy(cache_hint, cache_policy) return _tir_op.call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", + "tirx.ptx.cp_async_bulk_tensor_global_to_cluster", dim, dst_ptr, bar_addr, @@ -230,7 +230,7 @@ def g2c_tile_gather4_bar_addr( cache_policy, has_cache_policy = _cuda_op._resolve_cache_policy(cache_hint, cache_policy) return _tir_op.call_intrin( "", - "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", + "tirx.ptx.cp_async_bulk_tensor_tile_gather4_global_to_cluster", dim, dst_ptr, bar_addr, diff --git a/python/tvm/backend/trn/op.py b/python/tvm/backend/trn/op.py index d919e4fb8527..88cdd048fa98 100644 --- a/python/tvm/backend/trn/op.py +++ b/python/tvm/backend/trn/op.py @@ -22,49 +22,49 @@ def nki_load(res, data): - return call_intrin("", "tirx.nki_load", res, data) + return call_intrin("", "tirx.nki.load", res, data) def nki_store(res, data): - return call_intrin("", "tirx.nki_store", res, data) + return call_intrin("", "tirx.nki.store", res, data) def nki_tensor_copy(res, data): - return call_intrin("", "tirx.nki_tensor_copy", res, data) + return call_intrin("", "tirx.nki.tensor_copy", res, data) def nki_matmul(res, lhs, rhs, accum=True): - return call_intrin("", "tirx.nki_matmul", res, lhs, rhs, accum) + return call_intrin("", "tirx.nki.matmul", res, lhs, rhs, accum) def nki_activation(result, data, opcode, bias=0.0, scale=1.0): - return call_intrin("", "tirx.nki_activation", result, data, opcode, bias, scale) + return call_intrin("", "tirx.nki.activation", result, data, opcode, bias, scale) def nki_reciprocal(result, data): - return call_intrin("", "tirx.nki_reciprocal", result, data) + return call_intrin("", "tirx.nki.reciprocal", result, data) def nki_tensorreduce(result, data, opcode, negate, *axes): - return call_intrin("", "tirx.nki_tensorreduce", result, data, opcode, negate, *axes) + return call_intrin("", "tirx.nki.tensorreduce", result, data, opcode, negate, *axes) def nki_tensortensor(result, operand0, operand1, opcode): - return call_intrin("", "tirx.nki_tensortensor", result, operand0, operand1, opcode) + return call_intrin("", "tirx.nki.tensortensor", result, operand0, operand1, opcode) def nki_tensorscalar(result, operand0, operand1, opcode, reverse=False): - return call_intrin("", "tirx.nki_tensorscalar", result, operand0, operand1, opcode, reverse) + return call_intrin("", "tirx.nki.tensorscalar", result, operand0, operand1, opcode, reverse) def nki_memset(result, value): - return call_intrin("", "tirx.nki_memset", result, value) + return call_intrin("", "tirx.nki.memset", result, value) def nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias=0.0, scale=1.0): return call_intrin( "", - "tirx.nki_activation_reduce", + "tirx.nki.activation_reduce", reduce_res, act_res, data, @@ -80,7 +80,7 @@ def nki_tensorscalar_reduce( ): return call_intrin( "", - "tirx.nki_tensorscalar_reduce", + "tirx.nki.tensorscalar_reduce", reduce_res, tensorscalar_res, operand0, @@ -92,7 +92,7 @@ def nki_tensorscalar_reduce( def nki_identity(result, size): - return call_intrin("", "tirx.nki_identity", result, size) + return call_intrin("", "tirx.nki.identity", result, size) def nki_scalar_tensor_tensor( @@ -100,7 +100,7 @@ def nki_scalar_tensor_tensor( ): return call_intrin( "", - "tirx.nki_scalar_tensor_tensor", + "tirx.nki.scalar_tensor_tensor", result, data, operand0, @@ -117,7 +117,7 @@ def nki_scalar_tensor_scalar( ): return call_intrin( "", - "tirx.nki_scalar_tensor_scalar", + "tirx.nki.scalar_tensor_scalar", result, data, operand0, @@ -130,7 +130,7 @@ def nki_scalar_tensor_scalar( def nki_affine_select(result, pred, true_value, false_value): - return call_intrin("", "tirx.nki_affine_select", result, pred, true_value, false_value) + return call_intrin("", "tirx.nki.affine_select", result, pred, true_value, false_value) __all__ = [ diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 357f2c95857c..034620bf1c40 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -958,18 +958,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { static const Op& tvm_store_matrix_sync_op = Op::Get("tirx.tvm_store_matrix_sync"); static const Op& tvm_mma_sync_op = Op::Get("tirx.tvm_mma_sync"); static const Op& tvm_bmma_sync_op = Op::Get("tirx.tvm_bmma_sync"); - static const Op& ptx_mma_op = Op::Get("tirx.ptx_mma"); - static const Op& ptx_mma_sp_op = Op::Get("tirx.ptx_mma_sp"); + static const Op& ptx_mma_op = Op::Get("tirx.ptx.mma"); + static const Op& ptx_mma_sp_op = Op::Get("tirx.ptx.mma_sp"); static const Op& mma_store_op = Op::Get("tirx.mma_store"); static const Op& mma_fill_op = Op::Get("tirx.mma_fill"); - static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx_mma_legacy"); - static const Op& ptx_ldmatrix_legacy_op = Op::Get("tirx.ptx_ldmatrix_legacy"); + static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx.mma_legacy"); + static const Op& ptx_ldmatrix_legacy_op = Op::Get("tirx.ptx.ldmatrix_legacy"); static const Op& mma_store_legacy_op = Op::Get("tirx.mma_store_legacy"); static const Op& mma_fill_legacy_op = Op::Get("tirx.mma_fill_legacy"); - static const Op& ptx_cp_async_bulk_op = Op::Get("tirx.ptx_cp_async_bulk"); - static const Op& ptx_cp_async_mbarrier_arrive_op = Op::Get("tirx.ptx_cp_async_mbarrier_arrive"); + static const Op& ptx_cp_async_bulk_op = Op::Get("tirx.ptx.cp_async_bulk"); + static const Op& ptx_cp_async_mbarrier_arrive_op = Op::Get("tirx.ptx.cp_async_mbarrier_arrive"); static const Op& ptx_ldg32_op = Op::Get("tirx.ptx.ldg32"); - static const Op& cuda_func_call_op = Op::Get("tirx.cuda_func_call"); + static const Op& cuda_func_call_op = Op::Get("tirx.cuda.func_call"); if (op->op.same_as(tvm_fill_fragment_op)) { codegen_tags_.insert("mma"); @@ -1571,7 +1571,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { TVM_FFI_ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; this->VisitStmt(op->body); - static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx_cp_async_commit_group"); + static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx.cp_async_commit_group"); auto commit_group = Call(DataType::Void(), ptx_cp_async_commit_group_op, {}); this->PrintIndent(); this->VisitExpr(commit_group, this->stream); @@ -1583,7 +1583,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { TVM_FFI_ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; auto wait_cnt = wait_attrs.second; - static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx_cp_async_wait_group"); + static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx.cp_async_wait_group"); auto wait_group = Call(DataType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}); this->PrintIndent(); this->VisitExpr(wait_group, this->stream); diff --git a/src/backend/cuda/op/target_builtin.cc b/src/backend/cuda/op/target_builtin.cc index 353c04b501ec..5c5ad0b12d9e 100644 --- a/src/backend/cuda/op/target_builtin.cc +++ b/src/backend/cuda/op/target_builtin.cc @@ -66,26 +66,11 @@ TIRX_DEFINE_BUILTIN_FUNC(tvm_fill_fragment) TIRX_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIRX_DEFINE_BUILTIN_FUNC(ptx_mma) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - // Siblings of ptx_mma / ptx_ldmatrix / mma_store / mma_fill that accept // (ptr_var, offset) pairs. Codegen emits `ptr + offset` C-pointer // arithmetic and lower_warp_memory rewrites the offset's group component // to its thread-local index. Used by the s_tir tensor_intrin tensorize // path so per-thread fragment offsets stay element-accurate. -TIRX_DEFINE_BUILTIN_FUNC(ptx_mma_legacy) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_ldmatrix_legacy) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - TIRX_DEFINE_BUILTIN_FUNC(mma_store_legacy) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -100,100 +85,6 @@ OpRegEntry::RegisterOrGet("tirx.ptx.ldg32") .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) .set_attr("TDeviceIntrinsicNamespace", ffi::String("ptx"), 10); -TIRX_DEFINE_BUILTIN_FUNC(ptx_mma_sp) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_ldmatrix) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_shared_to_cluster) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_attr("TScriptDtypePrintLocation", - static_cast(ScriptDtypePrintLocation::kFirst)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_commit_group) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_wait_group) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_mbarrier_arrive) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_fence).set_attr( - "TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_fence_proxy_async) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_init) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive_expect_tx) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait_acquire_cluster) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_arrive) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_global_to_cluster) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_shared_to_global) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_global_to_cluster_prefetch) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_shared_to_global_reduce) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_commit_group) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_wait_group) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_barrier_cluster_arrive) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_barrier_cluster_wait) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_elect_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_fence_mbarrier_init) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - OpRegEntry::RegisterOrGet("tirx.ptx.fetch_register") .set_name() .set_num_inputs(-1) @@ -202,21 +93,17 @@ OpRegEntry::RegisterOrGet("tirx.ptx.fetch_register") .set_attr("TDeviceIntrinsicNamespace", ffi::String("ptx")) .set_attr("TScriptPrinterName", ffi::String("ptx.fetch_register")); -OpRegEntry::RegisterOrGet("tirx.ptx_fetch_register") +// Raw legacy cp.async form emitted by InjectPTXAsyncCopy (and round-tripped by +// the T.ptx.cp_async 6-arg surface). It carries the element dtype in Call.dtype +// and prints it dtype-first; the fork-native tirx.ptx.cp_async form does not. +OpRegEntry::RegisterOrGet("tirx.ptx.cp_async_raw") .set_name() - .set_num_inputs(-1) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("TIRxOpCategory", ffi::String("device_intrin")) .set_attr("TDeviceIntrinsicNamespace", ffi::String("ptx")) - .set_attr("TScriptPrinterName", ffi::String("ptx.fetch_register")); - -// griddepcontrol — programmatic dependent launch synchronization (sm_90+). -// Both are memory barriers; mark kOpaque to prevent CSE/reordering. -TIRX_DEFINE_BUILTIN_FUNC(ptx_griddepcontrol_wait) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_griddepcontrol_launch_dependents) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + .set_attr("TScriptPrinterName", ffi::String("ptx.cp_async")) + .set_attr("TScriptDtypePrintLocation", + static_cast(ScriptDtypePrintLocation::kFirst)); TIRX_DEFINE_BUILTIN_FUNC(mma_store) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) @@ -228,99 +115,6 @@ TIRX_DEFINE_BUILTIN_FUNC(mma_fill) .set_attr("TScriptDtypePrintLocation", static_cast(ScriptDtypePrintLocation::kFirst)); -TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_encode_matrix_descriptor) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_noop_barrier) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_mma_async_ss) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_mma_async_rs) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_fence) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_commit_group) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_wait_group) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_stmatrix) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_setmaxnreg) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_ld_global_acquire) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_alloc) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_dealloc) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_relinquish_alloc_permit) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_fence_before_thread_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_fence_after_thread_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_ld) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_st) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_wait_ld) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_wait_st) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_encode_matrix_descriptor) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_encode_instr_descriptor) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_encode_instr_descriptor_block_scaled) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma_block_scale) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma_sp) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma_sp_block_scale) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_commit) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_cp) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_shift) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(ptx_map_shared_rank) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(cuda_func_call) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - TIRX_DEFINE_BUILTIN_FUNC(timer_init_cuda) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -333,54 +127,6 @@ TIRX_DEFINE_BUILTIN_FUNC(timer_end_cuda) TIRX_DEFINE_BUILTIN_FUNC(timer_finalize_cuda) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_my_pe) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_n_pes) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_getmem_nbi) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_nbi) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_getmem_nbi_warp) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_nbi_warp) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_getmem_nbi_block) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_nbi_block) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_signal_op) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_wait_until) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_quiet) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_signal_nbi) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_signal_nbi_warp) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_signal_nbi_block) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_fence) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nvshmem_barrier_all) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - RegisterDeviceIntrinsicAliases(); // clang-format on } @@ -388,21 +134,20 @@ RegisterDeviceIntrinsicAliases(); namespace { struct DeviceIntrinsicRegistration { - const char* flat_name; + const char* name; const char* namespace_name; CallEffectKind effect_kind; }; void RegisterDeviceIntrinsic(const DeviceIntrinsicRegistration& reg) { - std::string flat_name(reg.flat_name); + std::string name(reg.name); std::string namespace_name(reg.namespace_name); std::string prefix = namespace_name + "_"; - std::string suffix = flat_name; + std::string suffix = name; if (suffix.rfind(prefix, 0) == 0) { suffix = suffix.substr(prefix.size()); } - std::string flat_op_name = "tirx." + flat_name; std::string canonical_op_name = "tirx." + namespace_name + "." + suffix; ffi::String namespace_attr(namespace_name); ffi::String printer_name(namespace_name + "." + suffix); @@ -419,7 +164,6 @@ void RegisterDeviceIntrinsic(const DeviceIntrinsicRegistration& reg) { .set_attr("TScriptPrinterName", printer_name, /*plevel=*/15); }; - register_one(flat_op_name); register_one(canonical_op_name); } diff --git a/src/backend/trn/codegen/codegen_trn.cc b/src/backend/trn/codegen/codegen_trn.cc index 9b798c3dc8f3..b057cb1509d7 100644 --- a/src/backend/trn/codegen/codegen_trn.cc +++ b/src/backend/trn/codegen/codegen_trn.cc @@ -360,22 +360,22 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL auto is_op = [&](const Op& compat, const char* canonical_name) { return op->op.same_as(compat) || (op_node != nullptr && op_node->name == canonical_name); }; - static const Op& nki_matmul_op = Op::Get("tirx.nki_matmul"); - static const Op& nki_load_op = Op::Get("tirx.nki_load"); - static const Op& nki_store_op = Op::Get("tirx.nki_store"); - static const Op& nki_tensor_copy_op = Op::Get("tirx.nki_tensor_copy"); - static const Op& nki_activation_op = Op::Get("tirx.nki_activation"); - static const Op& nki_reciprocal_op = Op::Get("tirx.nki_reciprocal"); - static const Op& nki_tensortensor_op = Op::Get("tirx.nki_tensortensor"); - static const Op& nki_tensorscalar_op = Op::Get("tirx.nki_tensorscalar"); - static const Op& nki_memset_op = Op::Get("tirx.nki_memset"); - static const Op& nki_tensorreduce_op = Op::Get("tirx.nki_tensorreduce"); - static const Op& nki_activation_reduce_op = Op::Get("tirx.nki_activation_reduce"); - static const Op& nki_tensorscalar_reduce_op = Op::Get("tirx.nki_tensorscalar_reduce"); - static const Op& nki_identity_op = Op::Get("tirx.nki_identity"); - static const Op& nki_scalar_tensor_tensor_op = Op::Get("tirx.nki_scalar_tensor_tensor"); - static const Op& nki_scalar_tensor_scalar_op = Op::Get("tirx.nki_scalar_tensor_scalar"); - static const Op& nki_affine_select_op = Op::Get("tirx.nki_affine_select"); + static const Op& nki_matmul_op = Op::Get("tirx.nki.matmul"); + static const Op& nki_load_op = Op::Get("tirx.nki.load"); + static const Op& nki_store_op = Op::Get("tirx.nki.store"); + static const Op& nki_tensor_copy_op = Op::Get("tirx.nki.tensor_copy"); + static const Op& nki_activation_op = Op::Get("tirx.nki.activation"); + static const Op& nki_reciprocal_op = Op::Get("tirx.nki.reciprocal"); + static const Op& nki_tensortensor_op = Op::Get("tirx.nki.tensortensor"); + static const Op& nki_tensorscalar_op = Op::Get("tirx.nki.tensorscalar"); + static const Op& nki_memset_op = Op::Get("tirx.nki.memset"); + static const Op& nki_tensorreduce_op = Op::Get("tirx.nki.tensorreduce"); + static const Op& nki_activation_reduce_op = Op::Get("tirx.nki.activation_reduce"); + static const Op& nki_tensorscalar_reduce_op = Op::Get("tirx.nki.tensorscalar_reduce"); + static const Op& nki_identity_op = Op::Get("tirx.nki.identity"); + static const Op& nki_scalar_tensor_tensor_op = Op::Get("tirx.nki.scalar_tensor_tensor"); + static const Op& nki_scalar_tensor_scalar_op = Op::Get("tirx.nki.scalar_tensor_scalar"); + static const Op& nki_affine_select_op = Op::Get("tirx.nki.affine_select"); if (is_op(nki_matmul_op, "tirx.nki.matmul")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); diff --git a/src/backend/trn/op/target_builtin.cc b/src/backend/trn/op/target_builtin.cc index c0d915bb2a36..a73057e60976 100644 --- a/src/backend/trn/op/target_builtin.cc +++ b/src/backend/trn/op/target_builtin.cc @@ -35,12 +35,6 @@ namespace tvm { namespace tirx { namespace builtin { -#define TIRX_DEFINE_BUILTIN_FUNC(OpName) \ - OpRegEntry::RegisterOrGet("tirx." #OpName) \ - .set_name() \ - .set_attr("TScriptPrinterName", ffi::String(#OpName), 1) \ - .set_attr("TIRxOpCategory", ffi::String("builtin"), /*plevel=*/1) - namespace { void RegisterNKIIntrinsicAliases(); } @@ -51,69 +45,19 @@ static bool registered = false; if (registered) return; registered = true; -TIRX_DEFINE_BUILTIN_FUNC(nki_load).set_attr( - "TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_store).set_attr( - "TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_tensor_copy) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_matmul) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_activation) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_reciprocal) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_tensortensor) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_tensorscalar) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_memset) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_tensorreduce) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_activation_reduce) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_tensorscalar_reduce) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_identity) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_scalar_tensor_tensor) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_scalar_tensor_scalar) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIRX_DEFINE_BUILTIN_FUNC(nki_affine_select) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - RegisterNKIIntrinsicAliases(); // clang-format on } namespace { -void RegisterNKIIntrinsic(const char* flat_name) { - std::string flat(flat_name); +void RegisterNKIIntrinsic(const char* name) { std::string prefix = "nki_"; - std::string suffix = flat; + std::string suffix(name); if (suffix.rfind(prefix, 0) == 0) { suffix = suffix.substr(prefix.size()); } - std::string flat_op_name = "tirx." + flat; std::string canonical_op_name = "tirx.nki." + suffix; ffi::String namespace_attr("nki"); ffi::String printer_name("nki." + suffix); @@ -130,7 +74,6 @@ void RegisterNKIIntrinsic(const char* flat_name) { .set_attr("TScriptPrinterName", printer_name, /*plevel=*/15); }; - register_one(flat_op_name); register_one(canonical_op_name); } @@ -161,8 +104,6 @@ void RegisterNKIIntrinsicAliases() { } // namespace -#undef TIRX_DEFINE_BUILTIN_FUNC - TVM_FFI_STATIC_INIT_BLOCK() { RegisterTRNTargetBuiltins(); } } // namespace builtin diff --git a/src/s_tir/transform/inject_ptx_async_copy.cc b/src/s_tir/transform/inject_ptx_async_copy.cc index 3a0f113499f8..500c2623be41 100644 --- a/src/s_tir/transform/inject_ptx_async_copy.cc +++ b/src/s_tir/transform/inject_ptx_async_copy.cc @@ -90,7 +90,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (predicated) { args.push_back(predicate_value); } - static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx.cp_async_raw"); return Evaluate(Call(store->buffer->dtype, ptx_cp_async_op, args)); } @@ -119,7 +119,7 @@ class PTXAsyncCopyInjector : public StmtMutator { return PrimExpr(); }(); if (src_offset.defined() && dst_offset.defined()) { - static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx.cp_async_raw"); return Evaluate(Call(store->buffer->dtype, ptx_cp_async_op, {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); @@ -149,7 +149,7 @@ class PTXAsyncCopyInjector : public StmtMutator { }(); if (src_offset.defined() && dst_offset.defined()) { - static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx.cp_async_raw"); return Evaluate( Call(store->buffer->dtype, ptx_cp_async_op, {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc b/src/s_tir/transform/merge_shared_memory_allocations.cc index 4b61c8994c97..1626d02e3f77 100644 --- a/src/s_tir/transform/merge_shared_memory_allocations.cc +++ b/src/s_tir/transform/merge_shared_memory_allocations.cc @@ -487,7 +487,7 @@ class SharedMemoryRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx.cp_async_raw"); if (op->op.same_as(builtin::tvm_access_ptr())) { TVM_FFI_ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); diff --git a/src/tirx/analysis/filter_canonical.cc b/src/tirx/analysis/filter_canonical.cc index dfefdd51c043..61af4812f710 100644 --- a/src/tirx/analysis/filter_canonical.cc +++ b/src/tirx/analysis/filter_canonical.cc @@ -45,12 +45,8 @@ bool IsBitwiseAndCall(const CallNode* call) { } bool IsPtxElectSyncCall(const CallNode* call) { - static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx_elect_sync"); - if (call->op.same_as(ptx_elect_sync_op)) return true; - if (auto op = call->op.as()) { - return op.value()->name == "tirx.ptx.elect_sync"; - } - return false; + static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx.elect_sync"); + return call->op.same_as(ptx_elect_sync_op); } // Strip implicit Cast wrappers from a predicate. Bool-vs-int mixing in the @@ -194,7 +190,7 @@ bool TryParseCompareAtom(const PrimExpr& expr, const ScopeIdPredicate& is_scope_ return true; } -// Try to read `expr` as a direct `Call("tirx.ptx_elect_sync")` atom. +// Try to read `expr` as a direct `Call("tirx.ptx.elect_sync")` atom. // Composed forms like `elect_sync() != 0` or `not elect_sync()` are NOT // accepted -- the canonical grammar requires a bare elect_sync call. bool TryParseElectSyncAtom(const PrimExpr& expr, FilterAtom* out) { diff --git a/src/tirx/analysis/filter_canonical.h b/src/tirx/analysis/filter_canonical.h index f3eb579214e0..6dfb6bff72d6 100644 --- a/src/tirx/analysis/filter_canonical.h +++ b/src/tirx/analysis/filter_canonical.h @@ -27,7 +27,7 @@ * * pred := atom (AND atom)* // pure n-ary conjunction (no OR/NOT) * atom := scopeid_var const // op in {==, <, <=, >, >=} - * | Call("tirx.ptx_elect_sync") + * | Call("tirx.ptx.elect_sync") * * Consumers: * 1. tile_primitive_dispatch routes a bare `if cond:` to atom-based @@ -62,7 +62,7 @@ namespace tirx { */ enum class FilterAtomKind { kRange, // scopeid_var in [lo, hi); covers ==, <, <=, >, >= - kElectSync, // Call("tirx.ptx_elect_sync") + kElectSync, // Call("tirx.ptx.elect_sync") }; /*! @@ -77,7 +77,7 @@ enum class FilterAtomKind { * - `elect_sync_call` is unset. * * For `kElectSync`: - * - `elect_sync_call`: the original `Call("tirx.ptx_elect_sync")` PrimExpr, + * - `elect_sync_call`: the original `Call("tirx.ptx.elect_sync")` PrimExpr, * preserved verbatim so downstream consumers (e.g. selector construction * in tile_primitive_dispatch) can reuse it without re-synthesizing. * - `scopeid_var`, `lo`, `hi` are unset. @@ -123,7 +123,7 @@ using ScopeIdPredicate = std::function; * Grammar (see file header): * pred := atom (AND atom)* * atom := scopeid_var const (op in {==, <, <=, >, >=}) - * | Call("tirx.ptx_elect_sync") + * | Call("tirx.ptx.elect_sync") * * Returns: * - `std::nullopt` if `cond` does not match the grammar. The caller should diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc index f732aa619b68..d4212bbb105f 100644 --- a/src/tirx/script/printer/expr.cc +++ b/src/tirx/script/printer/expr.cc @@ -315,7 +315,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // cuda_func_call: last arg is source_code (keyword-only in the Python API). // Print it as source_code=... to enable TVMScript round-trip. - if (op->name == "tirx.cuda_func_call" || op->name == "tirx.cuda.func_call") { + if (op->name == "tirx.cuda.func_call") { int n_args = call->args.size(); ffi::Array args; // All args except the last (source_code) are positional. diff --git a/src/tirx/transform/tile_primitive_dispatch.cc b/src/tirx/transform/tile_primitive_dispatch.cc index 80cde75e7e28..eceb897893ff 100644 --- a/src/tirx/transform/tile_primitive_dispatch.cc +++ b/src/tirx/transform/tile_primitive_dispatch.cc @@ -125,12 +125,8 @@ class ElectSyncFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { auto is_canonical_elect_sync = [&]() { - static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx_elect_sync"); - if (op->op.same_as(ptx_elect_sync_op)) return true; - if (auto call_op = op->op.as()) { - return call_op.value()->name == "tirx.ptx.elect_sync"; - } - return false; + static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx.elect_sync"); + return op->op.same_as(ptx_elect_sync_op); }; if (is_canonical_elect_sync()) { found_ = true; diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py index 1ed69262a464..e71df73e7121 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py @@ -28,11 +28,26 @@ from tvm.testing import env +def test_cp_async_raw_dtype_round_trips(): + # The raw cp.async form emitted by InjectPTXAsyncCopy carries the element + # dtype in Call.dtype and must survive a TVMScript print -> parse round-trip + # (it prints dtype-first via tirx.ptx.cp_async_raw). Guards the regression + # where the element dtype was dropped after the flat op was phased out. + @T.prim_func + def f(A: T.Buffer((128,), "float16"), B: T.Buffer((128,), "float16")): + T.func_attr({"global_symbol": "f"}) + for i in T.serial(8): + T.ptx.cp_async("float16", B.data, i * 16, A.data, i * 16, 16) + + reparsed = tvm.script.from_source(f.script()) + tvm.ir.assert_structural_equal(f, reparsed) + + def count_cp_async(stmt): num_alloc = [0] def verify(n): - if isinstance(n, tvm.tirx.Call) and n.op.name == "tirx.ptx_cp_async": + if isinstance(n, tvm.tirx.Call) and n.op.name == "tirx.ptx.cp_async_raw": num_alloc[0] += 1 tvm.tirx.stmt_functor.post_order_visit(stmt, verify)