Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 161 additions & 161 deletions python/tvm/backend/cuda/op.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions python/tvm/backend/cuda/operator/intrinsics/cp_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def _scale(n):
return result[0] if isinstance(result, tuple) else result


CODEGEN_REGISTRY["tirx.ptx.cp_async_raw"] = CODEGEN_REGISTRY["tirx.ptx.cp_async"]


# =============================================================================
# cp.async.bulk.tensor (TMA) — one device_intrinsic per arity variant of each
# PTX form. Per-dim coord operands materialise via the ``c_signature`` callable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def _make_desc_wrap(desc_buf, smem_buf, base, ldo, sdo, swizzle_val):
"""Build: { AllocBuffer(desc); encode(desc, smem); krp }"""
encode_call = tvm.tirx.call_intrin(
"",
"tirx.ptx_tcgen05_encode_matrix_descriptor",
"tirx.ptx.tcgen05_encode_matrix_descriptor",
tvm.tirx.address_of(desc_buf[0]),
smem_buf.ptr_to(base),
ldo,
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/backend/cuda/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(self):
def __call__(self, *args, **kwds):
# Accept the legacy 6-arg form ``(elem_dtype, dst, dst_off, src,
# src_off, cp_size)`` that the printer round-trips for the raw
# ``tirx.ptx_cp_async`` Call emitted by
# ``tirx.ptx.cp_async`` Call emitted by
# ``tvm.backend.cuda.transform.InjectPTXAsyncCopy``. The pass-emitted
# Call has 5 args (no ``tvm_access_ptr`` fold) and a
# per-element-dtype Call.dtype, so build it directly.
Expand All @@ -146,7 +146,7 @@ def __call__(self, *args, **kwds):
elem_dtype, dst, dst_off, src, src_off, cp_size = args
return tvm.tirx.Call(
tvm.DataType(elem_dtype),
tvm.ir.Op.get("tirx.ptx_cp_async"),
tvm.ir.Op.get("tirx.ptx.cp_async_raw"),
[dst, dst_off, src, src_off, cp_size],
)
return _dtype_forward(_cuda_op.ptx_cp_async)(*args, **kwds)
Expand Down Expand Up @@ -201,7 +201,7 @@ def g2c_bar_addr(
cache_policy, has_cache_policy = _cuda_op._resolve_cache_policy(cache_hint, cache_policy)
return _tir_op.call_intrin(
"",
"tirx.ptx_cp_async_bulk_tensor_global_to_cluster",
"tirx.ptx.cp_async_bulk_tensor_global_to_cluster",
dim,
dst_ptr,
bar_addr,
Expand Down Expand Up @@ -230,7 +230,7 @@ def g2c_tile_gather4_bar_addr(
cache_policy, has_cache_policy = _cuda_op._resolve_cache_policy(cache_hint, cache_policy)
return _tir_op.call_intrin(
"",
"tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster",
"tirx.ptx.cp_async_bulk_tensor_tile_gather4_global_to_cluster",
dim,
dst_ptr,
bar_addr,
Expand Down
32 changes: 16 additions & 16 deletions python/tvm/backend/trn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,49 @@


def nki_load(res, data):
return call_intrin("", "tirx.nki_load", res, data)
return call_intrin("", "tirx.nki.load", res, data)


def nki_store(res, data):
return call_intrin("", "tirx.nki_store", res, data)
return call_intrin("", "tirx.nki.store", res, data)


def nki_tensor_copy(res, data):
return call_intrin("", "tirx.nki_tensor_copy", res, data)
return call_intrin("", "tirx.nki.tensor_copy", res, data)


def nki_matmul(res, lhs, rhs, accum=True):
return call_intrin("", "tirx.nki_matmul", res, lhs, rhs, accum)
return call_intrin("", "tirx.nki.matmul", res, lhs, rhs, accum)


def nki_activation(result, data, opcode, bias=0.0, scale=1.0):
return call_intrin("", "tirx.nki_activation", result, data, opcode, bias, scale)
return call_intrin("", "tirx.nki.activation", result, data, opcode, bias, scale)


def nki_reciprocal(result, data):
return call_intrin("", "tirx.nki_reciprocal", result, data)
return call_intrin("", "tirx.nki.reciprocal", result, data)


def nki_tensorreduce(result, data, opcode, negate, *axes):
return call_intrin("", "tirx.nki_tensorreduce", result, data, opcode, negate, *axes)
return call_intrin("", "tirx.nki.tensorreduce", result, data, opcode, negate, *axes)


def nki_tensortensor(result, operand0, operand1, opcode):
return call_intrin("", "tirx.nki_tensortensor", result, operand0, operand1, opcode)
return call_intrin("", "tirx.nki.tensortensor", result, operand0, operand1, opcode)


def nki_tensorscalar(result, operand0, operand1, opcode, reverse=False):
return call_intrin("", "tirx.nki_tensorscalar", result, operand0, operand1, opcode, reverse)
return call_intrin("", "tirx.nki.tensorscalar", result, operand0, operand1, opcode, reverse)


def nki_memset(result, value):
return call_intrin("", "tirx.nki_memset", result, value)
return call_intrin("", "tirx.nki.memset", result, value)


def nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias=0.0, scale=1.0):
return call_intrin(
"",
"tirx.nki_activation_reduce",
"tirx.nki.activation_reduce",
reduce_res,
act_res,
data,
Expand All @@ -80,7 +80,7 @@ def nki_tensorscalar_reduce(
):
return call_intrin(
"",
"tirx.nki_tensorscalar_reduce",
"tirx.nki.tensorscalar_reduce",
reduce_res,
tensorscalar_res,
operand0,
Expand All @@ -92,15 +92,15 @@ def nki_tensorscalar_reduce(


def nki_identity(result, size):
return call_intrin("", "tirx.nki_identity", result, size)
return call_intrin("", "tirx.nki.identity", result, size)


def nki_scalar_tensor_tensor(
result, data, operand0, operand1, opcode0, opcode1, reverse0=False, reverse1=False
):
return call_intrin(
"",
"tirx.nki_scalar_tensor_tensor",
"tirx.nki.scalar_tensor_tensor",
result,
data,
operand0,
Expand All @@ -117,7 +117,7 @@ def nki_scalar_tensor_scalar(
):
return call_intrin(
"",
"tirx.nki_scalar_tensor_scalar",
"tirx.nki.scalar_tensor_scalar",
result,
data,
operand0,
Expand All @@ -130,7 +130,7 @@ def nki_scalar_tensor_scalar(


def nki_affine_select(result, pred, true_value, false_value):
return call_intrin("", "tirx.nki_affine_select", result, pred, true_value, false_value)
return call_intrin("", "tirx.nki.affine_select", result, pred, true_value, false_value)


__all__ = [
Expand Down
18 changes: 9 additions & 9 deletions src/backend/cuda/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -958,18 +958,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
static const Op& tvm_store_matrix_sync_op = Op::Get("tirx.tvm_store_matrix_sync");
static const Op& tvm_mma_sync_op = Op::Get("tirx.tvm_mma_sync");
static const Op& tvm_bmma_sync_op = Op::Get("tirx.tvm_bmma_sync");
static const Op& ptx_mma_op = Op::Get("tirx.ptx_mma");
static const Op& ptx_mma_sp_op = Op::Get("tirx.ptx_mma_sp");
static const Op& ptx_mma_op = Op::Get("tirx.ptx.mma");
static const Op& ptx_mma_sp_op = Op::Get("tirx.ptx.mma_sp");
static const Op& mma_store_op = Op::Get("tirx.mma_store");
static const Op& mma_fill_op = Op::Get("tirx.mma_fill");
static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx_mma_legacy");
static const Op& ptx_ldmatrix_legacy_op = Op::Get("tirx.ptx_ldmatrix_legacy");
static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx.mma_legacy");
static const Op& ptx_ldmatrix_legacy_op = Op::Get("tirx.ptx.ldmatrix_legacy");
static const Op& mma_store_legacy_op = Op::Get("tirx.mma_store_legacy");
static const Op& mma_fill_legacy_op = Op::Get("tirx.mma_fill_legacy");
static const Op& ptx_cp_async_bulk_op = Op::Get("tirx.ptx_cp_async_bulk");
static const Op& ptx_cp_async_mbarrier_arrive_op = Op::Get("tirx.ptx_cp_async_mbarrier_arrive");
static const Op& ptx_cp_async_bulk_op = Op::Get("tirx.ptx.cp_async_bulk");
static const Op& ptx_cp_async_mbarrier_arrive_op = Op::Get("tirx.ptx.cp_async_mbarrier_arrive");
static const Op& ptx_ldg32_op = Op::Get("tirx.ptx.ldg32");
static const Op& cuda_func_call_op = Op::Get("tirx.cuda_func_call");
static const Op& cuda_func_call_op = Op::Get("tirx.cuda.func_call");

if (op->op.same_as(tvm_fill_fragment_op)) {
codegen_tags_.insert("mma");
Expand Down Expand Up @@ -1571,7 +1571,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
TVM_FFI_ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx_cp_async_commit_group");
static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx.cp_async_commit_group");
auto commit_group = Call(DataType::Void(), ptx_cp_async_commit_group_op, {});
this->PrintIndent();
this->VisitExpr(commit_group, this->stream);
Expand All @@ -1583,7 +1583,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
TVM_FFI_ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx_cp_async_wait_group");
static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx.cp_async_wait_group");
auto wait_group = Call(DataType::Void(), ptx_cp_async_wait_group_op, {wait_cnt});
this->PrintIndent();
this->VisitExpr(wait_group, this->stream);
Expand Down
Loading
Loading