From cdb1193c2c28694de57a5ecaa97bdcffa2736e9c Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 13:32:40 +0000 Subject: [PATCH 01/31] quick install script on top of MLIR docker image --- quick_install.sh | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ set_ptoas_env.sh | 11 ++++++++ 2 files changed, 80 insertions(+) create mode 100755 quick_install.sh create mode 100644 set_ptoas_env.sh diff --git a/quick_install.sh b/quick_install.sh new file mode 100755 index 000000000..e41233e1a --- /dev/null +++ b/quick_install.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +# For quick development, build and install ptoas and its python bindings +# on top of Docker image https://github.com/learning-chip/agent_docker_npu/pull/8 +# assume MLIR is already installed to save time, takes <3min to finish the build of pto extension +# +# Optional env: +# LLVM_BUILD_DIR - default: ${LLVM_SOURCE_DIR:-/llvm-workspace/llvm-project}/build-shared +# PTO_INSTALL_DIR - default: /install + +set -euo pipefail + +PTO_SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" + +LLVM_SOURCE_DIR="${LLVM_SOURCE_DIR:-/llvm-workspace/llvm-project}" +LLVM_BUILD_DIR="${LLVM_BUILD_DIR:-${LLVM_SOURCE_DIR}/build-shared}" + +PY_ROOT="$(python -c 'import sys; print(sys.prefix)')" + +for d in "$LLVM_BUILD_DIR/lib/cmake/llvm" "$LLVM_BUILD_DIR/lib/cmake/mlir"; do + test -d "$d" || { echo "error: missing $d (set LLVM_BUILD_DIR?)" >&2; exit 1; } +done + +PYBIND11_DIR="$(python -m pybind11 --cmakedir)" +MLIR_PY_PKG="${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core" +test -d "$MLIR_PY_PKG" || { echo "error: MLIR python package dir missing: $MLIR_PY_PKG" >&2; exit 1; } + +PTOAS_VERSION="${PTOAS_VERSION:-$(python "${PTO_SOURCE_DIR}/.github/scripts/compute_ptoas_version.py" --cmake-file "${PTO_SOURCE_DIR}/CMakeLists.txt" --mode dev)}" + +cd "$PTO_SOURCE_DIR" + +cmake -C "${PTO_SOURCE_DIR}/cmake/LinuxHardeningCache.cmake" -G Ninja \ + -S . \ + -B build \ + -DLLVM_DIR="${LLVM_BUILD_DIR}/lib/cmake/llvm" \ + -DMLIR_DIR="${LLVM_BUILD_DIR}/lib/cmake/mlir" \ + -DPython3_ROOT_DIR="${PY_ROOT}" \ + -DPython3_EXECUTABLE=python \ + -DPython3_FIND_STRATEGY=LOCATION \ + -Dpybind11_DIR="${PYBIND11_DIR}" \ + -DMLIR_PYTHON_PACKAGE_DIR="${MLIR_PY_PKG}" \ + -DPTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ + -DCMAKE_INSTALL_PREFIX="${PTO_INSTALL_DIR}" + +ninja -C build +ninja -C build install + +export PTO_SOURCE_DIR PTO_INSTALL_DIR LLVM_BUILD_DIR +export PTOAS_PYTHON_PACKAGE_VERSION="${PTOAS_PYTHON_PACKAGE_VERSION:-${PTOAS_VERSION}}" +bash "${PTO_SOURCE_DIR}/docker/create_wheel.sh" + +shopt -s nullglob +wheels=("${MLIR_PY_PKG}/dist/ptoas-"*.whl) +shopt -u nullglob +((${#wheels[@]} > 0)) || { echo "error: no ptoas-*.whl under ${MLIR_PY_PKG}/dist" >&2; exit 1; } +pip install --force-reinstall "${wheels[0]}" + +export PATH="${PTO_SOURCE_DIR}/build/tools/ptoas:${PATH}" +export LD_LIBRARY_PATH="${LLVM_BUILD_DIR}/lib:${PTO_INSTALL_DIR}/lib:${LD_LIBRARY_PATH:-}" + +python -c "import mlir.ir" +python -c "from mlir.dialects import pto" + +which ptoas + +(cd "${PTO_SOURCE_DIR}/test/samples/MatMul" && python ./tmatmulk.py > ./tmatmulk.pto && ptoas ./tmatmulk.pto -o ./tmatmulk.cpp) +(cd "${PTO_SOURCE_DIR}/test/samples/Abs" && python ./abs.py > ./abs.pto && ptoas --enable-insert-sync ./abs.pto -o ./abs.cpp) + +echo "quick_install.sh: OK" diff --git a/set_ptoas_env.sh b/set_ptoas_env.sh new file mode 100644 index 000000000..c8a94bfb1 --- /dev/null +++ b/set_ptoas_env.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# after `quick_install.sh`, run `source set_ptoas_env.sh` in a new shell to find the lib +export PTO_SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" +export PATH="${PTO_SOURCE_DIR}/build/tools/ptoas:${PATH}" +export LD_LIBRARY_PATH="${LLVM_BUILD_DIR}/lib:${PTO_INSTALL_DIR}/lib:${LD_LIBRARY_PATH:-}" + +(cd "${PTO_SOURCE_DIR}/test/samples/MatMul" && python ./tmatmulk.py > ./tmatmulk.pto && ptoas ./tmatmulk.pto -o ./tmatmulk.cpp) +(cd "${PTO_SOURCE_DIR}/test/samples/Abs" && python ./abs.py > ./abs.pto && ptoas --enable-insert-sync ./abs.pto -o ./abs.cpp) + +echo "test set_env: OK" From 217dc2adcd82add26a6704eee19aacb1e1c465dc Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 13:41:42 +0000 Subject: [PATCH 02/31] add reference result for top->vop expansion --- .../lit/vpto/expand_tileop_to_vpto_result.pto | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 test/lit/vpto/expand_tileop_to_vpto_result.pto diff --git a/test/lit/vpto/expand_tileop_to_vpto_result.pto b/test/lit/vpto/expand_tileop_to_vpto_result.pto new file mode 100644 index 000000000..9644fc204 --- /dev/null +++ b/test/lit/vpto/expand_tileop_to_vpto_result.pto @@ -0,0 +1,32 @@ +// Generated by command: +// ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand --mlir-print-ir-after-all ./expand_tile_op_tilelang.pto -o out.pto + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} { + func.func @TADD() { + %c0_i64 = arith.constant 0 : i64 + %c16 = arith.constant 16 : index + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64_i32 = arith.constant 64 : i32 + %c64 = arith.constant 64 : index + pto.vecscope { + %0 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %1 = pto.castptr %c0_i64 : i64 -> !pto.ptr + scf.for %arg0 = %c0 to %c16 step %c1 { + %mask, %scalar_out = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %2 = arith.muli %arg0, %c64 : index + %3 = pto.addptr %0, %2 : -> + %4 = pto.vlds %3[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %5 = pto.addptr %1, %2 : -> + %6 = pto.vlds %5[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %7 = pto.vadd %4, %6, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %7, %5[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + return + } + } +} + From c5b540d52619a10677ac9406faf643297bf2a9c9 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 14:11:30 +0000 Subject: [PATCH 03/31] low-level python binding example to generate vpto IR --- ptodsl/build_expand_tileop_to_vpto.py | 167 ++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 ptodsl/build_expand_tileop_to_vpto.py diff --git a/ptodsl/build_expand_tileop_to_vpto.py b/ptodsl/build_expand_tileop_to_vpto.py new file mode 100644 index 000000000..43bb6490a --- /dev/null +++ b/ptodsl/build_expand_tileop_to_vpto.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Builds the MLIR IR module equivalent to expand_tileop_to_vpto_result.pto using +low-level MLIR Python bindings. + +Target IR (expand_tileop_to_vpto_result.pto): + module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} { + func.func @TADD() { + %c0_i64 = arith.constant 0 : i64 + %c16 = arith.constant 16 : index + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64_i32 = arith.constant 64 : i32 + %c64 = arith.constant 64 : index + pto.vecscope { + %0 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %1 = pto.castptr %c0_i64 : i64 -> !pto.ptr + scf.for %arg0 = %c0 to %c16 step %c1 { + %mask, %scalar_out = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %2 = arith.muli %arg0, %c64 : index + %3 = pto.addptr %0, %2 : -> + %4 = pto.vlds %3[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %5 = pto.addptr %1, %2 : -> + %6 = pto.vlds %5[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %7 = pto.vadd %4, %6, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %7, %5[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + return + } + } + } +""" + +from mlir.ir import ( + Attribute, + Context, + F32Type, + IntegerType, + IndexType, + InsertionPoint, + Location, + Module, + Operation, + StringAttr, + Type, +) +from mlir.dialects import arith, func, pto, scf + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(): + # ── Types ──────────────────────────────────────────────────────── + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + idx = IndexType.get() + + # !pto.ptr – pointer to f32 in the "ub" address space + ptr_f32_ub = Type.parse("!pto.ptr") + + # !pto.vreg<64xf32> – vector register holding 64 × f32 + vreg_64f32 = Type.parse("!pto.vreg<64xf32>") + + # !pto.mask – predicate register for 32-bit element ops + mask_b32 = Type.parse("!pto.mask") + + # ── Shared attributes ───────────────────────────────────────── + target_arch_attr = StringAttr.get("a5") + kernel_kind_attr = Attribute.parse("#pto.kernel_kind") + + # ── Outer module ───────────────────────────────────────────── + outer_mod = Module.create() + outer_mod.operation.attributes["pto.target_arch"] = target_arch_attr + + with InsertionPoint(outer_mod.body): + # ── Inner module ───────────────────────────────────────── + # Module.create() does not use the active InsertionPoint, so we + # use Operation.create("builtin.module") directly instead. + inner_op = Operation.create("builtin.module", regions=1) + inner_op.attributes["pto.target_arch"] = target_arch_attr + inner_op.attributes["pto.kernel_kind"] = kernel_kind_attr + + # builtin.module needs exactly one block in its body region. + inner_body = inner_op.regions[0].blocks.append() + + with InsertionPoint(inner_body): + # ── func @TADD() ────────────────────────────────────── + fn_ty = func.FunctionType.get([], []) + fn = func.FuncOp("TADD", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + # Constants live outside vecscope; they are visible + # inside because vecscope is not a new scope for SSA. + c0_i64 = arith.ConstantOp(i64, 0).result + c16 = arith.ConstantOp(idx, 16).result + c4096_i64 = arith.ConstantOp(i64, 4096).result + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c64_i32 = arith.ConstantOp(i32, 64).result + c64 = arith.ConstantOp(idx, 64).result + + # ── pto.vecscope { … } ──────────────────────────── + vecscope_op = pto.VecScopeOp() + # vecscope has one region; we must append its entry block. + vs_block = vecscope_op.body.blocks.append() + + with InsertionPoint(vs_block): + # %0 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + ptr0 = pto.CastPtrOp(ptr_f32_ub, c4096_i64).result + + # %1 = pto.castptr %c0_i64 : i64 -> !pto.ptr + ptr1 = pto.CastPtrOp(ptr_f32_ub, c0_i64).result + + # scf.for %arg0 = %c0 to %c16 step %c1 { … } + for_op = scf.ForOp(c0, c16, c1) + with InsertionPoint(for_op.body): + arg0 = for_op.induction_variable + + # %mask, %scalar_out = pto.plt_b32 %c64_i32 + plt = pto.PltB32Op(mask_b32, i32, c64_i32) + mask = plt.mask + # scalar_out is unused in this kernel + + # %2 = arith.muli %arg0, %c64 : index + off = arith.MulIOp(arg0, c64).result + + # %3 = pto.addptr %0, %2 + ptr3 = pto.AddPtrOp(ptr0, off).result + + # %4 = pto.vlds %3[%c0] : !pto.ptr -> !pto.vreg<64xf32> + vreg4 = pto.VldsOp(vreg_64f32, ptr3, c0).result + + # %5 = pto.addptr %1, %2 + ptr5 = pto.AddPtrOp(ptr1, off).result + + # %6 = pto.vlds %5[%c0] : !pto.ptr -> !pto.vreg<64xf32> + vreg6 = pto.VldsOp(vreg_64f32, ptr5, c0).result + + # %7 = pto.vadd %4, %6, %mask + vreg7 = pto.VaddOp(vreg_64f32, vreg4, vreg6, mask).result + + # pto.vsts %7, %5[%c0], %mask + pto.VstsOp(vreg7, ptr5, c0, mask) + + scf.YieldOp([]) + + func.ReturnOp([]) + + outer_mod.operation.verify() + return outer_mod + + +if __name__ == "__main__": + print(build()) From 8697f5ef62d4487262a16b8fce9d2fb2bb8a4aaa Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 16:06:55 +0000 Subject: [PATCH 04/31] initial prototype of high-level dsl builder api --- ptodsl/ptodsl_utils.py | 251 ++++++++++++++++++ ptodsl/tile_and_vpto_builder_highlevel.py | 86 ++++++ ...o.py => tile_and_vpto_builder_lowlevel.py} | 0 3 files changed, 337 insertions(+) create mode 100644 ptodsl/ptodsl_utils.py create mode 100644 ptodsl/tile_and_vpto_builder_highlevel.py rename ptodsl/{build_expand_tileop_to_vpto.py => tile_and_vpto_builder_lowlevel.py} (100%) diff --git a/ptodsl/ptodsl_utils.py b/ptodsl/ptodsl_utils.py new file mode 100644 index 000000000..b7a4898df --- /dev/null +++ b/ptodsl/ptodsl_utils.py @@ -0,0 +1,251 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Lightweight wrappers around the low-level MLIR Python bindings for the PTO +dialect. The goal is to eliminate boilerplate so that a vPTO kernel body can +be written in plain-looking Python without manual InsertionPoint management, +verbose type constructors, or raw Operation.create() calls. + +Design rules +──────────── +• Every helper is a plain function or a contextlib.contextmanager – no classes. +• All helpers work with the *current* MLIR context / location / insertion-point + (set by `pto_context` and `vpto_kernel`); no context parameter is threaded. +• The module is self-contained: only mlir.* imports are allowed. +""" + +from contextlib import contextmanager + +from mlir.ir import ( + Attribute, + Context, + IntegerType, + IndexType, + InsertionPoint, + Location, + Module, + Operation, + StringAttr, + Type, +) +from mlir.dialects import arith, func, pto, scf + + +# ─── Type constructors ──────────────────────────────────────────────────────── + +def i32_type(): + """Signless 32-bit integer type.""" + return IntegerType.get_signless(32) + + +def i64_type(): + """Signless 64-bit integer type.""" + return IntegerType.get_signless(64) + + +def idx_type(): + """MLIR index type.""" + return IndexType.get() + + +def ptr_type(elem_type, space="ub"): + """PTO pointer type: !pto.ptr<{elem_type}, {space}>.""" + return Type.parse(f"!pto.ptr<{elem_type}, {space}>") + + +def vreg_type(lanes, elem_type): + """PTO vector-register type: !pto.vreg<{lanes}x{elem_type}>.""" + return Type.parse(f"!pto.vreg<{lanes}x{elem_type}>") + + +def mask_type(bits="b32"): + """PTO mask/predicate type: !pto.mask<{bits}> (b8 | b16 | b32).""" + return Type.parse(f"!pto.mask<{bits}>") + + +# ─── Constant builders ─────────────────────────────────────────────────────── + +def c_idx(value): + """Emit an index constant.""" + return arith.ConstantOp(IndexType.get(), value).result + + +def c_i32(value): + """Emit a 32-bit integer constant.""" + return arith.ConstantOp(IntegerType.get_signless(32), value).result + + +def c_i64(value): + """Emit a 64-bit integer constant.""" + return arith.ConstantOp(IntegerType.get_signless(64), value).result + + +# ─── Arithmetic shorthands ─────────────────────────────────────────────────── + +def muli(lhs, rhs): + """arith.muli""" + return arith.MulIOp(lhs, rhs).result + + +def addi(lhs, rhs): + """arith.addi""" + return arith.AddIOp(lhs, rhs).result + + +def subi(lhs, rhs): + """arith.subi""" + return arith.SubIOp(lhs, rhs).result + + +# ─── PTO vector / pointer operations ──────────────────────────────────────── + +def castptr(int_addr, result_ptr_type): + """Cast an integer address to a typed PTO pointer (pto.castptr).""" + return pto.CastPtrOp(result_ptr_type, int_addr).result + + +def addptr(base_ptr, index_offset): + """Advance a PTO pointer by an index offset (pto.addptr).""" + return pto.AddPtrOp(base_ptr, index_offset).result + + +def vlds(src_ptr, offset, result_vreg_type): + """Vector load from a PTO pointer at *offset* (pto.vlds).""" + return pto.VldsOp(result_vreg_type, src_ptr, offset).result + + +def vadd(lhs, rhs, mask, result_vreg_type): + """Element-wise vector add under a predicate mask (pto.vadd).""" + return pto.VaddOp(result_vreg_type, lhs, rhs, mask).result + + +def vsts(val, dst_ptr, offset, mask): + """Vector store to a PTO pointer at *offset* under a mask (pto.vsts).""" + pto.VstsOp(val, dst_ptr, offset, mask) + + +def plt_b32(scalar): + """ + Predicate-load from a 32-bit scalar value (pto.plt_b32). + + Returns (mask_value, scalar_out) – the mask is typically the only value + used downstream; scalar_out can be discarded with ``_``. + """ + plt_op = pto.PltB32Op(mask_type("b32"), i32_type(), scalar) + return plt_op.mask, plt_op.scalar_out + + +# ─── Scope context managers ────────────────────────────────────────────────── + +@contextmanager +def vecscope(): + """ + Emit a ``pto.vecscope { ... }`` region. + + Usage:: + + with vecscope(): + ptr = castptr(addr, ptr_f32) + ... + """ + op = pto.VecScopeOp() + block = op.body.blocks.append() + with InsertionPoint(block): + yield + + +@contextmanager +def for_range(start, stop, step): + """ + Emit an ``scf.for`` loop; yield the induction variable. + The mandatory ``scf.yield`` terminator is inserted automatically on exit. + + Usage:: + + with for_range(c0, c16, c1) as i: + off = muli(i, c64) + ... + """ + for_op = scf.ForOp(start, stop, step) + with InsertionPoint(for_op.body): + yield for_op.induction_variable + scf.YieldOp([]) + + +# ─── Top-level module / kernel builder ─────────────────────────────────────── + +@contextmanager +def pto_context(): + """ + Activate an MLIR context with the PTO dialect registered. + Must wrap all other utility calls. + + Usage:: + + with pto_context(): + f32 = F32Type.get() + with vpto_kernel("MyKernel", arch="a5") as mod: + ... + """ + with Context() as ctx: + pto.register_dialect(ctx, load=True) + with Location.unknown(): + yield ctx + + +@contextmanager +def vpto_kernel(func_name, *, arch="a5"): + """ + Build the standard two-level nested-module + no-arg ``func.func`` shell + for a vPTO vector kernel, then yield the outer ``Module`` as the context + variable. ``func.ReturnOp`` and ``module.verify()`` are inserted/called + automatically on context exit. + + The emitted skeleton is:: + + module attributes {pto.target_arch = arch} { + module attributes {pto.kernel_kind = #pto.kernel_kind, + pto.target_arch = arch} { + func.func @func_name() { + + return + } + } + } + + Usage:: + + with vpto_kernel("TADD", arch="a5") as mod: + c0 = c_idx(0) + ... + return mod + """ + arch_attr = StringAttr.get(arch) + kind_attr = Attribute.parse("#pto.kernel_kind") + + outer_mod = Module.create() + outer_mod.operation.attributes["pto.target_arch"] = arch_attr + + with InsertionPoint(outer_mod.body): + # Module.create() ignores the active InsertionPoint, so use + # Operation.create("builtin.module") to insert the inner module. + inner_op = Operation.create("builtin.module", regions=1) + inner_op.attributes["pto.target_arch"] = arch_attr + inner_op.attributes["pto.kernel_kind"] = kind_attr + inner_body = inner_op.regions[0].blocks.append() + + with InsertionPoint(inner_body): + fn = func.FuncOp(func_name, func.FunctionType.get([], [])) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + yield outer_mod + func.ReturnOp([]) + + outer_mod.operation.verify() diff --git a/ptodsl/tile_and_vpto_builder_highlevel.py b/ptodsl/tile_and_vpto_builder_highlevel.py new file mode 100644 index 000000000..867657e4a --- /dev/null +++ b/ptodsl/tile_and_vpto_builder_highlevel.py @@ -0,0 +1,86 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +High-level builder for the TADD vPTO kernel. + +Reconstructs the same IR as expand_tileop_to_vpto_result.pto using the +thin wrappers in ptodsl_utils instead of raw MLIR Python binding calls. + +Compare with tile_and_vpto_builder_lowlevel.py to see what the utils hide: + • No manual InsertionPoint management + • No Operation.create("builtin.module", ...) boilerplate + • No Type.parse() / arith.ConstantOp(...).result calls in the kernel body + • vecscope and scf.for become ordinary Python context managers +""" + +from mlir.ir import F32Type + +from ptodsl_utils import ( + # types + ptr_type, vreg_type, + # constants + c_idx, c_i32, c_i64, + # arithmetic + muli, + # vector / pointer ops + castptr, addptr, vlds, vadd, vsts, plt_b32, + # scope helpers + vecscope, for_range, + # module builders + pto_context, vpto_kernel, +) + + +def build(): + with pto_context(): + # ── Types used in this kernel ───────────────────────────────────── + f32 = F32Type.get() + ptr_f32_ub = ptr_type(f32, "ub") # !pto.ptr + vreg_64f32 = vreg_type(64, f32) # !pto.vreg<64xf32> + + # ── Build the nested module shell and the @TADD function body ───── + with vpto_kernel("TADD", arch="a5") as mod: + + # Integer-address constants for the two input buffers + c0_i64 = c_i64(0) + c4096_i64 = c_i64(4096) + + # Loop-control constants + c0 = c_idx(0) + c1 = c_idx(1) + c16 = c_idx(16) # 1024-element array / 64-wide vreg = 16 tiles + + # Scalar used to generate the per-iteration mask + c64_i32 = c_i32(64) + c64 = c_idx(64) + + with vecscope(): + # Materialise typed pointers from the raw integer addresses + ptr_src = castptr(c4096_i64, ptr_f32_ub) # source buffer + ptr_dst = castptr(c0_i64, ptr_f32_ub) # destination buffer + + with for_range(c0, c16, c1) as tile_idx: + # Build a 64-lane all-true mask for this iteration + mask, _ = plt_b32(c64_i32) + + # Byte offset for the current 64-element tile + tile_off = muli(tile_idx, c64) + + # Load source tile, add to destination tile, store result + va = vlds(addptr(ptr_src, tile_off), c0, vreg_64f32) + ptr_dst_tile = addptr(ptr_dst, tile_off) + vb = vlds(ptr_dst_tile, c0, vreg_64f32) + vc = vadd(va, vb, mask, vreg_64f32) + vsts(vc, ptr_dst_tile, c0, mask) + + return mod + + +if __name__ == "__main__": + print(build()) diff --git a/ptodsl/build_expand_tileop_to_vpto.py b/ptodsl/tile_and_vpto_builder_lowlevel.py similarity index 100% rename from ptodsl/build_expand_tileop_to_vpto.py rename to ptodsl/tile_and_vpto_builder_lowlevel.py From cf4ece0e6c275c24c108d746cc5373f2a6b4c29f Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 18:01:18 +0000 Subject: [PATCH 05/31] initial prototype of softmax IR builder --- ptodsl/ptodsl_utils.py | 249 +++++++++++++++++++++ ptodsl/softmax_builder_highlevel.py | 234 ++++++++++++++++++++ ptodsl/softmax_builder_lowlevel.py | 328 ++++++++++++++++++++++++++++ 3 files changed, 811 insertions(+) create mode 100644 ptodsl/softmax_builder_highlevel.py create mode 100644 ptodsl/softmax_builder_lowlevel.py diff --git a/ptodsl/ptodsl_utils.py b/ptodsl/ptodsl_utils.py index b7a4898df..ed702f861 100644 --- a/ptodsl/ptodsl_utils.py +++ b/ptodsl/ptodsl_utils.py @@ -33,6 +33,7 @@ Operation, StringAttr, Type, + UnitAttr, ) from mlir.dialects import arith, func, pto, scf @@ -249,3 +250,251 @@ def vpto_kernel(func_name, *, arch="a5"): func.ReturnOp([]) outer_mod.operation.verify() + + +# ─── Flat single-module builders (for direct func inside module) ───────────── + +@contextmanager +def flat_pto_module(arch="a5"): + """ + Flat single-level module with ``pto.target_arch`` and + ``pto.kernel_kind = #pto.kernel_kind``. + + Usage:: + + with flat_pto_module("a5") as mod: + with pto_aicore_func("MyKernel", [ptr_gm, i32]) as args: + ... + return mod + """ + m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get(arch) + m.operation.attributes["pto.kernel_kind"] = Attribute.parse( + "#pto.kernel_kind" + ) + with InsertionPoint(m.body): + yield m + m.operation.verify() + + +@contextmanager +def pto_aicore_func(func_name, arg_types, *, ret_types=None): + """ + Create a ``func.func`` with the ``pto.aicore`` attribute. + Yields the function's block arguments tuple. + ``func.return`` is inserted automatically on exit. + + Usage:: + + with pto_aicore_func("f", [ptr_gm, ptr_gm, i32]) as (p0, p1, n): + ... + """ + fn_ty = func.FunctionType.get(arg_types, ret_types or []) + fn = func.FuncOp(func_name, fn_ty) + fn.attributes["pto.aicore"] = UnitAttr.get() + entry = fn.add_entry_block() + with InsertionPoint(entry): + yield tuple(entry.arguments) + func.ReturnOp([]) + + +# ─── Additional control-flow helpers ───────────────────────────────────────── + +@contextmanager +def if_ctx(cond): + """ + Emit ``scf.if cond { ... }`` with no results and no else branch. + The mandatory ``scf.yield`` terminator is inserted automatically. + + Usage:: + + with if_ctx(has_rows): + tload(part, tile) + ... + """ + op = scf.IfOp(cond) + with InsertionPoint(op.then_block): + yield + scf.YieldOp([]) + + +def if_op_returning(cond, result_types): + """ + Create a ``scf.if`` with results *and* an else branch. + Returns the raw ``IfOp`` so the caller can manage the two blocks + manually with ``InsertionPoint`` and close each with ``yield_vals()``. + + Usage:: + + br = if_op_returning(has_chunk, [vreg_f32, vreg_f32]) + with InsertionPoint(br.then_block): + ... + yield_vals(merged_max, merged_sum) + with InsertionPoint(br.else_block): + yield_vals(running_max, running_sum) + next_max, next_sum = br.results + """ + return scf.IfOp(cond, result_types, hasElse=True) + + +@contextmanager +def for_range_iter(start, stop, step, init_vals): + """ + Emit ``scf.for`` with iter_args. Yields the raw ``ForOp`` so the + caller can access ``induction_variable``, ``inner_iter_args``, and + ``results`` (after the ``with`` block). + + The caller **must** call ``yield_vals(...)`` at the end of the body. + + Usage:: + + with for_range_iter(c0, c128, c64, [a, b]) as cf: + i = cf.induction_variable + x, y = cf.inner_iter_args + ... + yield_vals(new_x, new_y) + final_x, final_y = cf.results + """ + for_op = scf.ForOp(start, stop, step, init_vals) + with InsertionPoint(for_op.body): + yield for_op + + +def yield_vals(*vals): + """Emit ``scf.yield`` with the given values (shorthand for scf.YieldOp).""" + scf.YieldOp(list(vals)) + + +# ─── Arithmetic helpers ─────────────────────────────────────────────────────── + +def index_cast(result_type, val): + """arith.index_cast from/to index.""" + return arith.IndexCastOp(result_type, val).result + + +def cmpi_sgt(lhs, rhs): + """arith.cmpi sgt (signed greater-than).""" + return arith.CmpIOp(arith.CmpIPredicate.sgt, lhs, rhs).result + + +def select_val(cond, true_val, false_val): + """arith.select.""" + return arith.SelectOp(cond, true_val, false_val).result + + +# ─── PTO hardware helpers ───────────────────────────────────────────────────── + +def get_block_idx(): + """pto.get_block_idx → i64 block index.""" + return pto.GetBlockIdxOp().result + + +def barrier_all(): + """pto.barrier #pto.pipe.""" + pto.BarrierOp(Attribute.parse("#pto.pipe")) + + +# ─── Tile-domain helpers ────────────────────────────────────────────────────── + +def tile_view(tv_type, ptr, shape, strides): + """pto.make_tensor_view → tensor_view SSA value.""" + return pto.MakeTensorViewOp(tv_type, ptr, shape, strides).result + + +def part_view(ptv_type, tv, offsets, sizes): + """pto.partition_view → partition_tensor_view SSA value.""" + return pto.PartitionViewOp(ptv_type, tv, offsets, sizes).result + + +def alloc_tile(tile_type, *, addr, valid_row, valid_col=None): + """pto.alloc_tile with optional valid_col.""" + return pto.AllocTileOp(tile_type, addr=addr, valid_row=valid_row, + valid_col=valid_col).result + + +def tload(part, tile): + """pto.tload ins(part) outs(tile).""" + pto.TLoadOp(None, part, tile) + + +def tstore(tile, part): + """pto.tstore ins(tile) outs(part).""" + pto.TStoreOp(None, tile, part) + + +def tile_ptr(tile, result_ptr_type): + """pto.tile_buf_addr – materialise a UB pointer from a tile handle.""" + return pto.TileBufAddrOp(result_ptr_type, tile).result + + +# ─── Mask helpers ───────────────────────────────────────────────────────────── + +def pset_b32(pattern): + """pto.pset_b32 "PATTERN" → !pto.mask (all-true when "PAT_ALL").""" + return pto.PsetB32Op(mask_type("b32"), pattern).result + + +# ─── Vector load / store with dist attribute ────────────────────────────────── + +def vbrc_load(src_ptr, offset, result_vreg_type): + """pto.vlds with dist="BRC_B32" – broadcast a scalar into all lanes.""" + return pto.VldsOp(result_vreg_type, src_ptr, offset, + dist="BRC_B32").result + + +def vsts_1pt(val, dst_ptr, offset, mask): + """pto.vsts with dist="1PT_B32" – store only the lowest lane.""" + pto.VstsOp(val, dst_ptr, offset, mask, dist="1PT_B32") + + +# ─── Vector math (result type inferred from first operand) ──────────────────── +# +# These wrappers follow the convention: if result_type is None the type is +# taken from the first operand (all PTO binary vector ops return the same +# type as their inputs). +# + +def vcmax(v, mask): + """pto.vcmax – cross-lane maximum reduction.""" + return pto.VcmaxOp(v.type, v, mask).result + + +def vdup_lowest(v, mask): + """pto.vdup {position="LOWEST"} – broadcast lane-0 to all lanes.""" + return pto.VdupOp(v.type, v, mask, position="LOWEST").result + + +def vmax(lhs, rhs, mask): + """pto.vmax – element-wise maximum.""" + return pto.VmaxOp(lhs.type, lhs, rhs, mask).result + + +def vexpdif(inp, ref, mask, part="ODD"): + """pto.vexpdif – exp(inp − ref), selecting ODD or EVEN lanes.""" + return pto.VexpdifOp(inp.type, inp, ref, mask, part).result + + +def vmul(lhs, rhs, mask): + """pto.vmul – element-wise multiply.""" + return pto.VmulOp(lhs.type, lhs, rhs, mask).result + + +def vcadd(v, mask): + """pto.vcadd – cross-lane add (sum reduction).""" + return pto.VcaddOp(v.type, v, mask).result + + +def vdiv(lhs, rhs, mask): + """pto.vdiv – element-wise divide.""" + return pto.VdivOp(lhs.type, lhs, rhs, mask).result + + +# Override vadd to make result_type optional (inferred from lhs when omitted) +_vadd_impl = vadd + + +def vadd(lhs, rhs, mask, result_type=None): # type: ignore[misc] + """pto.vadd – element-wise add (result_type inferred from lhs if None).""" + rt = result_type if result_type is not None else lhs.type + return pto.VaddOp(rt, lhs, rhs, mask).result + diff --git a/ptodsl/softmax_builder_highlevel.py b/ptodsl/softmax_builder_highlevel.py new file mode 100644 index 000000000..d0c80d2b5 --- /dev/null +++ b/ptodsl/softmax_builder_highlevel.py @@ -0,0 +1,234 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +High-level builder for the online softmax kernel. + +Reconstructs the same IR as softmax_builder_lowlevel.py using the +thin wrappers in ptodsl_utils. Compare the two files side by side to see +which boilerplate the utils eliminate. +""" + +from mlir.ir import F32Type, InsertionPoint, Type + +from ptodsl_utils import ( + # context / types + pto_context, flat_pto_module, pto_aicore_func, + i32_type, i64_type, idx_type, ptr_type, vreg_type, mask_type, + # constants + c_idx, c_i32, c_i64, + # arithmetic + muli, addi, subi, index_cast, cmpi_sgt, select_val, + # hardware + get_block_idx, barrier_all, + # tile domain + tile_view, part_view, alloc_tile, tload, tstore, tile_ptr, + # sync (pto.set_flag / pto.wait_flag come from pto module directly) + # vector / pointer + castptr, addptr, vlds, vsts, + plt_b32, pset_b32, vbrc_load, vsts_1pt, + # vector math + vcmax, vdup_lowest, vmax, vexpdif, vmul, vcadd, vadd, vdiv, + # control flow + vecscope, for_range, for_range_iter, yield_vals, + if_ctx, if_op_returning, +) +from mlir.dialects import pto + + +def build(): + with pto_context(): + # ── Types used throughout the kernel ────────────────────────────── + f32 = F32Type.get() + i32 = i32_type() + i64 = i64_type() + idx = idx_type() + ptr_gm = ptr_type(f32, "gm") # !pto.ptr + ptr_ub = ptr_type(f32, "ub") # !pto.ptr + tv5d = Type.parse("!pto.tensor_view") + ptv5d = Type.parse("!pto.partition_tensor_view") + tile_col = Type.parse("!pto.tile_buf") + tile_w = Type.parse("!pto.tile_buf") + vf32 = vreg_type(64, f32) # !pto.vreg<64xf32> + + with flat_pto_module("a5") as mod: + with pto_aicore_func( + "online_softmax_update_kernel_2d", + [ptr_gm] * 7 + [i32, i32], + ) as (a0, a1, a2, a3, a4, a5, a6, arg7, arg8): + + # ── Index constants ──────────────────────────────────── + c0, c1, c8, c64, c128 = (c_idx(v) for v in (0, 1, 8, 64, 128)) + + # ── i64 constants ───────────────────────────────────── + # Declared in the same order as the reference IR so that + # the round-tripped MLIR text compares equal. + c0_i64 = c_i64(0) + _c1_i64 = c_i64(1) # present in reference, unused here + _c8_i64 = c_i64(8) + _c16_i64 = c_i64(16) + _c32_i64 = c_i64(32) + _c64_i64 = c_i64(64) + c128_i64 = c_i64(128) + c256_i64 = c_i64(256) + _c512_i64 = c_i64(512) + c8448_i64 = c_i64(8448) + c16640_i64 = c_i64(16640) + c16768_i64 = c_i64(16768) + c16896_i64 = c_i64(16896) + + # ── i32 constants ────────────────────────────────────── + c1_i32 = c_i32(1); c8_i32 = c_i32(8) + c64_i32 = c_i32(64); c0_i32 = c_i32(0) + + # ── Block-level row assignment ───────────────────────── + block_i64 = get_block_idx() + block_idx = index_cast(idx, block_i64) + row_base = muli(block_idx, c8) + _ = index_cast(i32, c8) # block_rows_i32 + row_base_i32 = index_cast(i32, row_base) + remaining_rows= subi(arg8, row_base_i32) + has_rows = cmpi_sgt(remaining_rows, c0_i32) + too_many_rows = cmpi_sgt(remaining_rows, c8_i32) + row_count_i32 = select_val(too_many_rows, c8_i32, remaining_rows) + row_count = index_cast(idx, row_count_i32) + seq = index_cast(idx, arg7) + rows = index_cast(idx, arg8) + rows_x_128 = muli(rows, c128) + + with if_ctx(has_rows): + # ── Tensor views ─────────────────────────────────── + s1 = [rows, rows, rows, c1, rows] + s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] + sh1 = [c1, c1, c1, rows, c1] + sh128= [c1, c1, c1, rows, c128] + + oldmax_view = tile_view(tv5d, a0, sh1, s1) + oldsum_view = tile_view(tv5d, a1, sh1, s1) + qk_view = tile_view(tv5d, a2, sh128, s128) + newmax_view = tile_view(tv5d, a3, sh1, s1) + newsum_view = tile_view(tv5d, a4, sh1, s1) + expmax_view = tile_view(tv5d, a5, sh1, s1) + out_view = tile_view(tv5d, a6, sh128, s128) + + # ── Partition views ──────────────────────────────── + off = [c0, c0, c0, row_base, c0] + z1 = [c1, c1, c1, row_count, c1] + zs = [c1, c1, c1, row_count, seq] + + oldmax_part = part_view(ptv5d, oldmax_view, off, z1) + oldsum_part = part_view(ptv5d, oldsum_view, off, z1) + qk_part = part_view(ptv5d, qk_view, off, zs) + newmax_part = part_view(ptv5d, newmax_view, off, z1) + newsum_part = part_view(ptv5d, newsum_view, off, z1) + expmax_part = part_view(ptv5d, expmax_view, off, z1) + out_part = part_view(ptv5d, out_view, off, zs) + + # ── UB tile allocation ───────────────────────────── + oldmax_tile = alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) + oldsum_tile = alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) + qk_tile = alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) + out_tile = alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) + newmax_tile = alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) + newsum_tile = alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) + expmax_tile = alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) + + # ── Tile loads from GM ───────────────────────────── + tload(oldmax_part, oldmax_tile) + tload(oldsum_part, oldsum_tile) + tload(qk_part, qk_tile) + + pto.set_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) + pto.wait_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) + + with vecscope(): + # Materialise typed UB pointers from tile handles + ub_om = tile_ptr(oldmax_tile, ptr_ub) + ub_os = tile_ptr(oldsum_tile, ptr_ub) + ub_qk = tile_ptr(qk_tile, ptr_ub) + ub_out= tile_ptr(out_tile, ptr_ub) + ub_nm = tile_ptr(newmax_tile, ptr_ub) + ub_ns = tile_ptr(newsum_tile, ptr_ub) + ub_em = tile_ptr(expmax_tile, ptr_ub) + + active = pset_b32("PAT_ALL") + one_mask, _ = plt_b32(c1_i32) + + with for_range(c0, row_count, c1) as row: + row_qk = muli(row, c128) + oldmax_bc = vbrc_load(ub_om, row, vf32) + oldsum_bc = vbrc_load(ub_os, row, vf32) + + # ── Chunk loop: compute running max & sum ── + with for_range_iter(c0, c128, c64, + [oldmax_bc, oldsum_bc]) as cf: + chunk = cf.induction_variable + running_max, running_sum = cf.inner_iter_args + + rem_cols = subi(arg7, index_cast(i32, chunk)) + has_chunk = cmpi_sgt(rem_cols, c0_i32) + + br = if_op_returning(has_chunk, [vf32, vf32]) + with InsertionPoint(br.then_block): + cmask, _ = plt_b32(rem_cols) + cbase = addi(row_qk, chunk) + vec = vlds(ub_qk, cbase, vf32) + cmax = vcmax(vec, cmask) + cmax_bc = vdup_lowest(cmax, active) + mmax = vmax(running_max, cmax_bc, active) + sc_run = vexpdif(running_max, mmax, active) + rs_sc = vmul(sc_run, running_sum, active) + c_exp = vexpdif(vec, mmax, cmask) + c_sum = vcadd(c_exp, cmask) + c_sum_bc = vdup_lowest(c_sum, active) + m_sum = vadd(rs_sc, c_sum_bc, active) + yield_vals(mmax, m_sum) + with InsertionPoint(br.else_block): + yield_vals(running_max, running_sum) + + yield_vals(*br.results) + + final_max, final_sum = cf.results + + # ── Compute expmax scalar for this row ───── + raw_em = vexpdif(oldmax_bc, final_max, active) + sc_os = vmul(raw_em, oldsum_bc, active) + expmax = vdiv(sc_os, final_sum, active) + + vsts_1pt(final_max, ub_nm, row, one_mask) + vsts_1pt(final_sum, ub_ns, row, one_mask) + vsts_1pt(expmax, ub_em, row, one_mask) + + # ── Output normalisation loop ────────────── + with for_range(c0, c128, c64) as chunk2: + rem2 = subi(arg7, index_cast(i32, chunk2)) + has_c2 = cmpi_sgt(rem2, c0_i32) + with if_ctx(has_c2): + cmask2, _ = plt_b32(rem2) + cbase2 = addi(row_qk, chunk2) + vec2 = vlds(ub_qk, cbase2, vf32) + exp2 = vexpdif(vec2, final_max, cmask2) + out2 = vdiv(exp2, final_sum, cmask2) + vsts(out2, ub_out, cbase2, cmask2) + + pto.set_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) + pto.wait_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) + + # ── Tile stores to GM ────────────────────────────── + tstore(newmax_tile, newmax_part) + tstore(newsum_tile, newsum_part) + tstore(expmax_tile, expmax_part) + tstore(out_tile, out_part) + + barrier_all() + + return mod + + +if __name__ == "__main__": + print(build()) diff --git a/ptodsl/softmax_builder_lowlevel.py b/ptodsl/softmax_builder_lowlevel.py new file mode 100644 index 000000000..420eaf5b7 --- /dev/null +++ b/ptodsl/softmax_builder_lowlevel.py @@ -0,0 +1,328 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Low-level builder for the online softmax kernel. + +Reconstructs the IR in + test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto +using raw MLIR Python binding calls, with no additional abstraction layer. +""" + +from mlir.ir import ( + Attribute, + Context, + F32Type, + InsertionPoint, + IntegerType, + IndexType, + Location, + Module, + StringAttr, + Type, + UnitAttr, +) +from mlir.dialects import arith, func, pto, scf + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(): + # ── Types ──────────────────────────────────────────────────────── + i1 = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + idx = IndexType.get() + + ptr_gm = Type.parse("!pto.ptr") + ptr_ub = Type.parse("!pto.ptr") + tv5d = Type.parse("!pto.tensor_view") + ptv5d = Type.parse("!pto.partition_tensor_view") + tile_col = Type.parse("!pto.tile_buf") + tile_wide = Type.parse("!pto.tile_buf") + vreg = Type.parse("!pto.vreg<64xf32>") + mask_b32 = Type.parse("!pto.mask") + + # ── Flat single module ──────────────────────────────────────── + m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") + m.operation.attributes["pto.kernel_kind"] = Attribute.parse( + "#pto.kernel_kind" + ) + + fn_ty = func.FunctionType.get([ptr_gm] * 7 + [i32, i32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("online_softmax_update_kernel_2d", fn_ty) + fn.attributes["pto.aicore"] = UnitAttr.get() + entry = fn.add_entry_block() + + with InsertionPoint(entry): + a0, a1, a2, a3, a4, a5, a6, arg7, arg8 = entry.arguments + + # ── Index constants ─────────────────────────────────────── + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c8 = arith.ConstantOp(idx, 8).result + c64 = arith.ConstantOp(idx, 64).result + c128 = arith.ConstantOp(idx, 128).result + + # ── i64 constants ───────────────────────────────────────── + c0_i64 = arith.ConstantOp(i64, 0).result + c1_i64 = arith.ConstantOp(i64, 1).result + c8_i64 = arith.ConstantOp(i64, 8).result + c16_i64 = arith.ConstantOp(i64, 16).result + c32_i64 = arith.ConstantOp(i64, 32).result + c64_i64 = arith.ConstantOp(i64, 64).result + c128_i64 = arith.ConstantOp(i64, 128).result + c256_i64 = arith.ConstantOp(i64, 256).result + c512_i64 = arith.ConstantOp(i64, 512).result + c8448_i64 = arith.ConstantOp(i64, 8448).result + c16640_i64 = arith.ConstantOp(i64, 16640).result + c16768_i64 = arith.ConstantOp(i64, 16768).result + c16896_i64 = arith.ConstantOp(i64, 16896).result + + # ── i32 constants ───────────────────────────────────────── + c1_i32 = arith.ConstantOp(i32, 1).result + c8_i32 = arith.ConstantOp(i32, 8).result + c64_i32 = arith.ConstantOp(i32, 64).result + c0_i32 = arith.ConstantOp(i32, 0).result + + # ── Block / row computation ─────────────────────────────── + block = pto.GetBlockIdxOp().result # i64 + block_idx = arith.IndexCastOp(idx, block).result + row_base = arith.MulIOp(block_idx, c8).result + block_rows_i32= arith.IndexCastOp(i32, c8).result + row_base_i32 = arith.IndexCastOp(i32, row_base).result + remaining_rows= arith.SubIOp(arg8, row_base_i32).result + has_rows = arith.CmpIOp(arith.CmpIPredicate.sgt, + remaining_rows, c0_i32).result + too_many_rows = arith.CmpIOp(arith.CmpIPredicate.sgt, + remaining_rows, c8_i32).result + row_count_i32 = arith.SelectOp(too_many_rows, c8_i32, + remaining_rows).result + row_count = arith.IndexCastOp(idx, row_count_i32).result + seq = arith.IndexCastOp(idx, arg7).result + rows = arith.IndexCastOp(idx, arg8).result + rows_x_128 = arith.MulIOp(rows, c128).result + + # ── scf.if %has_rows ────────────────────────────────────── + if_rows = scf.IfOp(has_rows) + with InsertionPoint(if_rows.then_block): + + # ── Tensor views ────────────────────────────────────── + s1 = [rows, rows, rows, c1, rows] + s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] + sh1 = [c1, c1, c1, rows, c1] + sh128 = [c1, c1, c1, rows, c128] + + oldmax_view = pto.MakeTensorViewOp(tv5d, a0, sh1, s1).result + oldsum_view = pto.MakeTensorViewOp(tv5d, a1, sh1, s1).result + qk_view = pto.MakeTensorViewOp(tv5d, a2, sh128, s128).result + newmax_view = pto.MakeTensorViewOp(tv5d, a3, sh1, s1).result + newsum_view = pto.MakeTensorViewOp(tv5d, a4, sh1, s1).result + expmax_view = pto.MakeTensorViewOp(tv5d, a5, sh1, s1).result + out_view = pto.MakeTensorViewOp(tv5d, a6, sh128, s128).result + + # ── Partition views ─────────────────────────────────── + off5 = [c0, c0, c0, row_base, c0] + sz1 = [c1, c1, c1, row_count, c1] + szs = [c1, c1, c1, row_count, seq] + + oldmax_part = pto.PartitionViewOp(ptv5d, oldmax_view, off5, sz1).result + oldsum_part = pto.PartitionViewOp(ptv5d, oldsum_view, off5, sz1).result + qk_part = pto.PartitionViewOp(ptv5d, qk_view, off5, szs).result + newmax_part = pto.PartitionViewOp(ptv5d, newmax_view, off5, sz1).result + newsum_part = pto.PartitionViewOp(ptv5d, newsum_view, off5, sz1).result + expmax_part = pto.PartitionViewOp(ptv5d, expmax_view, off5, sz1).result + out_part = pto.PartitionViewOp(ptv5d, out_view, off5, szs).result + + # ── Tile allocation ─────────────────────────────────── + oldmax_tile = pto.AllocTileOp(tile_col, addr=c0_i64, valid_row=row_count).result + oldsum_tile = pto.AllocTileOp(tile_col, addr=c128_i64, valid_row=row_count).result + qk_tile = pto.AllocTileOp(tile_wide, addr=c256_i64, valid_row=row_count, valid_col=seq).result + out_tile = pto.AllocTileOp(tile_wide, addr=c8448_i64, valid_row=row_count, valid_col=seq).result + newmax_tile = pto.AllocTileOp(tile_col, addr=c16640_i64, valid_row=row_count).result + newsum_tile = pto.AllocTileOp(tile_col, addr=c16768_i64, valid_row=row_count).result + expmax_tile = pto.AllocTileOp(tile_col, addr=c16896_i64, valid_row=row_count).result + + # ── Tile loads ──────────────────────────────────────── + pto.TLoadOp(None, oldmax_part, oldmax_tile) + pto.TLoadOp(None, oldsum_part, oldsum_tile) + pto.TLoadOp(None, qk_part, qk_tile) + + # ── Sync before vecscope ────────────────────────────── + pto.set_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) + pto.wait_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) + + # ── pto.vecscope ────────────────────────────────────── + vs_op = pto.VecScopeOp() + vs_block = vs_op.body.blocks.append() + with InsertionPoint(vs_block): + + # Materialise UB pointers from tile handles + ub_oldmax = pto.TileBufAddrOp(ptr_ub, oldmax_tile).result + ub_oldsum = pto.TileBufAddrOp(ptr_ub, oldsum_tile).result + ub_qk = pto.TileBufAddrOp(ptr_ub, qk_tile).result + ub_out = pto.TileBufAddrOp(ptr_ub, out_tile).result + ub_newmax = pto.TileBufAddrOp(ptr_ub, newmax_tile).result + ub_newsum = pto.TileBufAddrOp(ptr_ub, newsum_tile).result + ub_expmax = pto.TileBufAddrOp(ptr_ub, expmax_tile).result + + active = pto.PsetB32Op(mask_b32, "PAT_ALL").result + plt1 = pto.PltB32Op(mask_b32, i32, c1_i32) + one_mask = plt1.mask + + # ── for row in [0, row_count) ───────────────────── + row_for = scf.ForOp(c0, row_count, c1) + with InsertionPoint(row_for.body): + row = row_for.induction_variable + row_qk = arith.MulIOp(row, c128).result + + oldmax_bc = pto.VldsOp(vreg, ub_oldmax, row, + dist="BRC_B32").result + oldsum_bc = pto.VldsOp(vreg, ub_oldsum, row, + dist="BRC_B32").result + + # ── for chunk in [0,128,64) with iter_args ──── + chunk_for = scf.ForOp(c0, c128, c64, + [oldmax_bc, oldsum_bc]) + with InsertionPoint(chunk_for.body): + chunk = chunk_for.induction_variable + running_max = chunk_for.inner_iter_args[0] + running_sum = chunk_for.inner_iter_args[1] + + chunk_i32 = arith.IndexCastOp(i32, chunk).result + remaining_cols= arith.SubIOp(arg7, chunk_i32).result + has_chunk = arith.CmpIOp( + arith.CmpIPredicate.sgt, + remaining_cols, c0_i32).result + + # ── if has_chunk -> (vreg, vreg) ────────── + c_if = scf.IfOp(has_chunk, [vreg, vreg], + hasElse=True) + with InsertionPoint(c_if.then_block): + cplt = pto.PltB32Op(mask_b32, i32, + remaining_cols) + chunk_mask = cplt.mask + chunk_base = arith.AddIOp(row_qk, + chunk).result + vec = pto.VldsOp(vreg, ub_qk, + chunk_base).result + chunk_max = pto.VcmaxOp(vreg, vec, + chunk_mask).result + chunk_max_bc= pto.VdupOp(vreg, chunk_max, + active, + position="LOWEST").result + merged_max = pto.VmaxOp(vreg, running_max, + chunk_max_bc, + active).result + scaled_run = pto.VexpdifOp(vreg, + running_max, + merged_max, + active, + "ODD").result + run_sum_sc = pto.VmulOp(vreg, scaled_run, + running_sum, + active).result + chunk_exp = pto.VexpdifOp(vreg, vec, + merged_max, + chunk_mask, + "ODD").result + chunk_sum = pto.VcaddOp(vreg, chunk_exp, + chunk_mask).result + chunk_sum_bc= pto.VdupOp(vreg, chunk_sum, + active, + position="LOWEST").result + merged_sum = pto.VaddOp(vreg, run_sum_sc, + chunk_sum_bc, + active).result + scf.YieldOp([merged_max, merged_sum]) + with InsertionPoint(c_if.else_block): + scf.YieldOp([running_max, running_sum]) + + next_max, next_sum = c_if.results + scf.YieldOp([next_max, next_sum]) + + final_max, final_sum = chunk_for.results + + # ── Post-loop: compute expmax ───────────────── + raw_expmax = pto.VexpdifOp(vreg, oldmax_bc, + final_max, active, + "ODD").result + scaled_oldsum = pto.VmulOp(vreg, raw_expmax, + oldsum_bc, + active).result + expmax = pto.VdivOp(vreg, scaled_oldsum, + final_sum, + active).result + + pto.VstsOp(final_max, ub_newmax, row, one_mask, + dist="1PT_B32") + pto.VstsOp(final_sum, ub_newsum, row, one_mask, + dist="1PT_B32") + pto.VstsOp(expmax, ub_expmax, row, one_mask, + dist="1PT_B32") + + # ── Output normalisation loop ───────────────── + out_for = scf.ForOp(c0, c128, c64) + with InsertionPoint(out_for.body): + chunk2 = out_for.induction_variable + ci32_2 = arith.IndexCastOp(i32, + chunk2).result + rem2 = arith.SubIOp(arg7, ci32_2).result + has_chunk2 = arith.CmpIOp( + arith.CmpIPredicate.sgt, + rem2, c0_i32).result + + o_if = scf.IfOp(has_chunk2) + with InsertionPoint(o_if.then_block): + oplt = pto.PltB32Op(mask_b32, i32, + rem2) + cmask2 = oplt.mask + cbase2 = arith.AddIOp(row_qk, + chunk2).result + vec2 = pto.VldsOp(vreg, ub_qk, + cbase2).result + exp2 = pto.VexpdifOp(vreg, vec2, + final_max, + cmask2, + "ODD").result + out2 = pto.VdivOp(vreg, exp2, + final_sum, + cmask2).result + pto.VstsOp(out2, ub_out, cbase2, cmask2) + scf.YieldOp([]) + + scf.YieldOp([]) # out_for body + + scf.YieldOp([]) # row_for body + + # ── Sync after vecscope ─────────────────────────────── + pto.set_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) + pto.wait_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) + + # ── Tile stores ─────────────────────────────────────── + pto.TStoreOp(None, newmax_tile, newmax_part) + pto.TStoreOp(None, newsum_tile, newsum_part) + pto.TStoreOp(None, expmax_tile, expmax_part) + pto.TStoreOp(None, out_tile, out_part) + + scf.YieldOp([]) # if_rows then_block + + # ── Barrier and return ──────────────────────────────────── + pto.BarrierOp(Attribute.parse("#pto.pipe")) + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) From de75cadbb64cf3b62378fb331a8b7d45e070c6b3 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 18:01:31 +0000 Subject: [PATCH 06/31] script to check IR equal --- ptodsl/README.md | 242 ++++++++++++++++++++++ ptodsl/check_ir.py | 160 ++++++++++++++ ptodsl/tile_and_vpto_builder_highlevel.py | 16 +- 3 files changed, 408 insertions(+), 10 deletions(-) create mode 100644 ptodsl/README.md create mode 100644 ptodsl/check_ir.py diff --git a/ptodsl/README.md b/ptodsl/README.md new file mode 100644 index 000000000..8d32b3280 --- /dev/null +++ b/ptodsl/README.md @@ -0,0 +1,242 @@ +# ptodsl — PTO Python IR Builders + +This directory contains Python scripts that construct PTO MLIR IR modules +programmatically using the MLIR Python bindings. Two complete kernel examples +are provided, each in a **low-level** (raw bindings) and a **high-level** +(utility-wrapped) variant. + +--- + +## Directory layout + +``` +ptodsl/ +├── ptodsl_utils.py # Reusable utility wrappers +│ +├── tile_and_vpto_builder_lowlevel.py # TADD kernel – raw bindings +├── tile_and_vpto_builder_highlevel.py # TADD kernel – ptodsl_utils +│ +├── softmax_builder_lowlevel.py # Softmax kernel – raw bindings +├── softmax_builder_highlevel.py # Softmax kernel – ptodsl_utils +│ +└── check_ir.py # IR correctness test for all builders +``` + +--- + +## Prerequisites + +The ptoas dialect must be installed and the environment set up before use: + +```bash +# Install (first time only) +cd /workdir/ptoas_a5 +bash quick_install.sh + +# Set up environment in every new shell +source set_ptoas_env.sh +``` + +--- + +## Running the IR check + +```bash +# From ptoas_a5/ptodsl/ +python3 check_ir.py + +# Or from the repository root (ptoas_a5/) +python3 ptodsl/check_ir.py +``` + +Expected output when everything is correct: + +``` +ptodsl IR check +================================================== + PASS TADD low-level + PASS TADD high-level + PASS softmax low-level + PASS softmax high-level +================================================== +Result: ALL PASS +``` + +Exit code is `0` on full pass, `1` if any check fails. +A unified diff of the first 60 diverging lines is printed for each failing case. + +--- + +## Kernel examples + +### TADD — simple vector add (vPTO) + +| File | Reference | +|---|---| +| `tile_and_vpto_builder_lowlevel.py` | `test/lit/vpto/expand_tileop_to_vpto_result.pto` | +| `tile_and_vpto_builder_highlevel.py` | same | + +The kernel performs an element-wise vector add over a 1024-element float32 +buffer using 16 iterations of 64-wide vector instructions inside a +`pto.vecscope`. It exercises: +`pto.castptr`, `pto.addptr`, `pto.plt_b32`, `pto.vlds`, `pto.vadd`, +`pto.vsts`, nested modules (`pto.target_arch` + `pto.kernel_kind`). + +### Online softmax update + +| File | Reference | +|---|---| +| `softmax_builder_lowlevel.py` | `test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto` | +| `softmax_builder_highlevel.py` | same | + +An online softmax update kernel that mixes tile-domain loads/stores with +raw vector compute inside a `pto.vecscope`. It exercises a significantly +larger set of ops including: +`pto.get_block_idx`, `pto.make_tensor_view`, `pto.partition_view`, +`pto.alloc_tile`, `pto.tload`/`pto.tstore`, `pto.set_flag`/`pto.wait_flag`, +`pto.tile_buf_addr`, `pto.pset_b32`, `pto.vcmax`, `pto.vdup`, `pto.vmax`, +`pto.vexpdif`, `pto.vmul`, `pto.vcadd`, `pto.vdiv`, `pto.barrier`, +`scf.for` with `iter_args`, and `scf.if` with result values. + +--- + +## How the IR check works + +`check_ir.py` calls `build()` in each builder, then compares the resulting +module against its reference `.pto` file using MLIR round-trip normalization: + +``` +generated IR ──┐ + ├── Module.parse() → canonical string ──── == ──── PASS/FAIL +reference .pto ──┘ (strips comments, normalises SSA names and attr order) +``` + +**Why round-trip normalization?** + +| Issue | Raw text comparison | Round-trip comparison | +|---|---|---| +| `// comment` lines in `.pto` files | breaks | ignored by MLIR parser | +| Named SSA values (`%block_idx`) vs anonymous (`%0`) | breaks | both become `%0`, `%1` … | +| Attribute dict ordering (`{a=1, b=2}` vs `{b=2, a=1}`) | breaks | normalized | +| Constant declaration order | breaks | **preserved** – must match | + +Because constant declaration order is preserved after round-trip, builders +must emit constants in the same order as the reference. The `check_ir.py` +diff output makes such mismatches easy to locate. + +--- + +## `ptodsl_utils.py` – utility reference + +The utility module eliminates boilerplate so kernel logic is immediately +readable. All helpers operate on the **current** MLIR context and insertion +point; no context argument is threaded. + +### Type constructors + +| Helper | MLIR type | +|---|---| +| `i32_type()` | `i32` | +| `i64_type()` | `i64` | +| `idx_type()` | `index` | +| `ptr_type(elem, space="ub")` | `!pto.ptr` | +| `vreg_type(lanes, elem)` | `!pto.vreg` | +| `mask_type(bits="b32")` | `!pto.mask` | + +### Constant builders + +| Helper | Op | +|---|---| +| `c_idx(v)` | `arith.constant v : index` | +| `c_i32(v)` | `arith.constant v : i32` | +| `c_i64(v)` | `arith.constant v : i64` | + +### Arithmetic + +`muli`, `addi`, `subi` — `arith.muli/addi/subi` +`index_cast(type, val)` — `arith.index_cast` +`cmpi_sgt(a, b)` — `arith.cmpi sgt` +`select_val(cond, t, f)` — `arith.select` + +### Module / function builders + +```python +with pto_context(): # MLIR Context + PTO dialect + with vpto_kernel("MyKernel", arch="a5") as mod: # nested module + func (no args) + ... + +with pto_context(): + with flat_pto_module("a5") as mod: # flat module + pto.kernel_kind + with pto_aicore_func("f", [ptr_gm, i32]) as (p, n): # func with args + ... +``` + +### Control-flow helpers + +```python +with vecscope(): # pto.vecscope { ... } + +with for_range(lo, hi, step) as i: # scf.for, auto-inserts scf.yield + ... + +with for_range_iter(lo, hi, step, [a, b]) as cf: # scf.for with iter_args + x, y = cf.inner_iter_args + yield_vals(new_x, new_y) # scf.yield at end of body +final_x, final_y = cf.results # results accessible after the block + +with if_ctx(cond): # scf.if, no results, auto-inserts scf.yield + ... + +br = if_op_returning(cond, [vreg, vreg]) # scf.if with results + else +with InsertionPoint(br.then_block): + yield_vals(a, b) +with InsertionPoint(br.else_block): + yield_vals(c, d) +x, y = br.results +``` + +### Tile-domain helpers + +```python +tv = tile_view(tv_type, ptr, shape, strides) # pto.make_tensor_view +ptv = part_view(ptv_type, tv, offsets, sizes) # pto.partition_view +t = alloc_tile(tile_type, addr=a, valid_row=r, valid_col=c) # pto.alloc_tile +tload(part, tile) # pto.tload +tstore(tile, part) # pto.tstore +ub = tile_ptr(tile, ptr_ub_type) # pto.tile_buf_addr +``` + +### Vector / pointer helpers + +```python +ptr = castptr(int_addr, ptr_type) # pto.castptr +ptr2 = addptr(ptr, offset) # pto.addptr +v = vlds(ptr, offset, vreg_type) # pto.vlds +v = vbrc_load(ptr, offset, vreg_type) # pto.vlds {dist="BRC_B32"} +vsts(v, ptr, offset, mask) # pto.vsts +vsts_1pt(v, ptr, offset, mask) # pto.vsts {dist="1PT_B32"} +mask, _ = plt_b32(scalar) # pto.plt_b32 +mask = pset_b32("PAT_ALL") # pto.pset_b32 +``` + +### Vector math (result type inferred from first operand) + +```python +vcmax(v, mask) # cross-lane max reduction +vdup_lowest(v, mask) # broadcast lane 0 to all lanes +vmax(a, b, mask) # element-wise max +vexpdif(x, ref, mask) # exp(x − ref), ODD lanes +vmul(a, b, mask) # element-wise multiply +vcadd(v, mask) # cross-lane add (sum reduction) +vadd(a, b, mask) # element-wise add (result_type optional) +vdiv(a, b, mask) # element-wise divide +``` + +### Hardware / sync + +```python +get_block_idx() # pto.get_block_idx → i64 +barrier_all() # pto.barrier #pto.pipe +# use pto.set_flag / pto.wait_flag directly (from mlir.dialects.pto) +# use yield_vals(*vals) as shorthand for scf.YieldOp(list(vals)) +``` diff --git a/ptodsl/check_ir.py b/ptodsl/check_ir.py new file mode 100644 index 000000000..355ca04a6 --- /dev/null +++ b/ptodsl/check_ir.py @@ -0,0 +1,160 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +IR correctness check for all ptodsl builder scripts. + +Run from the repository root or from this directory: + python3 ptodsl/check_ir.py # from ptoas_a5/ + python3 check_ir.py # from ptodsl/ + +Each builder's build() function is called; its output is compared against the +corresponding hand-written reference .pto file. + +Comparison methodology +────────────────────── +Both the generated module and the reference file are parsed by the MLIR Python +API (Module.parse), then printed back to a string. This round-trip: + + • Strips comments (// lines in .pto files are ignored by the MLIR parser) + • Normalises SSA value names (%block_idx → %0, %running_max → %arg11, …) + • Normalises attribute ordering (MLIR sorts dict-like attribute sets) + +The resulting canonical strings are compared with ==. A diff of the first 60 +differing lines is printed on failure to aid diagnosis. +""" + +import difflib +import os +import sys + +# Allow running from either ptoas_a5/ or ptoas_a5/ptodsl/ +_HERE = os.path.dirname(os.path.abspath(__file__)) +if _HERE not in sys.path: + sys.path.insert(0, _HERE) + +# ── MLIR bootstrap ─────────────────────────────────────────────────────────── +_MLIR_INSTALL = os.path.join(_HERE, "..", "install", "mlir") +if _MLIR_INSTALL not in sys.path: + sys.path.insert(0, _MLIR_INSTALL) + +from mlir.ir import Context, Module # noqa: E402 +from mlir.dialects import pto as _pto_mod # noqa: E402 + + +def _normalize(mlir_text: str) -> str: + """Parse *mlir_text* with MLIR and return the canonical printed form.""" + with Context() as ctx: + _pto_mod.register_dialect(ctx, load=True) + return str(Module.parse(mlir_text)) + + +def _strip_comments(text: str) -> str: + """Remove // comment lines that appear in hand-written .pto files.""" + return "\n".join( + line for line in text.splitlines() if not line.strip().startswith("//") + ) + + +# ── Test cases ──────────────────────────────────────────────────────────────── +# Each entry: (label, builder_module_path, reference_pto_path) +_REPO_ROOT = os.path.abspath(os.path.join(_HERE, "..")) + +CASES = [ + ( + "TADD low-level ", + "tile_and_vpto_builder_lowlevel", + os.path.join(_REPO_ROOT, + "test/lit/vpto/expand_tileop_to_vpto_result.pto"), + ), + ( + "TADD high-level", + "tile_and_vpto_builder_highlevel", + os.path.join(_REPO_ROOT, + "test/lit/vpto/expand_tileop_to_vpto_result.pto"), + ), + ( + "softmax low-level ", + "softmax_builder_lowlevel", + os.path.join(_REPO_ROOT, + "test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto"), + ), + ( + "softmax high-level", + "softmax_builder_highlevel", + os.path.join(_REPO_ROOT, + "test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto"), + ), +] + + +# ── Runner ──────────────────────────────────────────────────────────────────── + +def run_checks(cases=CASES) -> bool: + """Execute every check case; return True if all passed.""" + all_passed = True + + for label, module_name, ref_path in cases: + # -- import the builder and call build() -- + try: + builder = __import__(module_name) + generated_module = builder.build() + generated_text = str(generated_module) + except Exception as exc: + print(f" FAIL {label} [builder error: {exc}]") + all_passed = False + continue + + # -- load and prepare the reference -- + try: + ref_raw = open(ref_path).read() + except FileNotFoundError: + print(f" FAIL {label} [reference not found: {ref_path}]") + all_passed = False + continue + + ref_clean = _strip_comments(ref_raw) + + # -- normalise both through the MLIR parser -- + try: + gen_norm = _normalize(generated_text) + ref_norm = _normalize(ref_clean) + except Exception as exc: + print(f" FAIL {label} [MLIR parse error: {exc}]") + all_passed = False + continue + + # -- compare -- + if gen_norm == ref_norm: + print(f" PASS {label}") + else: + all_passed = False + diff_lines = list( + difflib.unified_diff( + ref_norm.splitlines(), + gen_norm.splitlines(), + fromfile="reference", + tofile="generated", + lineterm="", + ) + ) + snippet = "\n".join(diff_lines[:60]) + print(f" FAIL {label}\n{snippet}") + if len(diff_lines) > 60: + print(f" ... ({len(diff_lines) - 60} more diff lines)") + + return all_passed + + +if __name__ == "__main__": + print("ptodsl IR check") + print("=" * 50) + passed = run_checks() + print("=" * 50) + print("Result:", "ALL PASS" if passed else "SOME TESTS FAILED") + sys.exit(0 if passed else 1) diff --git a/ptodsl/tile_and_vpto_builder_highlevel.py b/ptodsl/tile_and_vpto_builder_highlevel.py index 867657e4a..7c16cadc7 100644 --- a/ptodsl/tile_and_vpto_builder_highlevel.py +++ b/ptodsl/tile_and_vpto_builder_highlevel.py @@ -47,18 +47,14 @@ def build(): # ── Build the nested module shell and the @TADD function body ───── with vpto_kernel("TADD", arch="a5") as mod: - # Integer-address constants for the two input buffers + # Constants – declared in the same order as the reference IR. c0_i64 = c_i64(0) + c16 = c_idx(16) # loop trip-count: 1024 elems / 64-wide vreg c4096_i64 = c_i64(4096) - - # Loop-control constants - c0 = c_idx(0) - c1 = c_idx(1) - c16 = c_idx(16) # 1024-element array / 64-wide vreg = 16 tiles - - # Scalar used to generate the per-iteration mask - c64_i32 = c_i32(64) - c64 = c_idx(64) + c0 = c_idx(0) + c1 = c_idx(1) + c64_i32 = c_i32(64) # scalar for mask generation + c64 = c_idx(64) with vecscope(): # Materialise typed pointers from the raw integer addresses From 8d1a8346339d5afa132ac88dcb77f36d07a04940 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 18:14:55 +0000 Subject: [PATCH 07/31] avoid raw MLIR `Type.parse` --- ptodsl/ptodsl_utils.py | 82 ++++++++++++++++++++++-- ptodsl/softmax_builder_highlevel.py | 14 ++-- ptodsl/softmax_builder_lowlevel.py | 47 +++++++++++--- ptodsl/tile_and_vpto_builder_lowlevel.py | 14 ++-- 4 files changed, 132 insertions(+), 25 deletions(-) diff --git a/ptodsl/ptodsl_utils.py b/ptodsl/ptodsl_utils.py index ed702f861..b220239d4 100644 --- a/ptodsl/ptodsl_utils.py +++ b/ptodsl/ptodsl_utils.py @@ -31,12 +31,22 @@ Location, Module, Operation, + ShapedType, StringAttr, Type, UnitAttr, ) from mlir.dialects import arith, func, pto, scf +# Mapping from the textual address-space name used in !pto.ptr +# to the AddressSpace enum value exposed by the C extension. +_ADDR_SPACE = { + "ub": pto.AddressSpace.VEC, # "ub" (unified buffer) prints as VEC + "gm": pto.AddressSpace.GM, + "vec": pto.AddressSpace.VEC, + "l1": pto.AddressSpace.MAT, +} + # ─── Type constructors ──────────────────────────────────────────────────────── @@ -56,20 +66,84 @@ def idx_type(): def ptr_type(elem_type, space="ub"): - """PTO pointer type: !pto.ptr<{elem_type}, {space}>.""" + """PTO pointer type: !pto.ptr<{elem_type}, {space}>. + + Uses ``pto.PtrType.get`` with an ``AddressSpaceAttr`` when the address-space + name is known; falls back to ``Type.parse`` for unknown spaces. + """ + enum_val = _ADDR_SPACE.get(space) + if enum_val is not None: + space_attr = pto.AddressSpaceAttr.get(enum_val) + return pto.PtrType.get(elem_type, memory_space=space_attr) return Type.parse(f"!pto.ptr<{elem_type}, {space}>") def vreg_type(lanes, elem_type): - """PTO vector-register type: !pto.vreg<{lanes}x{elem_type}>.""" + """PTO vector-register type: !pto.vreg<{lanes}x{elem_type}>. + + VRegType has no Python-binding constructor; Type.parse is the only path. + """ return Type.parse(f"!pto.vreg<{lanes}x{elem_type}>") def mask_type(bits="b32"): - """PTO mask/predicate type: !pto.mask<{bits}> (b8 | b16 | b32).""" + """PTO mask/predicate type: !pto.mask<{bits}> (b8 | b16 | b32). + + MaskType has no Python-binding constructor; Type.parse is the only path. + """ return Type.parse(f"!pto.mask<{bits}>") +def tensor_view_type(rank, elem_type): + """PTO tensor-view type with all-dynamic dimensions: !pto.tensor_view. + + Uses ``pto.TensorViewType.get(rank, elem_type)``. + """ + return pto.TensorViewType.get(rank, elem_type) + + +def part_tensor_view_type(rank, elem_type): + """PTO partition-tensor-view type with all-dynamic dims: !pto.partition_tensor_view. + + Uses ``pto.PartitionTensorViewType.get([kDynamic]*rank, elem_type)``. + ``ShapedType.get_dynamic_size()`` (``INT64_MIN``) is the correct MLIR + sentinel; plain ``-1`` would produce a different printed form. + """ + kDynamic = ShapedType.get_dynamic_size() + return pto.PartitionTensorViewType.get([kDynamic] * rank, elem_type) + + +def tile_buf_type(shape, elem_type, valid_shape, *, + blayout="RowMajor", address_space="ub", + slayout="NoneBox", fractal_size=512, pad="Null"): + """PTO tile-buffer type via ``pto.TileBufType.get``. + + ``valid_shape`` entries may be ``-1`` for dynamic (``?``) dimensions. + ``blayout`` selects the block layout: ``"RowMajor"`` (default, omitted in + the printed form) or ``"ColMajor"`` (printed as ``blayout=col_major``). + + Common usage:: + + # !pto.tile_buf + tile_buf_type([8, 128], f32, [-1, -1]) + + # !pto.tile_buf + tile_buf_type([8, 1], f32, [-1, 1], blayout="ColMajor") + """ + space_enum = _ADDR_SPACE.get(address_space) + if space_enum is None: + raise ValueError(f"Unknown address_space '{address_space}'; " + f"known: {list(_ADDR_SPACE)}") + space_attr = pto.AddressSpaceAttr.get(space_enum) + cfg = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(getattr(pto.BLayout, blayout)), + pto.SLayoutAttr.get(getattr(pto.SLayout, slayout)), + fractal_size, + pto.PadValueAttr.get(getattr(pto.PadValue, pad)), + ) + return pto.TileBufType.get(shape, elem_type, space_attr, valid_shape, cfg) + + # ─── Constant builders ─────────────────────────────────────────────────────── def c_idx(value): @@ -391,7 +465,7 @@ def get_block_idx(): def barrier_all(): """pto.barrier #pto.pipe.""" - pto.BarrierOp(Attribute.parse("#pto.pipe")) + pto.BarrierOp(pto.PipeAttr.get(pto.PIPE.PIPE_ALL)) # ─── Tile-domain helpers ────────────────────────────────────────────────────── diff --git a/ptodsl/softmax_builder_highlevel.py b/ptodsl/softmax_builder_highlevel.py index d0c80d2b5..bc600484a 100644 --- a/ptodsl/softmax_builder_highlevel.py +++ b/ptodsl/softmax_builder_highlevel.py @@ -14,12 +14,14 @@ which boilerplate the utils eliminate. """ -from mlir.ir import F32Type, InsertionPoint, Type +from mlir.ir import F32Type, InsertionPoint from ptodsl_utils import ( # context / types pto_context, flat_pto_module, pto_aicore_func, - i32_type, i64_type, idx_type, ptr_type, vreg_type, mask_type, + i32_type, i64_type, idx_type, ptr_type, + tensor_view_type, part_tensor_view_type, tile_buf_type, + vreg_type, # constants c_idx, c_i32, c_i64, # arithmetic @@ -50,10 +52,10 @@ def build(): idx = idx_type() ptr_gm = ptr_type(f32, "gm") # !pto.ptr ptr_ub = ptr_type(f32, "ub") # !pto.ptr - tv5d = Type.parse("!pto.tensor_view") - ptv5d = Type.parse("!pto.partition_tensor_view") - tile_col = Type.parse("!pto.tile_buf") - tile_w = Type.parse("!pto.tile_buf") + tv5d = tensor_view_type(5, f32) # !pto.tensor_view + ptv5d = part_tensor_view_type(5, f32) # !pto.partition_tensor_view + tile_col = tile_buf_type([8, 1], f32, [-1, 1], blayout="ColMajor") # valid=?x1, col_major + tile_w = tile_buf_type([8, 128], f32, [-1, -1]) # valid=?x? vf32 = vreg_type(64, f32) # !pto.vreg<64xf32> with flat_pto_module("a5") as mod: diff --git a/ptodsl/softmax_builder_lowlevel.py b/ptodsl/softmax_builder_lowlevel.py index 420eaf5b7..d93e70ccf 100644 --- a/ptodsl/softmax_builder_lowlevel.py +++ b/ptodsl/softmax_builder_lowlevel.py @@ -23,6 +23,7 @@ IndexType, Location, Module, + ShapedType, StringAttr, Type, UnitAttr, @@ -40,19 +41,47 @@ def build(): i32 = IntegerType.get_signless(32) i64 = IntegerType.get_signless(64) idx = IndexType.get() + f32 = F32Type.get() + + # Address-space attributes used in pointer and tile types + _gm = pto.AddressSpaceAttr.get(pto.AddressSpace.GM) # gm = global memory + _ub = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC) # vec = UB (unified buffer) + # Sentinel value for a dynamic (unknown) dimension + _dyn = ShapedType.get_dynamic_size() + + # Pointer types built with PtrType.get + ptr_gm = pto.PtrType.get(f32, memory_space=_gm) # !pto.ptr + ptr_ub = pto.PtrType.get(f32, memory_space=_ub) # !pto.ptr + + # Tensor-view types built with TensorViewType / PartitionTensorViewType + tv5d = pto.TensorViewType.get(5, f32) # !pto.tensor_view + ptv5d = pto.PartitionTensorViewType.get([_dyn] * 5, f32) # !pto.partition_tensor_view + + # Tile-buffer config attributes + _col_cfg = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.ColMajor), + pto.SLayoutAttr.get(pto.SLayout.NoneBox), + 512, pto.PadValueAttr.get(pto.PadValue.Null), + ) + _row_cfg = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.RowMajor), + pto.SLayoutAttr.get(pto.SLayout.NoneBox), + 512, pto.PadValueAttr.get(pto.PadValue.Null), + ) + # !pto.tile_buf + tile_col = pto.TileBufType.get([8, 1], f32, _ub, [-1, 1], _col_cfg) + # !pto.tile_buf + tile_wide = pto.TileBufType.get([8, 128], f32, _ub, [-1, -1], _row_cfg) - ptr_gm = Type.parse("!pto.ptr") - ptr_ub = Type.parse("!pto.ptr") - tv5d = Type.parse("!pto.tensor_view") - ptv5d = Type.parse("!pto.partition_tensor_view") - tile_col = Type.parse("!pto.tile_buf") - tile_wide = Type.parse("!pto.tile_buf") - vreg = Type.parse("!pto.vreg<64xf32>") - mask_b32 = Type.parse("!pto.mask") + # VReg and Mask types have no Python-binding constructors yet; + # Type.parse is the only available path for these two. + vreg = Type.parse("!pto.vreg<64xf32>") + mask_b32 = Type.parse("!pto.mask") # ── Flat single module ──────────────────────────────────────── m = Module.create() m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") + # FunctionKernelKindAttr has no binding; Attribute.parse is the only path. m.operation.attributes["pto.kernel_kind"] = Attribute.parse( "#pto.kernel_kind" ) @@ -317,7 +346,7 @@ def build(): scf.YieldOp([]) # if_rows then_block # ── Barrier and return ──────────────────────────────────── - pto.BarrierOp(Attribute.parse("#pto.pipe")) + pto.BarrierOp(pto.PipeAttr.get(pto.PIPE.PIPE_ALL)) func.ReturnOp([]) m.operation.verify() diff --git a/ptodsl/tile_and_vpto_builder_lowlevel.py b/ptodsl/tile_and_vpto_builder_lowlevel.py index 43bb6490a..8dd98c77d 100644 --- a/ptodsl/tile_and_vpto_builder_lowlevel.py +++ b/ptodsl/tile_and_vpto_builder_lowlevel.py @@ -66,15 +66,17 @@ def build(): i32 = IntegerType.get_signless(32) i64 = IntegerType.get_signless(64) idx = IndexType.get() + f32 = F32Type.get() - # !pto.ptr – pointer to f32 in the "ub" address space - ptr_f32_ub = Type.parse("!pto.ptr") + # !pto.ptr – pointer to f32 in the UB (VEC) address space + ptr_f32_ub = pto.PtrType.get( + f32, memory_space=pto.AddressSpaceAttr.get(pto.AddressSpace.VEC) + ) - # !pto.vreg<64xf32> – vector register holding 64 × f32 + # VReg and Mask types have no Python-binding constructors yet; + # Type.parse is the only available path for these two. vreg_64f32 = Type.parse("!pto.vreg<64xf32>") - - # !pto.mask – predicate register for 32-bit element ops - mask_b32 = Type.parse("!pto.mask") + mask_b32 = Type.parse("!pto.mask") # ── Shared attributes ───────────────────────────────────────── target_arch_attr = StringAttr.get("a5") From 8dc1a4aea800261a65e9f619c9eedda0d6e5c5b7 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 19:36:43 +0000 Subject: [PATCH 08/31] more Pythonic builder style suggestions --- ptodsl/softmax_builder_suggested.py | 204 ++++++++++++++++++++++ ptodsl/tile_and_vpto_builder_suggested.py | 32 ++++ 2 files changed, 236 insertions(+) create mode 100644 ptodsl/softmax_builder_suggested.py create mode 100644 ptodsl/tile_and_vpto_builder_suggested.py diff --git a/ptodsl/softmax_builder_suggested.py b/ptodsl/softmax_builder_suggested.py new file mode 100644 index 000000000..3c2340af1 --- /dev/null +++ b/ptodsl/softmax_builder_suggested.py @@ -0,0 +1,204 @@ +# Minimum Pythonic mapping of test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto +import pto +s = pto.scalar + +@pto.to_ir( + name="online_softmax_update_kernel_2d", # default to function name if not given + kernel_kind="vector", + arch="a5", + func_attr="pto.aicore" +) +def softmax_demo( + arg0: pto.ptr(pto.float32, "GM"), + arg2: pto.ptr(pto.float32, "GM"), + arg3: pto.ptr(pto.float32, "GM"), + arg4: pto.ptr(pto.float32, "GM"), + arg5: pto.ptr(pto.float32, "GM"), + arg6: pto.ptr(pto.float32, "GM"), + arg7: pto.int32, + arg8: pto.int32 +) + c0 = pto.const(0) + c1 = pto.const(1) + c8 = pto.const(8) + c64 = pto.const(64) + c128 = pto.const(128) + + c0_i64 = pto.const(0, dtype=pto.int64) + c1_i64 = pto.const(1, dtype=pto.int64) + c8_i64 = pto.const(8, dtype=pto.int64) + c16_i64 = pto.const(16, dtype=pto.int64) + c32_i64 = pto.const(32, dtype=pto.int64) + c64_i64 = pto.const(64, dtype=pto.int64) + c128_i64 = pto.const(128, dtype=pto.int64) + c256_i64 = pto.const(256, dtype=pto.int64) + c512_i64 = pto.const(512, dtype=pto.int64) + c8448_i64 = pto.const(8448, dtype=pto.int64) + c16640_i64 = pto.const(16640, dtype=pto.int64) + c16768_i64 = pto.const(16768, dtype=pto.int64) + c16896_i64 = pto.const(16896, dtype=pto.int64) + + c1_i32 = pto.const(1, dtype=pto.int32) + c8_i32 = pto.const(8, dtype=pto.int32) + c64_i32 = pto.const(64, dtype=pto.int32) + c0_i32 = pto.const(0, dtype=pto.int32) + + block_i64 = pto.get_block_idx() + block_idx = s.index_cast(idx, block_i64) + row_base = s.muli(block_idx, c8) + _ = s.index_cast(i32, c8) # block_rows_i32 + row_base_i32 = s.index_cast(i32, row_base) + remaining_rows= s.subi(arg8, row_base_i32) + has_rows = s.cmpi_sgt(remaining_rows, c0_i32) # optionally overload __gt__ + too_many_rows = s.cmpi_sgt(remaining_rows, c8_i32) + row_count_i32 = s.select(too_many_rows, c8_i32, remaining_rows) + row_count = s.index_cast(row_count_i32) + seq = s.index_cast(arg7) + rows = s.index_cast(arg8) + rows_x_128 = s.muli(rows, c128) + + with pto.if_(has_rows): + # ── Tensor views ─────────────────────────────────── + s1 = [rows, rows, rows, c1, rows] + s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] + sh1 = [c1, c1, c1, rows, c1] + sh128= [c1, c1, c1, rows, c128] + + # 5D type `!pto.tensor_view` can be inferred from shape rank + oldmax_view = pto.make_tensor_view(arg0, shape=sh1, strides=s1) + oldsum_view = pto.make_tensor_view(arg1, shape=sh1, strides=s1) + qk_view = pto.make_tensor_view(arg2, shape=h128, strides=s128) + newmax_view = pto.make_tensor_view(arg3, shape=sh1, strides=s1) + newsum_view = pto.make_tensor_view(arg4, shape=sh1, strides=s1) + expmax_view = pto.make_tensor_view(arg5, shape=sh1, strides=s1) + out_view = pto.make_tensor_view(arg6, shape=sh128, strides=s128) + + # ── Partition views ──────────────────────────────── + off = [c0, c0, c0, row_base, c0] + z1 = [c1, c1, c1, row_count, c1] + zs = [c1, c1, c1, row_count, seq] + + # 5D type `!pto.tensor_view -> !pto.partition_tensor_view` can be inferred from shape rank + oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) + oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) + qk_part = pto.partition_view(qk_view, offsets=off, sizes=zs) + newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) + newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) + expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) + out_part = pto.partition_view(out_view, offsets=off, sizes=zs) + + # ── UB tile allocation ───────────────────────────── + tile_col = pto.tile_buf_type( + shape=[8, 1], dtype=pto.float32, valid_shape=[-1, 1], blayout="ColMajor") # valid=?x1, col_major + tile_w = pto.tile_buf_type( + shape=[8, 128], dtype=pto.float32, valid_shape=[-1, -1]) # valid=?x? + + oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) + oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) + qk_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) + out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) + newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) + newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) + expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) + + # ── Tile loads from GM ───────────────────────────── + pto.tload(oldmax_part, oldmax_tile) + pto.tload(oldsum_part, oldsum_tile) + pto.tload(qk_part, qk_tile) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.vecscope(): + # Materialise typed UB pointers from tile handles + ptr_ub = pto.ptr(pto.float32, "UB") # !pto.ptr + vf32 = pto.vreg_type(64, pto.float32) + ub_om = pto.tile_ptr(oldmax_tile, ptr_ub) + ub_os = pto.tile_ptr(oldsum_tile, ptr_ub) + ub_qk = pto.tile_ptr(qk_tile, ptr_ub) + ub_out = pto.tile_ptr(out_tile, ptr_ub) + ub_nm = pto.tile_ptr(newmax_tile, ptr_ub) + ub_ns = pto.tile_ptr(newsum_tile, ptr_ub) + ub_em = pto.tile_ptr(expmax_tile, ptr_ub) + + active = pto.pset_b32("PAT_ALL") + one_mask, _ = pto.plt_b32(c1_i32) + + with pto.for_(c0, row_count, step=c1) as row: + row_qk = s.muli(row, c128) # can optionally overload __mul__ + oldmax_bc = pto.vbrc_load(ub_om, row, vf32) + oldsum_bc = pto.vbrc_load(ub_os, row, vf32) + + # ── Chunk loop: compute running max & sum ── + # %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + # iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + # -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + with pto.for_(c0, c128, step=c64, + iter_args=(oldmax_bc, oldsum_bc), + results=(vf32, vf32)) as loop: + chunk = loop.iv # induction variable %chunk (index) bound by `scf.for %chunk = ...` + running_max, running_sum = loop.iter_args + + chunk_i32 = s.index_cast(pto.int32, chunk) # arith.index_cast index to i32 + remaining_cols = s.subi(arg7, chunk_i32) # arith.subi + has_chunk = s.cmpi("sgt", remaining_cols, c0_i32) # arith.cmpi sgt + + # %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + with pto.if_(has_chunk, results=(vf32, vf32)) as br: + with br.then_: + chunk_mask, chunk_rest = pto.plt_b32(remaining_cols) + chunk_base = s.addi(row_qk, chunk) + vec = pto.vlds(ub_qk, chunk_base, vf32) + chunk_max = pto.vcmax(vec, chunk_mask) + chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") + merged_max = pto.vmax(running_max, chunk_max_bc, active) + scaled_running = pto.vexpdif(running_max, merged_max, active, "ODD") + running_sum_scaled = pto.vmul(scaled_running, running_sum, active) + chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask, "ODD") + chunk_sum = pto.vcadd(chunk_exp, chunk_mask) + chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") + merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) + pto.yield_(merged_max, merged_sum) # scf.yield + with br.else_: + pto.yield_(running_max, running_sum) # scf.yield + next_max, next_sum = br.results + + pto.yield_(next_max, next_sum) # scf.yield + + final_max, final_sum = loop.results + + # ── Compute expmax scalar for this row ───── + raw_em = pto.vexpdif(oldmax_bc, final_max, active) + sc_os = pto.vmul(raw_em, oldsum_bc, active) + expmax = pto.vdiv(sc_os, final_sum, active) + + pto.vsts_1pt(final_max, ub_nm, row, one_mask) + pto.vsts_1pt(final_sum, ub_ns, row, one_mask) + pto.vsts_1pt(expmax, ub_em, row, one_mask) + + # ── Output normalisation loop ────────────── + with pto.for_(c0, c128, step=c64) as chunk2: + rem2 = s.subi(arg7, s.index_cast(pto.int32, chunk2)) + has_c2 = s.cmpi_sgt(rem2, c0_i32) + with pto.if_(has_c2): + cmask2, _ = pto.plt_b32(rem2) + cbase2 = s.addi(row_qk, chunk2) + vec2 = pto.vlds(ub_qk, cbase2, vf32) + exp2 = pto.vexpdif(vec2, final_max, cmask2) + out2 = pto.vdiv(exp2, final_sum, cmask2) + pto.vsts(out2, ub_out, cbase2, cmask2) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=1) + + # ── Tile stores to GM ────────────────────────────── + pto.tstore(newmax_tile, newmax_part) + pto.tstore(newsum_tile, newsum_part) + pto.tstore(expmax_tile, expmax_part) + pto.tstore(out_tile, out_part) + + pto.barrier_all() + + +if __name__ == "__main__": + print(softmax_demo) diff --git a/ptodsl/tile_and_vpto_builder_suggested.py b/ptodsl/tile_and_vpto_builder_suggested.py new file mode 100644 index 000000000..583d491ba --- /dev/null +++ b/ptodsl/tile_and_vpto_builder_suggested.py @@ -0,0 +1,32 @@ +# minimum Pythonic mapping of test/lit/vpto/expand_tileop_to_vpto_result.pto + +import pto +s = pto.scalar + +@pto.to_ir(name="TADD", kernel_kind="vector", arch="a5") +def vpto_demo(): + c0_i64 = pto.const(0, dtype=pto.int64) + c16 = pto.const(16, dtype=pto.index) # if no dtype passed, default to pto.index + c4096_i64 = pto.const(4096, dtype=pto.int64) + c0 = pto.const(0) + c1 = pto.const(1) + c64_i32 = pto.const(64, dtype=pto.int32) + c64 = pto.const(64) + with pto.vecscope(): + ptr_type = pto.ptr(pto.float32, "UB") + ptr_src = pto.castptr(c4096_i64, ptr_type) + ptr_dst = pto.castptr(c0_i64, ptr_type) + vreg_type = vreg_type(64, pto.float32) + with for_(c0, c16, step=c1) as tile_idx: + mask, _ = pto.plt_b32(c64_i32) + tile_off = s.muli(tile_idx, c64) # can optionally overload __mul__ + va = pto.vlds(pto.addptr(ptr_src, tile_off), c0, vreg_type) + ptr_dst_tile = pto.addptr(ptr_dst, tile_off) + vb = pto.vlds(ptr_dst_tile, c0, vreg_type) + vc = pto.vadd(va, vb, mask, vreg_type) + pto.vsts(vc, ptr_dst_tile, c0, mask) + # by default return None, matches IR `return` + + +if __name__ == "__main__": + print(vpto_demo) From 63e590a3adf872d5c18d329a387bb76246799499 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 13 May 2026 20:10:47 +0000 Subject: [PATCH 09/31] major refactor of dsl syntax and impl --- ptodsl/README.md | 291 ++++----- ptodsl/check_ir.py | 89 ++- ptodsl/examples/softmax_dsl.py | 245 ++++++++ .../softmax_lowlevel.py} | 0 ptodsl/examples/tadd_dsl.py | 67 ++ .../tadd_lowlevel.py} | 0 ptodsl/ptodsl/__init__.py | 12 + ptodsl/ptodsl/_bootstrap.py | 35 ++ ptodsl/ptodsl/_control_flow.py | 230 +++++++ ptodsl/ptodsl/_module.py | 159 +++++ ptodsl/ptodsl/_ops.py | 256 ++++++++ ptodsl/ptodsl/_types.py | 176 ++++++ ptodsl/ptodsl/pto.py | 58 ++ ptodsl/ptodsl/scalar.py | 90 +++ ptodsl/ptodsl_utils.py | 574 ------------------ ptodsl/pyproject.toml | 13 + ptodsl/softmax_builder_highlevel.py | 236 ------- ptodsl/softmax_builder_suggested.py | 204 ------- ptodsl/tile_and_vpto_builder_highlevel.py | 82 --- ptodsl/tile_and_vpto_builder_suggested.py | 32 - 20 files changed, 1513 insertions(+), 1336 deletions(-) create mode 100644 ptodsl/examples/softmax_dsl.py rename ptodsl/{softmax_builder_lowlevel.py => examples/softmax_lowlevel.py} (100%) create mode 100644 ptodsl/examples/tadd_dsl.py rename ptodsl/{tile_and_vpto_builder_lowlevel.py => examples/tadd_lowlevel.py} (100%) create mode 100644 ptodsl/ptodsl/__init__.py create mode 100644 ptodsl/ptodsl/_bootstrap.py create mode 100644 ptodsl/ptodsl/_control_flow.py create mode 100644 ptodsl/ptodsl/_module.py create mode 100644 ptodsl/ptodsl/_ops.py create mode 100644 ptodsl/ptodsl/_types.py create mode 100644 ptodsl/ptodsl/pto.py create mode 100644 ptodsl/ptodsl/scalar.py delete mode 100644 ptodsl/ptodsl_utils.py create mode 100644 ptodsl/pyproject.toml delete mode 100644 ptodsl/softmax_builder_highlevel.py delete mode 100644 ptodsl/softmax_builder_suggested.py delete mode 100644 ptodsl/tile_and_vpto_builder_highlevel.py delete mode 100644 ptodsl/tile_and_vpto_builder_suggested.py diff --git a/ptodsl/README.md b/ptodsl/README.md index 8d32b3280..d2c2dac81 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -1,9 +1,9 @@ # ptodsl — PTO Python IR Builders -This directory contains Python scripts that construct PTO MLIR IR modules -programmatically using the MLIR Python bindings. Two complete kernel examples -are provided, each in a **low-level** (raw bindings) and a **high-level** -(utility-wrapped) variant. +A lightweight, pip-installable DSL package for building PTO MLIR IR modules +in Python. The API is inspired by Triton / CuteDSL: kernels are ordinary +Python functions decorated with `@pto.to_ir`, type annotations carry PTO +types as lazy descriptors, and control-flow maps 1-to-1 to MLIR operations. --- @@ -11,26 +11,32 @@ are provided, each in a **low-level** (raw bindings) and a **high-level** ``` ptodsl/ -├── ptodsl_utils.py # Reusable utility wrappers -│ -├── tile_and_vpto_builder_lowlevel.py # TADD kernel – raw bindings -├── tile_and_vpto_builder_highlevel.py # TADD kernel – ptodsl_utils -│ -├── softmax_builder_lowlevel.py # Softmax kernel – raw bindings -├── softmax_builder_highlevel.py # Softmax kernel – ptodsl_utils -│ -└── check_ir.py # IR correctness test for all builders +├── ptodsl/ # pip-installable package +│ ├── __init__.py # exports: pto, scalar +│ ├── pto.py # main pto.* namespace +│ ├── scalar.py # pto.scalar.* arith helpers +│ ├── _bootstrap.py # MLIR path setup + context factory +│ ├── _types.py # lazy dtype descriptors and type constructors +│ ├── _ops.py # PTO operation wrappers +│ ├── _control_flow.py # vecscope, for_, if_, yield_ context managers +│ └── _module.py # @pto.to_ir decorator + module builders +├── examples/ +│ ├── tadd_lowlevel.py # TADD – raw MLIR Python binding calls +│ ├── tadd_dsl.py # TADD – @pto.to_ir DSL style +│ ├── softmax_lowlevel.py # Softmax – raw MLIR Python binding calls +│ └── softmax_dsl.py # Softmax – @pto.to_ir DSL style +├── pyproject.toml # pip install -e . +├── check_ir.py # IR correctness test runner +└── README.md ``` --- ## Prerequisites -The ptoas dialect must be installed and the environment set up before use: - ```bash -# Install (first time only) -cd /workdir/ptoas_a5 +# Install ptoas (first time only) +cd $PTOAS_REPO_ROOT # e.g. export PTOAS_REPO_ROOT=/workdir/ptoas_a5 bash quick_install.sh # Set up environment in every new shell @@ -39,204 +45,171 @@ source set_ptoas_env.sh --- +## Install the package + +```bash +cd $PTOAS_REPO_ROOT/ptodsl +pip install -e . +``` + +--- + ## Running the IR check ```bash -# From ptoas_a5/ptodsl/ +# From $PTOAS_REPO_ROOT/ptodsl/ python3 check_ir.py -# Or from the repository root (ptoas_a5/) +# From the repository root ($PTOAS_REPO_ROOT) python3 ptodsl/check_ir.py ``` -Expected output when everything is correct: +Expected output: ``` ptodsl IR check ================================================== PASS TADD low-level - PASS TADD high-level + PASS TADD dsl-style PASS softmax low-level - PASS softmax high-level + PASS softmax dsl-style ================================================== Result: ALL PASS ``` -Exit code is `0` on full pass, `1` if any check fails. -A unified diff of the first 60 diverging lines is printed for each failing case. - ---- - -## Kernel examples - -### TADD — simple vector add (vPTO) - -| File | Reference | -|---|---| -| `tile_and_vpto_builder_lowlevel.py` | `test/lit/vpto/expand_tileop_to_vpto_result.pto` | -| `tile_and_vpto_builder_highlevel.py` | same | - -The kernel performs an element-wise vector add over a 1024-element float32 -buffer using 16 iterations of 64-wide vector instructions inside a -`pto.vecscope`. It exercises: -`pto.castptr`, `pto.addptr`, `pto.plt_b32`, `pto.vlds`, `pto.vadd`, -`pto.vsts`, nested modules (`pto.target_arch` + `pto.kernel_kind`). - -### Online softmax update - -| File | Reference | -|---|---| -| `softmax_builder_lowlevel.py` | `test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto` | -| `softmax_builder_highlevel.py` | same | - -An online softmax update kernel that mixes tile-domain loads/stores with -raw vector compute inside a `pto.vecscope`. It exercises a significantly -larger set of ops including: -`pto.get_block_idx`, `pto.make_tensor_view`, `pto.partition_view`, -`pto.alloc_tile`, `pto.tload`/`pto.tstore`, `pto.set_flag`/`pto.wait_flag`, -`pto.tile_buf_addr`, `pto.pset_b32`, `pto.vcmax`, `pto.vdup`, `pto.vmax`, -`pto.vexpdif`, `pto.vmul`, `pto.vcadd`, `pto.vdiv`, `pto.barrier`, -`scf.for` with `iter_args`, and `scf.if` with result values. +Exit code is `0` on full pass, `1` on any failure. A unified diff of up to +60 diverging lines is printed for each failing case. --- -## How the IR check works - -`check_ir.py` calls `build()` in each builder, then compares the resulting -module against its reference `.pto` file using MLIR round-trip normalization: +## DSL-style API quick reference +```python +from ptodsl import pto +s = pto.scalar # arith shorthand alias ``` -generated IR ──┐ - ├── Module.parse() → canonical string ──── == ──── PASS/FAIL -reference .pto ──┘ (strips comments, normalises SSA names and attr order) -``` - -**Why round-trip normalization?** - -| Issue | Raw text comparison | Round-trip comparison | -|---|---|---| -| `// comment` lines in `.pto` files | breaks | ignored by MLIR parser | -| Named SSA values (`%block_idx`) vs anonymous (`%0`) | breaks | both become `%0`, `%1` … | -| Attribute dict ordering (`{a=1, b=2}` vs `{b=2, a=1}`) | breaks | normalized | -| Constant declaration order | breaks | **preserved** – must match | -Because constant declaration order is preserved after round-trip, builders -must emit constants in the same order as the reference. The `check_ir.py` -diff output makes such mismatches easy to locate. +### Kernel decorator ---- - -## `ptodsl_utils.py` – utility reference +```python +@pto.to_ir(name="MyKernel", kernel_kind="vector", arch="a5") +def MyKernel(): + ... -The utility module eliminates boilerplate so kernel logic is immediately -readable. All helpers operate on the **current** MLIR context and insertion -point; no context argument is threaded. +@pto.to_ir(name="Softmax", kernel_kind="vector", arch="a5", func_attr="pto.aicore") +def Softmax(arg0: pto.ptr(pto.float32, "gm"), n: pto.int32): + ... -### Type constructors +print(MyKernel) # prints MLIR text +mod = MyKernel.build() # returns mlir.ir.Module +``` -| Helper | MLIR type | -|---|---| -| `i32_type()` | `i32` | -| `i64_type()` | `i64` | -| `idx_type()` | `index` | -| `ptr_type(elem, space="ub")` | `!pto.ptr` | -| `vreg_type(lanes, elem)` | `!pto.vreg` | -| `mask_type(bits="b32")` | `!pto.mask` | +`func_attr="pto.aicore"` selects a flat single-module structure with the +`pto.aicore` function attribute (softmax style). Without it, a nested +double-module is emitted (TADD style). -### Constant builders +### Type descriptors (lazy – safe to use in annotations) -| Helper | Op | +| Expression | MLIR type | |---|---| -| `c_idx(v)` | `arith.constant v : index` | -| `c_i32(v)` | `arith.constant v : i32` | -| `c_i64(v)` | `arith.constant v : i64` | +| `pto.float32` | `f32` | +| `pto.int32` | `i32` | +| `pto.int64` | `i64` | +| `pto.index` | `index` | +| `pto.ptr(pto.float32, "gm")` | `!pto.ptr` | +| `pto.ptr(pto.float32, "ub")` | `!pto.ptr` | -### Arithmetic +### Type constructors (eager – require active context) -`muli`, `addi`, `subi` — `arith.muli/addi/subi` -`index_cast(type, val)` — `arith.index_cast` -`cmpi_sgt(a, b)` — `arith.cmpi sgt` -`select_val(cond, t, f)` — `arith.select` +```python +vf32 = pto.vreg_type(64, pto.float32) # !pto.vreg<64xf32> +tile_col = pto.tile_buf_type([8,1], pto.float32, [-1,1], blayout="ColMajor") +tile_w = pto.tile_buf_type([8,128], pto.float32, [-1,-1]) +``` -### Module / function builders +### Constants ```python -with pto_context(): # MLIR Context + PTO dialect - with vpto_kernel("MyKernel", arch="a5") as mod: # nested module + func (no args) - ... - -with pto_context(): - with flat_pto_module("a5") as mod: # flat module + pto.kernel_kind - with pto_aicore_func("f", [ptr_gm, i32]) as (p, n): # func with args - ... +c0 = pto.const(0) # index +c1_i32 = pto.const(1, dtype=pto.int32) +c64_i64= pto.const(64, dtype=pto.int64) ``` -### Control-flow helpers +### Control flow ```python -with vecscope(): # pto.vecscope { ... } - -with for_range(lo, hi, step) as i: # scf.for, auto-inserts scf.yield +with pto.vecscope(): # pto.vecscope { … } ... -with for_range_iter(lo, hi, step, [a, b]) as cf: # scf.for with iter_args - x, y = cf.inner_iter_args - yield_vals(new_x, new_y) # scf.yield at end of body -final_x, final_y = cf.results # results accessible after the block +with pto.for_(c0, c16, step=c1) as i: # simple scf.for + ... # scf.yield inserted automatically -with if_ctx(cond): # scf.if, no results, auto-inserts scf.yield +with pto.for_(c0, c128, step=c64, iter_args=(a, b)) as loop: + x, y = loop.iter_args ... + pto.yield_(nx, ny) # scf.yield with values +fx, fy = loop.results + +with pto.if_(has_rows): # simple scf.if + ... # scf.yield inserted automatically -br = if_op_returning(cond, [vreg, vreg]) # scf.if with results + else -with InsertionPoint(br.then_block): - yield_vals(a, b) -with InsertionPoint(br.else_block): - yield_vals(c, d) +with pto.if_(has_chunk, results=(vf32, vf32)) as br: + with br.then_: + ... + pto.yield_(merged_max, merged_sum) + with br.else_: + pto.yield_(running_max, running_sum) x, y = br.results ``` -### Tile-domain helpers +### Scalar arithmetic (`s = pto.scalar`) ```python -tv = tile_view(tv_type, ptr, shape, strides) # pto.make_tensor_view -ptv = part_view(ptv_type, tv, offsets, sizes) # pto.partition_view -t = alloc_tile(tile_type, addr=a, valid_row=r, valid_col=c) # pto.alloc_tile -tload(part, tile) # pto.tload -tstore(tile, part) # pto.tstore -ub = tile_ptr(tile, ptr_ub_type) # pto.tile_buf_addr +s.muli(a, b) # arith.muli +s.addi(a, b) # arith.addi +s.subi(a, b) # arith.subi +s.index_cast(val) # arith.index_cast → index +s.index_cast(pto.int32, val) # arith.index_cast → i32 +s.cmpi_sgt(a, b) # arith.cmpi sgt +s.cmpi("slt", a, b) # arith.cmpi with named predicate +s.select(cond, t, f) # arith.select ``` -### Vector / pointer helpers +### PTO operations ```python -ptr = castptr(int_addr, ptr_type) # pto.castptr -ptr2 = addptr(ptr, offset) # pto.addptr -v = vlds(ptr, offset, vreg_type) # pto.vlds -v = vbrc_load(ptr, offset, vreg_type) # pto.vlds {dist="BRC_B32"} -vsts(v, ptr, offset, mask) # pto.vsts -vsts_1pt(v, ptr, offset, mask) # pto.vsts {dist="1PT_B32"} -mask, _ = plt_b32(scalar) # pto.plt_b32 -mask = pset_b32("PAT_ALL") # pto.pset_b32 +pto.castptr(addr, ptr_type) # pto.castptr +pto.addptr(ptr, offset) # pto.addptr +pto.vlds(ptr, offset, vreg_type) # pto.vlds +pto.vbrc_load(ptr, offset, vreg_type) # pto.vlds {dist="BRC_B32"} +pto.vsts(v, ptr, offset, mask) # pto.vsts +pto.vsts_1pt(v, ptr, offset, mask) # pto.vsts {dist="1PT_B32"} +pto.plt_b32(scalar) # → (mask, scalar_out) +pto.pset_b32("PAT_ALL") # pto.pset_b32 → mask +pto.vadd(a, b, mask) # infers result type from a.type +pto.vmul / vmax / vdiv / vcmax / vcadd / vdup / vexpdif # similarly +pto.make_tensor_view(ptr, shape=…, strides=…) # type inferred +pto.partition_view(tv, offsets=…, sizes=…) # type inferred +pto.alloc_tile(tile_type, addr=…, valid_row=…, valid_col=…) +pto.tload(part, tile) +pto.tstore(tile, part) +pto.tile_ptr(tile, ptr_type) +pto.get_block_idx() # → i64 +pto.set_flag("MTE2", "V", event_id=0) +pto.wait_flag("MTE2", "V", event_id=0) +pto.barrier_all() ``` -### Vector math (result type inferred from first operand) - -```python -vcmax(v, mask) # cross-lane max reduction -vdup_lowest(v, mask) # broadcast lane 0 to all lanes -vmax(a, b, mask) # element-wise max -vexpdif(x, ref, mask) # exp(x − ref), ODD lanes -vmul(a, b, mask) # element-wise multiply -vcadd(v, mask) # cross-lane add (sum reduction) -vadd(a, b, mask) # element-wise add (result_type optional) -vdiv(a, b, mask) # element-wise divide -``` +--- -### Hardware / sync +## How the IR check works -```python -get_block_idx() # pto.get_block_idx → i64 -barrier_all() # pto.barrier #pto.pipe -# use pto.set_flag / pto.wait_flag directly (from mlir.dialects.pto) -# use yield_vals(*vals) as shorthand for scf.YieldOp(list(vals)) ``` +generated IR ──┐ + ├── Module.parse() → canonical string ──── == ──── PASS/FAIL +reference .pto ──┘ (strips comments, normalises SSA names and attr order) +``` + +Constant declaration order is preserved after the round-trip; builders must +emit constants in the same order as the reference. The diff output makes any +mismatch immediately visible. diff --git a/ptodsl/check_ir.py b/ptodsl/check_ir.py index 355ca04a6..6be6fdd71 100644 --- a/ptodsl/check_ir.py +++ b/ptodsl/check_ir.py @@ -7,46 +7,49 @@ # See LICENSE in the root of the software repository for the full text of the License. """ -IR correctness check for all ptodsl builder scripts. +IR correctness check for all ptodsl example scripts. Run from the repository root or from this directory: python3 ptodsl/check_ir.py # from ptoas_a5/ - python3 check_ir.py # from ptodsl/ + python3 check_ir.py # from ptoas_a5/ptodsl/ -Each builder's build() function is called; its output is compared against the -corresponding hand-written reference .pto file. +Each example's ``build()`` function is called; its output is compared against +the corresponding hand-written reference ``.pto`` file. Comparison methodology ────────────────────── Both the generated module and the reference file are parsed by the MLIR Python -API (Module.parse), then printed back to a string. This round-trip: +API (``Module.parse``), then printed back to a string. This round-trip: - • Strips comments (// lines in .pto files are ignored by the MLIR parser) - • Normalises SSA value names (%block_idx → %0, %running_max → %arg11, …) - • Normalises attribute ordering (MLIR sorts dict-like attribute sets) + • Strips ``//`` comments present in hand-written ``.pto`` files + • Normalises SSA value names (``%block_idx`` → ``%0``, …) + • Normalises attribute ordering -The resulting canonical strings are compared with ==. A diff of the first 60 -differing lines is printed on failure to aid diagnosis. +The resulting canonical strings are compared with ``==``. A unified diff of +the first 60 diverging lines is printed on failure. """ import difflib +import importlib import os import sys -# Allow running from either ptoas_a5/ or ptoas_a5/ptodsl/ -_HERE = os.path.dirname(os.path.abspath(__file__)) -if _HERE not in sys.path: - sys.path.insert(0, _HERE) +# ── Path setup ──────────────────────────────────────────────────────────────── -# ── MLIR bootstrap ─────────────────────────────────────────────────────────── +_HERE = os.path.dirname(os.path.abspath(__file__)) +_EXAMPLES = os.path.join(_HERE, "examples") _MLIR_INSTALL = os.path.join(_HERE, "..", "install", "mlir") -if _MLIR_INSTALL not in sys.path: - sys.path.insert(0, _MLIR_INSTALL) + +for _p in (_MLIR_INSTALL, _HERE, _EXAMPLES): + if _p not in sys.path: + sys.path.insert(0, _p) from mlir.ir import Context, Module # noqa: E402 from mlir.dialects import pto as _pto_mod # noqa: E402 +# ── Helpers ─────────────────────────────────────────────────────────────────── + def _normalize(mlir_text: str) -> str: """Parse *mlir_text* with MLIR and return the canonical printed form.""" with Context() as ctx: @@ -55,56 +58,44 @@ def _normalize(mlir_text: str) -> str: def _strip_comments(text: str) -> str: - """Remove // comment lines that appear in hand-written .pto files.""" + """Remove ``//`` comment lines found in hand-written ``.pto`` files.""" return "\n".join( line for line in text.splitlines() if not line.strip().startswith("//") ) # ── Test cases ──────────────────────────────────────────────────────────────── -# Each entry: (label, builder_module_path, reference_pto_path) +# Each entry: (label, module_name, reference_pto_path) + _REPO_ROOT = os.path.abspath(os.path.join(_HERE, "..")) +_TADD_REF = os.path.join(_REPO_ROOT, "test/lit/vpto/expand_tileop_to_vpto_result.pto") +_SOFTMAX_REF = os.path.join(_REPO_ROOT, + "test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto") CASES = [ - ( - "TADD low-level ", - "tile_and_vpto_builder_lowlevel", - os.path.join(_REPO_ROOT, - "test/lit/vpto/expand_tileop_to_vpto_result.pto"), - ), - ( - "TADD high-level", - "tile_and_vpto_builder_highlevel", - os.path.join(_REPO_ROOT, - "test/lit/vpto/expand_tileop_to_vpto_result.pto"), - ), - ( - "softmax low-level ", - "softmax_builder_lowlevel", - os.path.join(_REPO_ROOT, - "test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto"), - ), - ( - "softmax high-level", - "softmax_builder_highlevel", - os.path.join(_REPO_ROOT, - "test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto"), - ), + ("TADD low-level ", "tadd_lowlevel", _TADD_REF), + ("TADD dsl-style ", "tadd_dsl", _TADD_REF), + ("softmax low-level", "softmax_lowlevel", _SOFTMAX_REF), + ("softmax dsl-style", "softmax_dsl", _SOFTMAX_REF), ] # ── Runner ──────────────────────────────────────────────────────────────────── def run_checks(cases=CASES) -> bool: - """Execute every check case; return True if all passed.""" + """Execute every check case; return ``True`` if all passed.""" all_passed = True for label, module_name, ref_path in cases: - # -- import the builder and call build() -- + # -- import the example and call build() -- try: - builder = __import__(module_name) - generated_module = builder.build() - generated_text = str(generated_module) + # Re-import on every run so state doesn't leak between cases + spec = importlib.util.spec_from_file_location( + module_name, os.path.join(_EXAMPLES, f"{module_name}.py") + ) + builder = importlib.util.module_from_spec(spec) + spec.loader.exec_module(builder) + generated_text = str(builder.build()) except Exception as exc: print(f" FAIL {label} [builder error: {exc}]") all_passed = False @@ -146,7 +137,7 @@ def run_checks(cases=CASES) -> bool: snippet = "\n".join(diff_lines[:60]) print(f" FAIL {label}\n{snippet}") if len(diff_lines) > 60: - print(f" ... ({len(diff_lines) - 60} more diff lines)") + print(f" … ({len(diff_lines) - 60} more diff lines)") return all_passed diff --git a/ptodsl/examples/softmax_dsl.py b/ptodsl/examples/softmax_dsl.py new file mode 100644 index 000000000..ec8311c64 --- /dev/null +++ b/ptodsl/examples/softmax_dsl.py @@ -0,0 +1,245 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Online softmax update kernel – DSL-style builder. + +Generates the same IR as + test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto +using the ``@pto.to_ir`` decorator and the ``pto.*`` namespace. + +The Python maps almost line-for-line to the target MLIR: + + func.func @online_softmax_update_kernel_2d( # function signature + %arg0: !pto.ptr, …, %arg7: i32, …) # arg0: pto.ptr(…), … + + scf.if %has_rows { # with pto.if_(has_rows): + pto.tload ins(…) outs(…) # pto.tload(part, tile) + pto.vecscope { # with pto.vecscope(): + scf.for %row = … { # with pto.for_(…) as row: + %final_max, %final_sum = # + scf.for %chunk = … iter_args(…) { # with pto.for_(…, iter_args=…) as loop: + scf.if %has_chunk → (vreg, vreg) { # with pto.if_(…, results=…) as br: + scf.yield %merged_max, %merged_sum # pto.yield_(…) + } else { # with br.else_: + scf.yield %running_max, %running_sum # pto.yield_(…) + } # + scf.yield %next_max, %next_sum # pto.yield_(…) + } # + } # + } # + } # + pto.barrier # pto.barrier_all() +""" + +from ptodsl import pto + +s = pto.scalar # arith shorthand alias + + +@pto.to_ir( + name="online_softmax_update_kernel_2d", + kernel_kind="vector", + arch="a5", + func_attr="pto.aicore", +) +def online_softmax_update_kernel_2d( + arg0: pto.ptr(pto.float32, "gm"), + arg1: pto.ptr(pto.float32, "gm"), + arg2: pto.ptr(pto.float32, "gm"), + arg3: pto.ptr(pto.float32, "gm"), + arg4: pto.ptr(pto.float32, "gm"), + arg5: pto.ptr(pto.float32, "gm"), + arg6: pto.ptr(pto.float32, "gm"), + arg7: pto.int32, + arg8: pto.int32, +): + # ── Index constants ────────────────────────────────────────────────────── + c0 = pto.const(0) + c1 = pto.const(1) + c8 = pto.const(8) + c64 = pto.const(64) + c128 = pto.const(128) + + # ── i64 address constants (UB tile base addresses) ─────────────────────── + c0_i64 = pto.const(0, dtype=pto.int64) + c1_i64 = pto.const(1, dtype=pto.int64) # noqa: F841 (present in IR) + c8_i64 = pto.const(8, dtype=pto.int64) # noqa: F841 + c16_i64 = pto.const(16, dtype=pto.int64) # noqa: F841 + c32_i64 = pto.const(32, dtype=pto.int64) # noqa: F841 + c64_i64 = pto.const(64, dtype=pto.int64) # noqa: F841 + c128_i64 = pto.const(128, dtype=pto.int64) + c256_i64 = pto.const(256, dtype=pto.int64) + c512_i64 = pto.const(512, dtype=pto.int64) # noqa: F841 + c8448_i64 = pto.const(8448, dtype=pto.int64) + c16640_i64 = pto.const(16640, dtype=pto.int64) + c16768_i64 = pto.const(16768, dtype=pto.int64) + c16896_i64 = pto.const(16896, dtype=pto.int64) + + # ── i32 constants ──────────────────────────────────────────────────────── + c1_i32 = pto.const(1, dtype=pto.int32) + c8_i32 = pto.const(8, dtype=pto.int32) + c64_i32 = pto.const(64, dtype=pto.int32) + c0_i32 = pto.const(0, dtype=pto.int32) + + # ── Block-level row assignment ──────────────────────────────────────────── + block_i64 = pto.get_block_idx() + block_idx = s.index_cast(block_i64) # → index + row_base = s.muli(block_idx, c8) + _ = s.index_cast(pto.int32, c8) # block_rows_i32 + row_base_i32 = s.index_cast(pto.int32, row_base) + remaining_rows= s.subi(arg8, row_base_i32) + has_rows = s.cmpi_sgt(remaining_rows, c0_i32) + too_many_rows = s.cmpi_sgt(remaining_rows, c8_i32) + row_count_i32 = s.select(too_many_rows, c8_i32, remaining_rows) + row_count = s.index_cast(row_count_i32) # → index + seq = s.index_cast(arg7) # → index + rows = s.index_cast(arg8) # → index + rows_x_128 = s.muli(rows, c128) + + with pto.if_(has_rows): + # ── Tensor views ───────────────────────────────────────────────────── + s1 = [rows, rows, rows, c1, rows] + s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] + sh1 = [c1, c1, c1, rows, c1] + sh128= [c1, c1, c1, rows, c128] + + oldmax_view = pto.make_tensor_view(arg0, shape=sh1, strides=s1) + oldsum_view = pto.make_tensor_view(arg1, shape=sh1, strides=s1) + qk_view = pto.make_tensor_view(arg2, shape=sh128, strides=s128) + newmax_view = pto.make_tensor_view(arg3, shape=sh1, strides=s1) + newsum_view = pto.make_tensor_view(arg4, shape=sh1, strides=s1) + expmax_view = pto.make_tensor_view(arg5, shape=sh1, strides=s1) + out_view = pto.make_tensor_view(arg6, shape=sh128, strides=s128) + + # ── Partition views ─────────────────────────────────────────────────── + off = [c0, c0, c0, row_base, c0] + z1 = [c1, c1, c1, row_count, c1] + zs = [c1, c1, c1, row_count, seq] + + oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) + oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) + qk_part = pto.partition_view(qk_view, offsets=off, sizes=zs) + newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) + newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) + expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) + out_part = pto.partition_view(out_view, offsets=off, sizes=zs) + + # ── UB tile allocation ──────────────────────────────────────────────── + tile_col = pto.tile_buf_type([8, 1], pto.float32, [-1, 1], blayout="ColMajor") + tile_w = pto.tile_buf_type([8, 128], pto.float32, [-1, -1]) + + oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) + oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) + qk_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) + out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) + newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) + newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) + expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) + + # ── Tile loads from GM ──────────────────────────────────────────────── + pto.tload(oldmax_part, oldmax_tile) + pto.tload(oldsum_part, oldsum_tile) + pto.tload(qk_part, qk_tile) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.vecscope(): + # Materialise typed UB pointers from tile handles + ptr_ub = pto.ptr(pto.float32, "ub") + vf32 = pto.vreg_type(64, pto.float32) + + ub_om = pto.tile_ptr(oldmax_tile, ptr_ub) + ub_os = pto.tile_ptr(oldsum_tile, ptr_ub) + ub_qk = pto.tile_ptr(qk_tile, ptr_ub) + ub_out = pto.tile_ptr(out_tile, ptr_ub) + ub_nm = pto.tile_ptr(newmax_tile, ptr_ub) + ub_ns = pto.tile_ptr(newsum_tile, ptr_ub) + ub_em = pto.tile_ptr(expmax_tile, ptr_ub) + + active = pto.pset_b32("PAT_ALL") + one_mask, _ = pto.plt_b32(c1_i32) + + with pto.for_(c0, row_count, step=c1) as row: + row_qk = s.muli(row, c128) + oldmax_bc = pto.vbrc_load(ub_om, row, vf32) + oldsum_bc = pto.vbrc_load(ub_os, row, vf32) + + # scf.for with iter_args: accumulate (running_max, running_sum) + with pto.for_(c0, c128, step=c64, iter_args=(oldmax_bc, oldsum_bc)) as loop: + chunk = loop.iv + running_max, running_sum = loop.iter_args + + chunk_i32 = s.index_cast(pto.int32, chunk) + remaining_cols = s.subi(arg7, chunk_i32) + has_chunk = s.cmpi_sgt(remaining_cols, c0_i32) + + # scf.if with results – produce (next_max, next_sum) + with pto.if_(has_chunk, results=(vf32, vf32)) as br: + with br.then_: + chunk_mask, _ = pto.plt_b32(remaining_cols) + chunk_base = s.addi(row_qk, chunk) + vec = pto.vlds(ub_qk, chunk_base, vf32) + chunk_max = pto.vcmax(vec, chunk_mask) + chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") + merged_max = pto.vmax(running_max, chunk_max_bc, active) + scaled_running = pto.vexpdif(running_max, merged_max, active) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active) + chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask) + chunk_sum = pto.vcadd(chunk_exp, chunk_mask) + chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") + merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) + pto.yield_(merged_max, merged_sum) + with br.else_: + pto.yield_(running_max, running_sum) + + next_max, next_sum = br.results + pto.yield_(next_max, next_sum) + + final_max, final_sum = loop.results + + # Compute per-row expmax scalar + raw_em = pto.vexpdif(oldmax_bc, final_max, active) + sc_os = pto.vmul(raw_em, oldsum_bc, active) + expmax = pto.vdiv(sc_os, final_sum, active) + + pto.vsts_1pt(final_max, ub_nm, row, one_mask) + pto.vsts_1pt(final_sum, ub_ns, row, one_mask) + pto.vsts_1pt(expmax, ub_em, row, one_mask) + + # Output normalisation loop + with pto.for_(c0, c128, step=c64) as chunk2: + rem2 = s.subi(arg7, s.index_cast(pto.int32, chunk2)) + has_chunk2= s.cmpi_sgt(rem2, c0_i32) + with pto.if_(has_chunk2): + cmask2, _ = pto.plt_b32(rem2) + cbase2 = s.addi(row_qk, chunk2) + vec2 = pto.vlds(ub_qk, cbase2, vf32) + exp2 = pto.vexpdif(vec2, final_max, cmask2) + out2 = pto.vdiv(exp2, final_sum, cmask2) + pto.vsts(out2, ub_out, cbase2, cmask2) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + # Tile stores to GM + pto.tstore(newmax_tile, newmax_part) + pto.tstore(newsum_tile, newsum_part) + pto.tstore(expmax_tile, expmax_part) + pto.tstore(out_tile, out_part) + + pto.barrier_all() + + +def build(): + return online_softmax_update_kernel_2d._ir_module + + +if __name__ == "__main__": + print(online_softmax_update_kernel_2d) diff --git a/ptodsl/softmax_builder_lowlevel.py b/ptodsl/examples/softmax_lowlevel.py similarity index 100% rename from ptodsl/softmax_builder_lowlevel.py rename to ptodsl/examples/softmax_lowlevel.py diff --git a/ptodsl/examples/tadd_dsl.py b/ptodsl/examples/tadd_dsl.py new file mode 100644 index 000000000..96983c2c0 --- /dev/null +++ b/ptodsl/examples/tadd_dsl.py @@ -0,0 +1,67 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +TADD vPTO kernel – DSL-style builder. + +Generates the same IR as expand_tileop_to_vpto_result.pto using the +``@pto.to_ir`` decorator and the ``pto.*`` namespace. + +The Python code maps 1-to-1 to the MLIR IR lines: + + func.func @TADD() { # @pto.to_ir(name="TADD", …) + %c0_i64 = arith.constant 0 : i64 # pto.const(0, dtype=pto.int64) + %c16 = arith.constant 16 : index # pto.const(16, dtype=pto.index) + … + pto.vecscope { # with pto.vecscope(): + %0 = pto.castptr %c4096_i64 … # pto.castptr(c4096_i64, …) + scf.for %arg0 = %c0 to %c16 … { # with pto.for_(c0, c16, step=c1) as i: + %mask, _ = pto.plt_b32 … # pto.plt_b32(c64_i32) + … + } + } + } +""" + +from ptodsl import pto + +s = pto.scalar # arith shorthand alias + + +@pto.to_ir(name="TADD", kernel_kind="vector", arch="a5") +def TADD(): + c0_i64 = pto.const(0, dtype=pto.int64) + c16 = pto.const(16, dtype=pto.index) + c4096_i64 = pto.const(4096, dtype=pto.int64) + c0 = pto.const(0) + c1 = pto.const(1) + c64_i32 = pto.const(64, dtype=pto.int32) + c64 = pto.const(64) + + with pto.vecscope(): + ptr_f32_ub = pto.ptr(pto.float32, "ub") + vf32 = pto.vreg_type(64, pto.float32) + ptr_src = pto.castptr(c4096_i64, ptr_f32_ub) + ptr_dst = pto.castptr(c0_i64, ptr_f32_ub) + + with pto.for_(c0, c16, step=c1) as tile_idx: + mask, _ = pto.plt_b32(c64_i32) + tile_off = s.muli(tile_idx, c64) + va = pto.vlds(pto.addptr(ptr_src, tile_off), c0, vf32) + ptr_dst_tile = pto.addptr(ptr_dst, tile_off) + vb = pto.vlds(ptr_dst_tile, c0, vf32) + vc = pto.vadd(va, vb, mask) + pto.vsts(vc, ptr_dst_tile, c0, mask) + + +def build(): + return TADD._ir_module + + +if __name__ == "__main__": + print(TADD) diff --git a/ptodsl/tile_and_vpto_builder_lowlevel.py b/ptodsl/examples/tadd_lowlevel.py similarity index 100% rename from ptodsl/tile_and_vpto_builder_lowlevel.py rename to ptodsl/examples/tadd_lowlevel.py diff --git a/ptodsl/ptodsl/__init__.py b/ptodsl/ptodsl/__init__.py new file mode 100644 index 000000000..f558e21e3 --- /dev/null +++ b/ptodsl/ptodsl/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""ptodsl – PTO MLIR DSL package.""" + +from . import pto, scalar # noqa: F401 + +__all__ = ["pto", "scalar"] diff --git a/ptodsl/ptodsl/_bootstrap.py b/ptodsl/ptodsl/_bootstrap.py new file mode 100644 index 000000000..894310ae3 --- /dev/null +++ b/ptodsl/ptodsl/_bootstrap.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +MLIR path bootstrap and context factory. + +Adds the ptoas install directory to sys.path so that the mlir package is +importable regardless of how the ptodsl package itself was installed. +""" + +import os +import sys + +_INSTALL = os.path.normpath( + os.path.join(os.path.dirname(__file__), "..", "..", "install", "mlir") +) +if os.path.isdir(_INSTALL) and _INSTALL not in sys.path: + sys.path.insert(0, _INSTALL) + +from mlir.dialects import pto as _pto_dialect # noqa: E402 +from mlir.ir import Context, Location # noqa: E402 + + +def make_context() -> Context: + """Create a fresh MLIR Context with the PTO dialect loaded.""" + ctx = Context() + _pto_dialect.register_dialect(ctx, load=True) + return ctx + + +__all__ = ["make_context"] diff --git a/ptodsl/ptodsl/_control_flow.py b/ptodsl/ptodsl/_control_flow.py new file mode 100644 index 000000000..5d20b91e6 --- /dev/null +++ b/ptodsl/ptodsl/_control_flow.py @@ -0,0 +1,230 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Control-flow context managers for PTO kernels. + +All CMs work with the current MLIR insertion point; no context threading needed. + +Public API +────────── +``vecscope()`` – ``pto.vecscope { … }`` +``for_(lo, hi, step, *, iter_args)`` + – ``scf.for`` with optional iter_args +``if_(cond, *, results)`` – ``scf.if`` with optional results + else +``yield_(*vals)`` – ``scf.yield`` +""" + +from ._bootstrap import make_context # noqa: F401 +from ._types import _resolve + +from mlir.dialects import pto as _pto, scf +from mlir.ir import InsertionPoint + + +# ── vecscope ────────────────────────────────────────────────────────────────── + +class _VecScopeCM: + """Context manager for ``pto.vecscope { … }``.""" + + def __enter__(self): + self._op = _pto.VecScopeOp() + self._block = self._op.body.blocks.append() + self._ip = InsertionPoint(self._block) + self._ip.__enter__() + return None + + def __exit__(self, *exc): + self._ip.__exit__(*exc) + + +def vecscope() -> _VecScopeCM: + """Return a context manager that emits ``pto.vecscope { … }``.""" + return _VecScopeCM() + + +# ── for_ ────────────────────────────────────────────────────────────────────── + +class LoopHandle: + """ + Handle for a ``scf.for`` loop with iter_args. + + Attributes available *after* the ``with pto.for_(…) as loop:`` block:: + + loop.iv – induction variable + loop.iter_args – tuple of inner (mutable) SSA values + loop.results – tuple of ForOp results (after loop exit) + """ + + def __init__(self, for_op): + self._op = for_op + + @property + def iv(self): + return self._op.induction_variable + + @property + def iter_args(self): + return tuple(self._op.inner_iter_args) + + @property + def results(self): + return tuple(self._op.results) + + +class _ForCM: + def __init__(self, start, stop, step, iter_args): + self._start = start + self._stop = stop + self._step = step + self._iter_args = list(iter_args) if iter_args is not None else [] + self._for_op = None + self._ip = None + + def __enter__(self): + self._for_op = scf.ForOp( + self._start, self._stop, self._step, + self._iter_args if self._iter_args else None, + ) + self._ip = InsertionPoint(self._for_op.body) + self._ip.__enter__() + if not self._iter_args: + return self._for_op.induction_variable + return LoopHandle(self._for_op) + + def __exit__(self, *exc): + if not self._iter_args: + scf.YieldOp([]) + self._ip.__exit__(*exc) + + +def for_(start, stop, *, step, iter_args=None) -> _ForCM: + """ + ``scf.for`` context manager. + + Without ``iter_args`` – yields the induction variable; ``scf.yield`` is + inserted automatically:: + + with pto.for_(c0, c16, step=c1) as i: + ... + + With ``iter_args`` – yields a :class:`LoopHandle`; the caller must emit + ``pto.yield_(…)`` before the block closes:: + + with pto.for_(c0, c128, step=c64, iter_args=(a, b)) as loop: + x, y = loop.iter_args + ... + pto.yield_(nx, ny) + fa, fb = loop.results + """ + return _ForCM(start, stop, step, iter_args) + + +# ── if_ ─────────────────────────────────────────────────────────────────────── + +class _BlockCM: + """Enters the InsertionPoint of a single block for ``with br.then_:`` style.""" + + def __init__(self, block): + self._block = block + self._ip = None + + def __enter__(self): + self._ip = InsertionPoint(self._block) + self._ip.__enter__() + + def __exit__(self, *exc): + self._ip.__exit__(*exc) + + +class BranchHandle: + """ + Handle for ``scf.if`` with results and an else branch. + + Usage:: + + with pto.if_(cond, results=(vf32, vf32)) as br: + with br.then_: + ... + pto.yield_(a, b) + with br.else_: + pto.yield_(c, d) + x, y = br.results + """ + + def __init__(self, if_op): + self._op = if_op + self.then_ = _BlockCM(if_op.then_block) + self.else_ = _BlockCM(if_op.else_block) + + @property + def results(self): + return tuple(self._op.results) + + +class _IfCM: + def __init__(self, cond, result_types): + self._cond = cond + self._result_types = [_resolve(t) for t in result_types] if result_types else [] + self._if_op = None + self._ip = None + + def __enter__(self): + if self._result_types: + # if/else with results: create IfOp but don't enter any block; + # the caller manages blocks via br.then_ / br.else_ + self._if_op = scf.IfOp(self._cond, self._result_types, hasElse=True) + return BranchHandle(self._if_op) + else: + # simple if without results: enter then_block automatically + self._if_op = scf.IfOp(self._cond) + self._ip = InsertionPoint(self._if_op.then_block) + self._ip.__enter__() + return None + + def __exit__(self, *exc): + if not self._result_types: + scf.YieldOp([]) + self._ip.__exit__(*exc) + # for if/else with results: blocks are managed by BranchHandle; nothing to do + + +def if_(cond, *, results=None) -> _IfCM: + """ + ``scf.if`` context manager. + + Without ``results`` – simple if with no else; ``scf.yield`` is inserted + automatically:: + + with pto.if_(has_rows): + ... + + With ``results`` – if/else pair that produces SSA values; the caller must + manage ``br.then_`` and ``br.else_`` and emit ``pto.yield_(…)`` in each:: + + with pto.if_(has_chunk, results=(vf32, vf32)) as br: + with br.then_: + ... + pto.yield_(merged_max, merged_sum) + with br.else_: + pto.yield_(running_max, running_sum) + x, y = br.results + """ + return _IfCM(cond, results) + + +# ── yield_ ──────────────────────────────────────────────────────────────────── + +def yield_(*vals): + """Emit ``scf.yield`` with the given values.""" + scf.YieldOp(list(vals)) + + +__all__ = [ + "vecscope", "LoopHandle", "BranchHandle", + "for_", "if_", "yield_", +] diff --git a/ptodsl/ptodsl/_module.py b/ptodsl/ptodsl/_module.py new file mode 100644 index 000000000..2745504f8 --- /dev/null +++ b/ptodsl/ptodsl/_module.py @@ -0,0 +1,159 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +``@pto.to_ir`` decorator and module-level IR builders. + +The decorator: +1. Inspects the function signature – annotations are ``_DType`` lazy + descriptors or concrete ``mlir.ir.Type`` objects. +2. Creates the MLIR context and module. +3. Calls the Python function body with actual MLIR SSA values. +4. Verifies the module and caches it as ``fn._ir_module``. +5. Adds ``__str__`` so ``print(my_kernel)`` prints the MLIR text. + +Module structure is selected by ``func_attr``: +- ``func_attr="pto.aicore"`` → flat module + ``pto.aicore`` function attribute + (used by softmax-style kernels) +- otherwise → nested double-module (used by vPTO TADD-style) +""" + +import inspect + +from ._bootstrap import make_context +from ._types import _resolve + +from mlir.dialects import func, pto as _pto +from mlir.ir import ( + Attribute, + InsertionPoint, + Location, + Module, + Operation, + StringAttr, + UnitAttr, +) + + +def _call_body(ir_fn, fn, arg_types): + """Add entry block to *ir_fn* and call *fn* with the SSA arguments.""" + entry = ir_fn.add_entry_block() + with InsertionPoint(entry): + fn(*entry.arguments) + func.ReturnOp([]) + + +def _build_flat_module(fn_name, arg_types, fn, arch, kernel_kind): + """ + Flat ``module attributes {pto.target_arch, pto.kernel_kind}`` with a + single function that carries ``pto.aicore``. + """ + m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get(arch) + m.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{kernel_kind}>" + ) + fn_ty = func.FunctionType.get(arg_types, []) + with InsertionPoint(m.body): + ir_fn = func.FuncOp(fn_name, fn_ty) + ir_fn.attributes["pto.aicore"] = UnitAttr.get() + _call_body(ir_fn, fn, arg_types) + return m + + +def _build_nested_module(fn_name, arg_types, fn, arch, kernel_kind): + """ + Nested ``module { module { func … } }`` structure used by vPTO kernels + without function arguments (e.g. TADD). + """ + outer = Module.create() + outer.operation.attributes["pto.target_arch"] = StringAttr.get(arch) + + with InsertionPoint(outer.body): + # Module.create() ignores the active InsertionPoint, so create + # the inner module via Operation.create("builtin.module") instead. + inner_op = Operation.create("builtin.module", regions=1) + inner_op.attributes["pto.target_arch"] = StringAttr.get(arch) + inner_op.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{kernel_kind}>" + ) + inner_body = inner_op.regions[0].blocks.append() + + with InsertionPoint(inner_body): + fn_ty = func.FunctionType.get(arg_types, []) + ir_fn = func.FuncOp(fn_name, fn_ty) + + _call_body(ir_fn, fn, arg_types) + return outer + + +def to_ir(name=None, *, kernel_kind: str = "vector", arch: str = "a5", + func_attr: str = None): + """ + Decorator that eagerly lowers a Python function to an MLIR module. + + Parameters + ---------- + name: IR function name (defaults to the Python function name). + kernel_kind: ``"vector"`` or ``"cube"`` – sets ``pto.kernel_kind``. + arch: Target architecture string, e.g. ``"a5"``. + func_attr: Optional function attribute. Pass ``"pto.aicore"`` to + select the flat-module structure with the aicore attribute. + + The decorated function is replaced by a :class:`KernelHandle` that: + + - prints as the MLIR module text (``print(my_kernel)``), + - exposes ``my_kernel.build()`` returning the ``mlir.ir.Module``, + - exposes ``my_kernel._ir_module`` for direct access. + """ + + def decorator(fn): + fn_name = name or fn.__name__ + sig = inspect.signature(fn) + ctx = make_context() + with ctx, Location.unknown(): + arg_types = [ + _resolve(p.annotation) + for p in sig.parameters.values() + if p.annotation is not inspect.Parameter.empty + ] + if func_attr == "pto.aicore": + mod = _build_flat_module(fn_name, arg_types, fn, arch, kernel_kind) + else: + mod = _build_nested_module(fn_name, arg_types, fn, arch, kernel_kind) + mod.operation.verify() + + return KernelHandle(fn.__name__, mod) + + return decorator + + +class KernelHandle: + """ + Represents a compiled PTO kernel. + + ``print(handle)`` emits the MLIR module text. + ``handle.build()`` returns the ``mlir.ir.Module`` (for ``check_ir.py``). + ``handle._ir_module`` is the raw module for direct access. + """ + + def __init__(self, py_name: str, module): + self._py_name = py_name + self._ir_module = module + + def build(self): + """Return the compiled ``mlir.ir.Module``.""" + return self._ir_module + + def __str__(self): + return str(self._ir_module) + + def __repr__(self): + return str(self._ir_module) + + +__all__ = ["to_ir", "KernelHandle"] diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py new file mode 100644 index 000000000..0db485888 --- /dev/null +++ b/ptodsl/ptodsl/_ops.py @@ -0,0 +1,256 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +PTO operation wrappers. + +Every function in this module emits one or more MLIR operations at the +active insertion point and returns the primary SSA result(s). + +Design rules: +- Vector math ops infer the result type from the first operand's type. +- ``vlds`` / ``vbrc_load`` still require an explicit ``vreg_type`` argument + because the result type cannot be inferred from the pointer alone. +- ``make_tensor_view`` infers the TensorViewType from ``len(shape)`` and the + pointer's element type. +- ``partition_view`` infers the PartitionTensorViewType from the source type. +""" + +from ._bootstrap import make_context # noqa: F401 – ensure MLIR on sys.path +from ._types import _resolve, mask_type, part_tensor_view_type, tensor_view_type + +from mlir.dialects import arith, pto as _pto +from mlir.ir import ( + Attribute, + IndexType, + IntegerType, + ShapedType, + StringAttr, +) + +# Pipe name shorthands → canonical PIPE_* names +_PIPE_ALIASES = { + "MTE1": "PIPE_MTE1", + "MTE2": "PIPE_MTE2", + "MTE3": "PIPE_MTE3", + "MTE4": "PIPE_MTE4", + "V": "PIPE_V", + "M": "PIPE_M", + "S": "PIPE_S", + "ALL": "PIPE_ALL", +} + + +def _pipe_attr(name: str): + canonical = _PIPE_ALIASES.get(name, name) + if not canonical.startswith("PIPE_"): + canonical = "PIPE_" + canonical + return _pto.PipeAttr.get(getattr(_pto.PIPE, canonical)) + + +def _event_attr(event_id: int): + return getattr(_pto, f"EVENT_ID{event_id}") + + +# ── Constants ──────────────────────────────────────────────────────────────── + +def const(value: int, *, dtype=None): + """ + Emit an ``arith.constant``. + + ``dtype`` is a ``_DType`` descriptor or a concrete ``mlir.ir.Type``. + Defaults to ``index`` when omitted. + """ + from ._types import index as _idx_dtype + mlir_type = _resolve(dtype) if dtype is not None else _resolve(_idx_dtype) + return arith.ConstantOp(mlir_type, value).result + + +# ── Pointer ops ─────────────────────────────────────────────────────────────── + +def castptr(int_addr, result_ptr_type): + """``pto.castptr`` – cast an integer address to a typed PTO pointer.""" + return _pto.CastPtrOp(_resolve(result_ptr_type), int_addr).result + + +def addptr(base_ptr, index_offset): + """``pto.addptr`` – advance a pointer by an index offset.""" + return _pto.AddPtrOp(base_ptr, index_offset).result + + +# ── Vector load / store ─────────────────────────────────────────────────────── + +def vlds(src_ptr, offset, result_vreg_type): + """``pto.vlds`` – vector load from *src_ptr* at *offset*.""" + return _pto.VldsOp(_resolve(result_vreg_type), src_ptr, offset).result + + +def vbrc_load(src_ptr, offset, result_vreg_type): + """``pto.vlds {dist="BRC_B32"}`` – broadcast a scalar into all lanes.""" + return _pto.VldsOp(_resolve(result_vreg_type), src_ptr, offset, + dist="BRC_B32").result + + +def vsts(val, dst_ptr, offset, mask): + """``pto.vsts`` – vector store.""" + _pto.VstsOp(val, dst_ptr, offset, mask) + + +def vsts_1pt(val, dst_ptr, offset, mask): + """``pto.vsts {dist="1PT_B32"}`` – store only the lowest lane.""" + _pto.VstsOp(val, dst_ptr, offset, mask, dist="1PT_B32") + + +# ── Mask / predicate ops ────────────────────────────────────────────────────── + +def plt_b32(scalar): + """ + ``pto.plt_b32`` – predicate-load from a 32-bit scalar. + + Returns ``(mask_value, scalar_out)``. ``scalar_out`` is often unused + and can be discarded with ``_``. + """ + plt_op = _pto.PltB32Op(mask_type("b32"), IntegerType.get_signless(32), scalar) + return plt_op.mask, plt_op.scalar_out + + +def pset_b32(pattern: str): + """``pto.pset_b32 "PATTERN"`` → ``!pto.mask``.""" + return _pto.PsetB32Op(mask_type("b32"), pattern).result + + +# ── Vector math (result type inferred from first operand) ───────────────────── + +def vadd(lhs, rhs, mask, result_type=None): + """``pto.vadd`` – element-wise add.""" + rt = result_type if result_type is not None else lhs.type + return _pto.VaddOp(_resolve(rt), lhs, rhs, mask).result + + +def vmul(lhs, rhs, mask): + """``pto.vmul`` – element-wise multiply.""" + return _pto.VmulOp(lhs.type, lhs, rhs, mask).result + + +def vmax(lhs, rhs, mask): + """``pto.vmax`` – element-wise maximum.""" + return _pto.VmaxOp(lhs.type, lhs, rhs, mask).result + + +def vdiv(lhs, rhs, mask): + """``pto.vdiv`` – element-wise divide.""" + return _pto.VdivOp(lhs.type, lhs, rhs, mask).result + + +def vcmax(v, mask): + """``pto.vcmax`` – cross-lane maximum reduction.""" + return _pto.VcmaxOp(v.type, v, mask).result + + +def vcadd(v, mask): + """``pto.vcadd`` – cross-lane add (sum reduction).""" + return _pto.VcaddOp(v.type, v, mask).result + + +def vdup(v, mask, *, position=None): + """``pto.vdup`` – duplicate a lane value into all lanes. + + Pass ``position="LOWEST"`` to broadcast the lowest (lane-0) element. + """ + return _pto.VdupOp(v.type, v, mask, position=position).result + + +def vexpdif(inp, ref, mask, part: str = "ODD"): + """``pto.vexpdif`` – ``exp(inp - ref)`` selecting ODD or EVEN lanes.""" + return _pto.VexpdifOp(inp.type, inp, ref, mask, part).result + + +# ── Tile-domain operations ──────────────────────────────────────────────────── + +def make_tensor_view(ptr, *, shape, strides): + """ + ``pto.make_tensor_view`` – wrap a pointer as a tensor view. + + Type is inferred: rank from ``len(shape)``, element type from ``ptr``. + """ + rank = len(shape) + elem = _pto.PtrType(ptr.type).element_type + tv_type = tensor_view_type(rank, elem) + return _pto.MakeTensorViewOp(tv_type, ptr, list(shape), list(strides)).result + + +def partition_view(tv, *, offsets, sizes): + """ + ``pto.partition_view`` – slice a tensor view. + + Type is inferred from the source tensor-view type. + """ + src_type = _pto.TensorViewType(tv.type) + rank = src_type.rank + elem = src_type.element_type + ptv_type = part_tensor_view_type(rank, elem) + return _pto.PartitionViewOp(ptv_type, tv, list(offsets), list(sizes)).result + + +def alloc_tile(tile_type, *, addr, valid_row, valid_col=None): + """``pto.alloc_tile``.""" + return _pto.AllocTileOp(_resolve(tile_type), addr=addr, valid_row=valid_row, + valid_col=valid_col).result + + +def tload(part, tile): + """``pto.tload ins(part) outs(tile)``.""" + _pto.TLoadOp(None, part, tile) + + +def tstore(tile, part): + """``pto.tstore ins(tile) outs(part)``.""" + _pto.TStoreOp(None, tile, part) + + +def tile_ptr(tile, result_ptr_type): + """``pto.tile_buf_addr`` – materialise a UB pointer from a tile handle.""" + return _pto.TileBufAddrOp(_resolve(result_ptr_type), tile).result + + +# ── Hardware / sync ─────────────────────────────────────────────────────────── + +def get_block_idx(): + """``pto.get_block_idx`` → i64 block index.""" + return _pto.GetBlockIdxOp().result + + +def barrier_all(): + """``pto.barrier #pto.pipe``.""" + _pto.BarrierOp(_pipe_attr("ALL")) + + +def set_flag(src: str, dst: str, *, event_id: int = 0): + """``pto.set_flag[src, dst, event_id]``. + + Accepts short pipe names (``"MTE2"``, ``"V"``, …) or full ``"PIPE_MTE2"`` + names. ``event_id`` is an integer in ``[0, 7]``. + """ + _pto.set_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_id)) + + +def wait_flag(src: str, dst: str, *, event_id: int = 0): + """``pto.wait_flag[src, dst, event_id]``.""" + _pto.wait_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_id)) + + +__all__ = [ + "const", + "castptr", "addptr", + "vlds", "vbrc_load", "vsts", "vsts_1pt", + "plt_b32", "pset_b32", + "vadd", "vmul", "vmax", "vdiv", + "vcmax", "vcadd", "vdup", "vexpdif", + "make_tensor_view", "partition_view", + "alloc_tile", "tload", "tstore", "tile_ptr", + "get_block_idx", "barrier_all", "set_flag", "wait_flag", +] diff --git a/ptodsl/ptodsl/_types.py b/ptodsl/ptodsl/_types.py new file mode 100644 index 000000000..693b69a35 --- /dev/null +++ b/ptodsl/ptodsl/_types.py @@ -0,0 +1,176 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Lazy MLIR type descriptors and eager type constructors. + +Type descriptors (``_DType`` subclasses) can be created *before* any MLIR +Context exists – they only resolve to concrete ``mlir.ir.Type`` objects when +``_resolve()`` is called inside an active context. This lets users write:: + + def softmax(arg0: pto.ptr(pto.float32, "GM"), ...): + ... + +where the annotation is evaluated at *import* time (no active context), and +the actual type is materialised later by the ``@pto.to_ir`` decorator. +""" + +from ._bootstrap import make_context # ensure MLIR is on sys.path + +from mlir.dialects import pto as _pto +from mlir.ir import ( + F16Type, + F32Type, + IndexType, + IntegerType, + ShapedType, + Type, +) + +# ── Address-space name → AddressSpace enum ─────────────────────────────────── +_ADDR_SPACE = { + "ub": _pto.AddressSpace.VEC, # UB == unified buffer == VEC in PTO + "gm": _pto.AddressSpace.GM, + "vec": _pto.AddressSpace.VEC, + "mat": _pto.AddressSpace.MAT, + "GM": _pto.AddressSpace.GM, + "UB": _pto.AddressSpace.VEC, + "VEC": _pto.AddressSpace.VEC, + "MAT": _pto.AddressSpace.MAT, +} + + +# ── Lazy type descriptor base ───────────────────────────────────────────────── + +class _DType: + """Deferred MLIR type: only resolves inside an active MLIR context.""" + + def __init__(self, factory): + self._factory = factory + + def resolve(self) -> Type: + return self._factory() + + def __repr__(self): + return f"" + + +class _PtrDescriptor(_DType): + def __init__(self, elem, space: str): + self._elem = elem + self._space = space + + def resolve(self) -> Type: + elem = _resolve(self._elem) + space_enum = _ADDR_SPACE.get(self._space) + if space_enum is None: + raise ValueError( + f"Unknown address space '{self._space}'; " + f"known: {list(_ADDR_SPACE)}" + ) + space_attr = _pto.AddressSpaceAttr.get(space_enum) + return _pto.PtrType.get(elem, memory_space=space_attr) + + def __repr__(self): + return f"" + + +class _VRegDescriptor(_DType): + def __init__(self, lanes: int, elem): + self._lanes = lanes + self._elem = elem + + def resolve(self) -> Type: + elem = _resolve(self._elem) + return Type.parse(f"!pto.vreg<{self._lanes}x{elem}>") + + def __repr__(self): + return f"" + + +def _resolve(dtype) -> Type: + """Coerce a ``_DType`` descriptor or a concrete ``mlir.ir.Type`` to a Type.""" + if isinstance(dtype, _DType): + return dtype.resolve() + return dtype # already an mlir.ir.Type + + +# ── Scalar dtype singletons ─────────────────────────────────────────────────── + +float32 = _DType(F32Type.get) +float16 = _DType(F16Type.get) +int8 = _DType(lambda: IntegerType.get_signless(8)) +int16 = _DType(lambda: IntegerType.get_signless(16)) +int32 = _DType(lambda: IntegerType.get_signless(32)) +int64 = _DType(lambda: IntegerType.get_signless(64)) +index = _DType(IndexType.get) + + +# ── Type constructor functions ──────────────────────────────────────────────── + +def ptr(elem, space: str = "ub") -> _PtrDescriptor: + """Return a lazy descriptor for ``!pto.ptr``.""" + return _PtrDescriptor(elem, space) + + +def vreg_type(lanes: int, elem) -> _VRegDescriptor: + """Return a lazy descriptor for ``!pto.vreg``.""" + return _VRegDescriptor(lanes, elem) + + +def mask_type(bits: str = "b32") -> Type: + """Return ``!pto.mask`` (b8 | b16 | b32). Requires active context.""" + return Type.parse(f"!pto.mask<{bits}>") + + +def tile_buf_type(shape, dtype, valid_shape, *, + blayout: str = "RowMajor", + address_space: str = "ub", + slayout: str = "NoneBox", + fractal_size: int = 512, + pad: str = "Null") -> Type: + """ + Construct a ``!pto.tile_buf<…>`` type via the Python bindings. + + ``valid_shape`` entries may be ``-1`` for dynamic (``?``) dimensions. + ``blayout="ColMajor"`` prints as ``blayout=col_major``. + + Requires an active MLIR context. + """ + elem = _resolve(dtype) + space_enum = _ADDR_SPACE.get(address_space) + if space_enum is None: + raise ValueError( + f"Unknown address_space '{address_space}'; known: {list(_ADDR_SPACE)}" + ) + space_attr = _pto.AddressSpaceAttr.get(space_enum) + cfg = _pto.TileBufConfigAttr.get( + _pto.BLayoutAttr.get(getattr(_pto.BLayout, blayout)), + _pto.SLayoutAttr.get(getattr(_pto.SLayout, slayout)), + fractal_size, + _pto.PadValueAttr.get(getattr(_pto.PadValue, pad)), + ) + return _pto.TileBufType.get(shape, elem, space_attr, valid_shape, cfg) + + +def tensor_view_type(rank: int, elem) -> Type: + """``!pto.tensor_view`` with *rank* all-dynamic dims.""" + return _pto.TensorViewType.get(rank, _resolve(elem)) + + +def part_tensor_view_type(rank: int, elem) -> Type: + """``!pto.partition_tensor_view`` with *rank* all-dynamic dims.""" + kDynamic = ShapedType.get_dynamic_size() + return _pto.PartitionTensorViewType.get([kDynamic] * rank, _resolve(elem)) + + +__all__ = [ + "_DType", "_resolve", + "float32", "float16", "int8", "int16", "int32", "int64", "index", + "ptr", "vreg_type", "mask_type", + "tile_buf_type", "tensor_view_type", "part_tensor_view_type", +] diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py new file mode 100644 index 000000000..e28c34fe0 --- /dev/null +++ b/ptodsl/ptodsl/pto.py @@ -0,0 +1,58 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +``pto`` – the public DSL namespace. + +Import as:: + + import pto + +or as the sub-namespace ``pto`` from the ptodsl package:: + + from ptodsl import pto + +All user-facing symbols live here. Low-level MLIR bindings are accessed +internally as ``_pto`` (``from mlir.dialects import pto as _pto``). +""" + +# ── Types ───────────────────────────────────────────────────────────────────── +from ._types import ( # noqa: F401 + float32, float16, + int8, int16, int32, int64, + index, + ptr, vreg_type, mask_type, + tile_buf_type, tensor_view_type, part_tensor_view_type, + _resolve, +) + +# ── Operations ──────────────────────────────────────────────────────────────── +from ._ops import ( # noqa: F401 + const, + castptr, addptr, + vlds, vbrc_load, vsts, vsts_1pt, + plt_b32, pset_b32, + vadd, vmul, vmax, vdiv, + vcmax, vcadd, vdup, vexpdif, + make_tensor_view, partition_view, + alloc_tile, tload, tstore, tile_ptr, + get_block_idx, barrier_all, + set_flag, wait_flag, +) + +# ── Control flow ────────────────────────────────────────────────────────────── +from ._control_flow import ( # noqa: F401 + vecscope, + for_, if_, yield_, + LoopHandle, BranchHandle, +) + +# ── Decorator ───────────────────────────────────────────────────────────────── +from ._module import to_ir, KernelHandle # noqa: F401 + +# ── Scalar sub-namespace ────────────────────────────────────────────────────── +from . import scalar # noqa: F401 diff --git a/ptodsl/ptodsl/scalar.py b/ptodsl/ptodsl/scalar.py new file mode 100644 index 000000000..4902112c0 --- /dev/null +++ b/ptodsl/ptodsl/scalar.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Scalar arithmetic helpers – exposed as ``pto.scalar.*`` (or ``s = pto.scalar``). + +All functions operate on raw ``mlir.ir.Value`` objects and emit the +corresponding arith dialect operations at the active insertion point. +""" + +from ._bootstrap import make_context # ensure MLIR is on sys.path # noqa: F401 +from ._types import _resolve + +from mlir.dialects import arith +from mlir.ir import IndexType, IntegerType + +_CMPI_PREDICATES = { + "eq": arith.CmpIPredicate.eq, + "ne": arith.CmpIPredicate.ne, + "slt": arith.CmpIPredicate.slt, + "sle": arith.CmpIPredicate.sle, + "sgt": arith.CmpIPredicate.sgt, + "sge": arith.CmpIPredicate.sge, + "ult": arith.CmpIPredicate.ult, + "ule": arith.CmpIPredicate.ule, + "ugt": arith.CmpIPredicate.ugt, + "uge": arith.CmpIPredicate.uge, +} + + +def muli(lhs, rhs): + """arith.muli""" + return arith.MulIOp(lhs, rhs).result + + +def addi(lhs, rhs): + """arith.addi""" + return arith.AddIOp(lhs, rhs).result + + +def subi(lhs, rhs): + """arith.subi""" + return arith.SubIOp(lhs, rhs).result + + +def index_cast(type_or_val, val=None): + """ + arith.index_cast. + + Two calling conventions:: + + index_cast(result_type, value) # explicit result type + index_cast(value) # result type = index (1-arg shorthand) + """ + if val is None: + # 1-arg form: cast to index + return arith.IndexCastOp(IndexType.get(), type_or_val).result + return arith.IndexCastOp(_resolve(type_or_val), val).result + + +def cmpi(pred: str, lhs, rhs): + """ + arith.cmpi with a named predicate string. + + ``pred`` is one of: ``"eq"``, ``"ne"``, ``"slt"``, ``"sle"``, + ``"sgt"``, ``"sge"``, ``"ult"``, ``"ule"``, ``"ugt"``, ``"uge"``. + """ + predicate = _CMPI_PREDICATES.get(pred) + if predicate is None: + raise ValueError( + f"Unknown cmpi predicate '{pred}'; known: {list(_CMPI_PREDICATES)}" + ) + return arith.CmpIOp(predicate, lhs, rhs).result + + +def cmpi_sgt(lhs, rhs): + """arith.cmpi sgt (signed greater-than).""" + return arith.CmpIOp(arith.CmpIPredicate.sgt, lhs, rhs).result + + +def select(cond, true_val, false_val): + """arith.select""" + return arith.SelectOp(cond, true_val, false_val).result + + +__all__ = ["muli", "addi", "subi", "index_cast", "cmpi", "cmpi_sgt", "select"] diff --git a/ptodsl/ptodsl_utils.py b/ptodsl/ptodsl_utils.py deleted file mode 100644 index b220239d4..000000000 --- a/ptodsl/ptodsl_utils.py +++ /dev/null @@ -1,574 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -""" -Lightweight wrappers around the low-level MLIR Python bindings for the PTO -dialect. The goal is to eliminate boilerplate so that a vPTO kernel body can -be written in plain-looking Python without manual InsertionPoint management, -verbose type constructors, or raw Operation.create() calls. - -Design rules -──────────── -• Every helper is a plain function or a contextlib.contextmanager – no classes. -• All helpers work with the *current* MLIR context / location / insertion-point - (set by `pto_context` and `vpto_kernel`); no context parameter is threaded. -• The module is self-contained: only mlir.* imports are allowed. -""" - -from contextlib import contextmanager - -from mlir.ir import ( - Attribute, - Context, - IntegerType, - IndexType, - InsertionPoint, - Location, - Module, - Operation, - ShapedType, - StringAttr, - Type, - UnitAttr, -) -from mlir.dialects import arith, func, pto, scf - -# Mapping from the textual address-space name used in !pto.ptr -# to the AddressSpace enum value exposed by the C extension. -_ADDR_SPACE = { - "ub": pto.AddressSpace.VEC, # "ub" (unified buffer) prints as VEC - "gm": pto.AddressSpace.GM, - "vec": pto.AddressSpace.VEC, - "l1": pto.AddressSpace.MAT, -} - - -# ─── Type constructors ──────────────────────────────────────────────────────── - -def i32_type(): - """Signless 32-bit integer type.""" - return IntegerType.get_signless(32) - - -def i64_type(): - """Signless 64-bit integer type.""" - return IntegerType.get_signless(64) - - -def idx_type(): - """MLIR index type.""" - return IndexType.get() - - -def ptr_type(elem_type, space="ub"): - """PTO pointer type: !pto.ptr<{elem_type}, {space}>. - - Uses ``pto.PtrType.get`` with an ``AddressSpaceAttr`` when the address-space - name is known; falls back to ``Type.parse`` for unknown spaces. - """ - enum_val = _ADDR_SPACE.get(space) - if enum_val is not None: - space_attr = pto.AddressSpaceAttr.get(enum_val) - return pto.PtrType.get(elem_type, memory_space=space_attr) - return Type.parse(f"!pto.ptr<{elem_type}, {space}>") - - -def vreg_type(lanes, elem_type): - """PTO vector-register type: !pto.vreg<{lanes}x{elem_type}>. - - VRegType has no Python-binding constructor; Type.parse is the only path. - """ - return Type.parse(f"!pto.vreg<{lanes}x{elem_type}>") - - -def mask_type(bits="b32"): - """PTO mask/predicate type: !pto.mask<{bits}> (b8 | b16 | b32). - - MaskType has no Python-binding constructor; Type.parse is the only path. - """ - return Type.parse(f"!pto.mask<{bits}>") - - -def tensor_view_type(rank, elem_type): - """PTO tensor-view type with all-dynamic dimensions: !pto.tensor_view. - - Uses ``pto.TensorViewType.get(rank, elem_type)``. - """ - return pto.TensorViewType.get(rank, elem_type) - - -def part_tensor_view_type(rank, elem_type): - """PTO partition-tensor-view type with all-dynamic dims: !pto.partition_tensor_view. - - Uses ``pto.PartitionTensorViewType.get([kDynamic]*rank, elem_type)``. - ``ShapedType.get_dynamic_size()`` (``INT64_MIN``) is the correct MLIR - sentinel; plain ``-1`` would produce a different printed form. - """ - kDynamic = ShapedType.get_dynamic_size() - return pto.PartitionTensorViewType.get([kDynamic] * rank, elem_type) - - -def tile_buf_type(shape, elem_type, valid_shape, *, - blayout="RowMajor", address_space="ub", - slayout="NoneBox", fractal_size=512, pad="Null"): - """PTO tile-buffer type via ``pto.TileBufType.get``. - - ``valid_shape`` entries may be ``-1`` for dynamic (``?``) dimensions. - ``blayout`` selects the block layout: ``"RowMajor"`` (default, omitted in - the printed form) or ``"ColMajor"`` (printed as ``blayout=col_major``). - - Common usage:: - - # !pto.tile_buf - tile_buf_type([8, 128], f32, [-1, -1]) - - # !pto.tile_buf - tile_buf_type([8, 1], f32, [-1, 1], blayout="ColMajor") - """ - space_enum = _ADDR_SPACE.get(address_space) - if space_enum is None: - raise ValueError(f"Unknown address_space '{address_space}'; " - f"known: {list(_ADDR_SPACE)}") - space_attr = pto.AddressSpaceAttr.get(space_enum) - cfg = pto.TileBufConfigAttr.get( - pto.BLayoutAttr.get(getattr(pto.BLayout, blayout)), - pto.SLayoutAttr.get(getattr(pto.SLayout, slayout)), - fractal_size, - pto.PadValueAttr.get(getattr(pto.PadValue, pad)), - ) - return pto.TileBufType.get(shape, elem_type, space_attr, valid_shape, cfg) - - -# ─── Constant builders ─────────────────────────────────────────────────────── - -def c_idx(value): - """Emit an index constant.""" - return arith.ConstantOp(IndexType.get(), value).result - - -def c_i32(value): - """Emit a 32-bit integer constant.""" - return arith.ConstantOp(IntegerType.get_signless(32), value).result - - -def c_i64(value): - """Emit a 64-bit integer constant.""" - return arith.ConstantOp(IntegerType.get_signless(64), value).result - - -# ─── Arithmetic shorthands ─────────────────────────────────────────────────── - -def muli(lhs, rhs): - """arith.muli""" - return arith.MulIOp(lhs, rhs).result - - -def addi(lhs, rhs): - """arith.addi""" - return arith.AddIOp(lhs, rhs).result - - -def subi(lhs, rhs): - """arith.subi""" - return arith.SubIOp(lhs, rhs).result - - -# ─── PTO vector / pointer operations ──────────────────────────────────────── - -def castptr(int_addr, result_ptr_type): - """Cast an integer address to a typed PTO pointer (pto.castptr).""" - return pto.CastPtrOp(result_ptr_type, int_addr).result - - -def addptr(base_ptr, index_offset): - """Advance a PTO pointer by an index offset (pto.addptr).""" - return pto.AddPtrOp(base_ptr, index_offset).result - - -def vlds(src_ptr, offset, result_vreg_type): - """Vector load from a PTO pointer at *offset* (pto.vlds).""" - return pto.VldsOp(result_vreg_type, src_ptr, offset).result - - -def vadd(lhs, rhs, mask, result_vreg_type): - """Element-wise vector add under a predicate mask (pto.vadd).""" - return pto.VaddOp(result_vreg_type, lhs, rhs, mask).result - - -def vsts(val, dst_ptr, offset, mask): - """Vector store to a PTO pointer at *offset* under a mask (pto.vsts).""" - pto.VstsOp(val, dst_ptr, offset, mask) - - -def plt_b32(scalar): - """ - Predicate-load from a 32-bit scalar value (pto.plt_b32). - - Returns (mask_value, scalar_out) – the mask is typically the only value - used downstream; scalar_out can be discarded with ``_``. - """ - plt_op = pto.PltB32Op(mask_type("b32"), i32_type(), scalar) - return plt_op.mask, plt_op.scalar_out - - -# ─── Scope context managers ────────────────────────────────────────────────── - -@contextmanager -def vecscope(): - """ - Emit a ``pto.vecscope { ... }`` region. - - Usage:: - - with vecscope(): - ptr = castptr(addr, ptr_f32) - ... - """ - op = pto.VecScopeOp() - block = op.body.blocks.append() - with InsertionPoint(block): - yield - - -@contextmanager -def for_range(start, stop, step): - """ - Emit an ``scf.for`` loop; yield the induction variable. - The mandatory ``scf.yield`` terminator is inserted automatically on exit. - - Usage:: - - with for_range(c0, c16, c1) as i: - off = muli(i, c64) - ... - """ - for_op = scf.ForOp(start, stop, step) - with InsertionPoint(for_op.body): - yield for_op.induction_variable - scf.YieldOp([]) - - -# ─── Top-level module / kernel builder ─────────────────────────────────────── - -@contextmanager -def pto_context(): - """ - Activate an MLIR context with the PTO dialect registered. - Must wrap all other utility calls. - - Usage:: - - with pto_context(): - f32 = F32Type.get() - with vpto_kernel("MyKernel", arch="a5") as mod: - ... - """ - with Context() as ctx: - pto.register_dialect(ctx, load=True) - with Location.unknown(): - yield ctx - - -@contextmanager -def vpto_kernel(func_name, *, arch="a5"): - """ - Build the standard two-level nested-module + no-arg ``func.func`` shell - for a vPTO vector kernel, then yield the outer ``Module`` as the context - variable. ``func.ReturnOp`` and ``module.verify()`` are inserted/called - automatically on context exit. - - The emitted skeleton is:: - - module attributes {pto.target_arch = arch} { - module attributes {pto.kernel_kind = #pto.kernel_kind, - pto.target_arch = arch} { - func.func @func_name() { - - return - } - } - } - - Usage:: - - with vpto_kernel("TADD", arch="a5") as mod: - c0 = c_idx(0) - ... - return mod - """ - arch_attr = StringAttr.get(arch) - kind_attr = Attribute.parse("#pto.kernel_kind") - - outer_mod = Module.create() - outer_mod.operation.attributes["pto.target_arch"] = arch_attr - - with InsertionPoint(outer_mod.body): - # Module.create() ignores the active InsertionPoint, so use - # Operation.create("builtin.module") to insert the inner module. - inner_op = Operation.create("builtin.module", regions=1) - inner_op.attributes["pto.target_arch"] = arch_attr - inner_op.attributes["pto.kernel_kind"] = kind_attr - inner_body = inner_op.regions[0].blocks.append() - - with InsertionPoint(inner_body): - fn = func.FuncOp(func_name, func.FunctionType.get([], [])) - entry = fn.add_entry_block() - - with InsertionPoint(entry): - yield outer_mod - func.ReturnOp([]) - - outer_mod.operation.verify() - - -# ─── Flat single-module builders (for direct func inside module) ───────────── - -@contextmanager -def flat_pto_module(arch="a5"): - """ - Flat single-level module with ``pto.target_arch`` and - ``pto.kernel_kind = #pto.kernel_kind``. - - Usage:: - - with flat_pto_module("a5") as mod: - with pto_aicore_func("MyKernel", [ptr_gm, i32]) as args: - ... - return mod - """ - m = Module.create() - m.operation.attributes["pto.target_arch"] = StringAttr.get(arch) - m.operation.attributes["pto.kernel_kind"] = Attribute.parse( - "#pto.kernel_kind" - ) - with InsertionPoint(m.body): - yield m - m.operation.verify() - - -@contextmanager -def pto_aicore_func(func_name, arg_types, *, ret_types=None): - """ - Create a ``func.func`` with the ``pto.aicore`` attribute. - Yields the function's block arguments tuple. - ``func.return`` is inserted automatically on exit. - - Usage:: - - with pto_aicore_func("f", [ptr_gm, ptr_gm, i32]) as (p0, p1, n): - ... - """ - fn_ty = func.FunctionType.get(arg_types, ret_types or []) - fn = func.FuncOp(func_name, fn_ty) - fn.attributes["pto.aicore"] = UnitAttr.get() - entry = fn.add_entry_block() - with InsertionPoint(entry): - yield tuple(entry.arguments) - func.ReturnOp([]) - - -# ─── Additional control-flow helpers ───────────────────────────────────────── - -@contextmanager -def if_ctx(cond): - """ - Emit ``scf.if cond { ... }`` with no results and no else branch. - The mandatory ``scf.yield`` terminator is inserted automatically. - - Usage:: - - with if_ctx(has_rows): - tload(part, tile) - ... - """ - op = scf.IfOp(cond) - with InsertionPoint(op.then_block): - yield - scf.YieldOp([]) - - -def if_op_returning(cond, result_types): - """ - Create a ``scf.if`` with results *and* an else branch. - Returns the raw ``IfOp`` so the caller can manage the two blocks - manually with ``InsertionPoint`` and close each with ``yield_vals()``. - - Usage:: - - br = if_op_returning(has_chunk, [vreg_f32, vreg_f32]) - with InsertionPoint(br.then_block): - ... - yield_vals(merged_max, merged_sum) - with InsertionPoint(br.else_block): - yield_vals(running_max, running_sum) - next_max, next_sum = br.results - """ - return scf.IfOp(cond, result_types, hasElse=True) - - -@contextmanager -def for_range_iter(start, stop, step, init_vals): - """ - Emit ``scf.for`` with iter_args. Yields the raw ``ForOp`` so the - caller can access ``induction_variable``, ``inner_iter_args``, and - ``results`` (after the ``with`` block). - - The caller **must** call ``yield_vals(...)`` at the end of the body. - - Usage:: - - with for_range_iter(c0, c128, c64, [a, b]) as cf: - i = cf.induction_variable - x, y = cf.inner_iter_args - ... - yield_vals(new_x, new_y) - final_x, final_y = cf.results - """ - for_op = scf.ForOp(start, stop, step, init_vals) - with InsertionPoint(for_op.body): - yield for_op - - -def yield_vals(*vals): - """Emit ``scf.yield`` with the given values (shorthand for scf.YieldOp).""" - scf.YieldOp(list(vals)) - - -# ─── Arithmetic helpers ─────────────────────────────────────────────────────── - -def index_cast(result_type, val): - """arith.index_cast from/to index.""" - return arith.IndexCastOp(result_type, val).result - - -def cmpi_sgt(lhs, rhs): - """arith.cmpi sgt (signed greater-than).""" - return arith.CmpIOp(arith.CmpIPredicate.sgt, lhs, rhs).result - - -def select_val(cond, true_val, false_val): - """arith.select.""" - return arith.SelectOp(cond, true_val, false_val).result - - -# ─── PTO hardware helpers ───────────────────────────────────────────────────── - -def get_block_idx(): - """pto.get_block_idx → i64 block index.""" - return pto.GetBlockIdxOp().result - - -def barrier_all(): - """pto.barrier #pto.pipe.""" - pto.BarrierOp(pto.PipeAttr.get(pto.PIPE.PIPE_ALL)) - - -# ─── Tile-domain helpers ────────────────────────────────────────────────────── - -def tile_view(tv_type, ptr, shape, strides): - """pto.make_tensor_view → tensor_view SSA value.""" - return pto.MakeTensorViewOp(tv_type, ptr, shape, strides).result - - -def part_view(ptv_type, tv, offsets, sizes): - """pto.partition_view → partition_tensor_view SSA value.""" - return pto.PartitionViewOp(ptv_type, tv, offsets, sizes).result - - -def alloc_tile(tile_type, *, addr, valid_row, valid_col=None): - """pto.alloc_tile with optional valid_col.""" - return pto.AllocTileOp(tile_type, addr=addr, valid_row=valid_row, - valid_col=valid_col).result - - -def tload(part, tile): - """pto.tload ins(part) outs(tile).""" - pto.TLoadOp(None, part, tile) - - -def tstore(tile, part): - """pto.tstore ins(tile) outs(part).""" - pto.TStoreOp(None, tile, part) - - -def tile_ptr(tile, result_ptr_type): - """pto.tile_buf_addr – materialise a UB pointer from a tile handle.""" - return pto.TileBufAddrOp(result_ptr_type, tile).result - - -# ─── Mask helpers ───────────────────────────────────────────────────────────── - -def pset_b32(pattern): - """pto.pset_b32 "PATTERN" → !pto.mask (all-true when "PAT_ALL").""" - return pto.PsetB32Op(mask_type("b32"), pattern).result - - -# ─── Vector load / store with dist attribute ────────────────────────────────── - -def vbrc_load(src_ptr, offset, result_vreg_type): - """pto.vlds with dist="BRC_B32" – broadcast a scalar into all lanes.""" - return pto.VldsOp(result_vreg_type, src_ptr, offset, - dist="BRC_B32").result - - -def vsts_1pt(val, dst_ptr, offset, mask): - """pto.vsts with dist="1PT_B32" – store only the lowest lane.""" - pto.VstsOp(val, dst_ptr, offset, mask, dist="1PT_B32") - - -# ─── Vector math (result type inferred from first operand) ──────────────────── -# -# These wrappers follow the convention: if result_type is None the type is -# taken from the first operand (all PTO binary vector ops return the same -# type as their inputs). -# - -def vcmax(v, mask): - """pto.vcmax – cross-lane maximum reduction.""" - return pto.VcmaxOp(v.type, v, mask).result - - -def vdup_lowest(v, mask): - """pto.vdup {position="LOWEST"} – broadcast lane-0 to all lanes.""" - return pto.VdupOp(v.type, v, mask, position="LOWEST").result - - -def vmax(lhs, rhs, mask): - """pto.vmax – element-wise maximum.""" - return pto.VmaxOp(lhs.type, lhs, rhs, mask).result - - -def vexpdif(inp, ref, mask, part="ODD"): - """pto.vexpdif – exp(inp − ref), selecting ODD or EVEN lanes.""" - return pto.VexpdifOp(inp.type, inp, ref, mask, part).result - - -def vmul(lhs, rhs, mask): - """pto.vmul – element-wise multiply.""" - return pto.VmulOp(lhs.type, lhs, rhs, mask).result - - -def vcadd(v, mask): - """pto.vcadd – cross-lane add (sum reduction).""" - return pto.VcaddOp(v.type, v, mask).result - - -def vdiv(lhs, rhs, mask): - """pto.vdiv – element-wise divide.""" - return pto.VdivOp(lhs.type, lhs, rhs, mask).result - - -# Override vadd to make result_type optional (inferred from lhs when omitted) -_vadd_impl = vadd - - -def vadd(lhs, rhs, mask, result_type=None): # type: ignore[misc] - """pto.vadd – element-wise add (result_type inferred from lhs if None).""" - rt = result_type if result_type is not None else lhs.type - return pto.VaddOp(rt, lhs, rhs, mask).result - diff --git a/ptodsl/pyproject.toml b/ptodsl/pyproject.toml new file mode 100644 index 000000000..07762191c --- /dev/null +++ b/ptodsl/pyproject.toml @@ -0,0 +1,13 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "ptodsl" +version = "0.1.0" +description = "PTO MLIR DSL – Pythonic JIT-compiler-style IR builder for the PTO dialect" +requires-python = ">=3.9" + +[tool.setuptools.packages.find] +where = ["."] +include = ["ptodsl*"] diff --git a/ptodsl/softmax_builder_highlevel.py b/ptodsl/softmax_builder_highlevel.py deleted file mode 100644 index bc600484a..000000000 --- a/ptodsl/softmax_builder_highlevel.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -""" -High-level builder for the online softmax kernel. - -Reconstructs the same IR as softmax_builder_lowlevel.py using the -thin wrappers in ptodsl_utils. Compare the two files side by side to see -which boilerplate the utils eliminate. -""" - -from mlir.ir import F32Type, InsertionPoint - -from ptodsl_utils import ( - # context / types - pto_context, flat_pto_module, pto_aicore_func, - i32_type, i64_type, idx_type, ptr_type, - tensor_view_type, part_tensor_view_type, tile_buf_type, - vreg_type, - # constants - c_idx, c_i32, c_i64, - # arithmetic - muli, addi, subi, index_cast, cmpi_sgt, select_val, - # hardware - get_block_idx, barrier_all, - # tile domain - tile_view, part_view, alloc_tile, tload, tstore, tile_ptr, - # sync (pto.set_flag / pto.wait_flag come from pto module directly) - # vector / pointer - castptr, addptr, vlds, vsts, - plt_b32, pset_b32, vbrc_load, vsts_1pt, - # vector math - vcmax, vdup_lowest, vmax, vexpdif, vmul, vcadd, vadd, vdiv, - # control flow - vecscope, for_range, for_range_iter, yield_vals, - if_ctx, if_op_returning, -) -from mlir.dialects import pto - - -def build(): - with pto_context(): - # ── Types used throughout the kernel ────────────────────────────── - f32 = F32Type.get() - i32 = i32_type() - i64 = i64_type() - idx = idx_type() - ptr_gm = ptr_type(f32, "gm") # !pto.ptr - ptr_ub = ptr_type(f32, "ub") # !pto.ptr - tv5d = tensor_view_type(5, f32) # !pto.tensor_view - ptv5d = part_tensor_view_type(5, f32) # !pto.partition_tensor_view - tile_col = tile_buf_type([8, 1], f32, [-1, 1], blayout="ColMajor") # valid=?x1, col_major - tile_w = tile_buf_type([8, 128], f32, [-1, -1]) # valid=?x? - vf32 = vreg_type(64, f32) # !pto.vreg<64xf32> - - with flat_pto_module("a5") as mod: - with pto_aicore_func( - "online_softmax_update_kernel_2d", - [ptr_gm] * 7 + [i32, i32], - ) as (a0, a1, a2, a3, a4, a5, a6, arg7, arg8): - - # ── Index constants ──────────────────────────────────── - c0, c1, c8, c64, c128 = (c_idx(v) for v in (0, 1, 8, 64, 128)) - - # ── i64 constants ───────────────────────────────────── - # Declared in the same order as the reference IR so that - # the round-tripped MLIR text compares equal. - c0_i64 = c_i64(0) - _c1_i64 = c_i64(1) # present in reference, unused here - _c8_i64 = c_i64(8) - _c16_i64 = c_i64(16) - _c32_i64 = c_i64(32) - _c64_i64 = c_i64(64) - c128_i64 = c_i64(128) - c256_i64 = c_i64(256) - _c512_i64 = c_i64(512) - c8448_i64 = c_i64(8448) - c16640_i64 = c_i64(16640) - c16768_i64 = c_i64(16768) - c16896_i64 = c_i64(16896) - - # ── i32 constants ────────────────────────────────────── - c1_i32 = c_i32(1); c8_i32 = c_i32(8) - c64_i32 = c_i32(64); c0_i32 = c_i32(0) - - # ── Block-level row assignment ───────────────────────── - block_i64 = get_block_idx() - block_idx = index_cast(idx, block_i64) - row_base = muli(block_idx, c8) - _ = index_cast(i32, c8) # block_rows_i32 - row_base_i32 = index_cast(i32, row_base) - remaining_rows= subi(arg8, row_base_i32) - has_rows = cmpi_sgt(remaining_rows, c0_i32) - too_many_rows = cmpi_sgt(remaining_rows, c8_i32) - row_count_i32 = select_val(too_many_rows, c8_i32, remaining_rows) - row_count = index_cast(idx, row_count_i32) - seq = index_cast(idx, arg7) - rows = index_cast(idx, arg8) - rows_x_128 = muli(rows, c128) - - with if_ctx(has_rows): - # ── Tensor views ─────────────────────────────────── - s1 = [rows, rows, rows, c1, rows] - s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] - sh1 = [c1, c1, c1, rows, c1] - sh128= [c1, c1, c1, rows, c128] - - oldmax_view = tile_view(tv5d, a0, sh1, s1) - oldsum_view = tile_view(tv5d, a1, sh1, s1) - qk_view = tile_view(tv5d, a2, sh128, s128) - newmax_view = tile_view(tv5d, a3, sh1, s1) - newsum_view = tile_view(tv5d, a4, sh1, s1) - expmax_view = tile_view(tv5d, a5, sh1, s1) - out_view = tile_view(tv5d, a6, sh128, s128) - - # ── Partition views ──────────────────────────────── - off = [c0, c0, c0, row_base, c0] - z1 = [c1, c1, c1, row_count, c1] - zs = [c1, c1, c1, row_count, seq] - - oldmax_part = part_view(ptv5d, oldmax_view, off, z1) - oldsum_part = part_view(ptv5d, oldsum_view, off, z1) - qk_part = part_view(ptv5d, qk_view, off, zs) - newmax_part = part_view(ptv5d, newmax_view, off, z1) - newsum_part = part_view(ptv5d, newsum_view, off, z1) - expmax_part = part_view(ptv5d, expmax_view, off, z1) - out_part = part_view(ptv5d, out_view, off, zs) - - # ── UB tile allocation ───────────────────────────── - oldmax_tile = alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) - oldsum_tile = alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) - qk_tile = alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) - out_tile = alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) - newmax_tile = alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) - newsum_tile = alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) - expmax_tile = alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) - - # ── Tile loads from GM ───────────────────────────── - tload(oldmax_part, oldmax_tile) - tload(oldsum_part, oldsum_tile) - tload(qk_part, qk_tile) - - pto.set_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) - pto.wait_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) - - with vecscope(): - # Materialise typed UB pointers from tile handles - ub_om = tile_ptr(oldmax_tile, ptr_ub) - ub_os = tile_ptr(oldsum_tile, ptr_ub) - ub_qk = tile_ptr(qk_tile, ptr_ub) - ub_out= tile_ptr(out_tile, ptr_ub) - ub_nm = tile_ptr(newmax_tile, ptr_ub) - ub_ns = tile_ptr(newsum_tile, ptr_ub) - ub_em = tile_ptr(expmax_tile, ptr_ub) - - active = pset_b32("PAT_ALL") - one_mask, _ = plt_b32(c1_i32) - - with for_range(c0, row_count, c1) as row: - row_qk = muli(row, c128) - oldmax_bc = vbrc_load(ub_om, row, vf32) - oldsum_bc = vbrc_load(ub_os, row, vf32) - - # ── Chunk loop: compute running max & sum ── - with for_range_iter(c0, c128, c64, - [oldmax_bc, oldsum_bc]) as cf: - chunk = cf.induction_variable - running_max, running_sum = cf.inner_iter_args - - rem_cols = subi(arg7, index_cast(i32, chunk)) - has_chunk = cmpi_sgt(rem_cols, c0_i32) - - br = if_op_returning(has_chunk, [vf32, vf32]) - with InsertionPoint(br.then_block): - cmask, _ = plt_b32(rem_cols) - cbase = addi(row_qk, chunk) - vec = vlds(ub_qk, cbase, vf32) - cmax = vcmax(vec, cmask) - cmax_bc = vdup_lowest(cmax, active) - mmax = vmax(running_max, cmax_bc, active) - sc_run = vexpdif(running_max, mmax, active) - rs_sc = vmul(sc_run, running_sum, active) - c_exp = vexpdif(vec, mmax, cmask) - c_sum = vcadd(c_exp, cmask) - c_sum_bc = vdup_lowest(c_sum, active) - m_sum = vadd(rs_sc, c_sum_bc, active) - yield_vals(mmax, m_sum) - with InsertionPoint(br.else_block): - yield_vals(running_max, running_sum) - - yield_vals(*br.results) - - final_max, final_sum = cf.results - - # ── Compute expmax scalar for this row ───── - raw_em = vexpdif(oldmax_bc, final_max, active) - sc_os = vmul(raw_em, oldsum_bc, active) - expmax = vdiv(sc_os, final_sum, active) - - vsts_1pt(final_max, ub_nm, row, one_mask) - vsts_1pt(final_sum, ub_ns, row, one_mask) - vsts_1pt(expmax, ub_em, row, one_mask) - - # ── Output normalisation loop ────────────── - with for_range(c0, c128, c64) as chunk2: - rem2 = subi(arg7, index_cast(i32, chunk2)) - has_c2 = cmpi_sgt(rem2, c0_i32) - with if_ctx(has_c2): - cmask2, _ = plt_b32(rem2) - cbase2 = addi(row_qk, chunk2) - vec2 = vlds(ub_qk, cbase2, vf32) - exp2 = vexpdif(vec2, final_max, cmask2) - out2 = vdiv(exp2, final_sum, cmask2) - vsts(out2, ub_out, cbase2, cmask2) - - pto.set_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) - pto.wait_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) - - # ── Tile stores to GM ────────────────────────────── - tstore(newmax_tile, newmax_part) - tstore(newsum_tile, newsum_part) - tstore(expmax_tile, expmax_part) - tstore(out_tile, out_part) - - barrier_all() - - return mod - - -if __name__ == "__main__": - print(build()) diff --git a/ptodsl/softmax_builder_suggested.py b/ptodsl/softmax_builder_suggested.py deleted file mode 100644 index 3c2340af1..000000000 --- a/ptodsl/softmax_builder_suggested.py +++ /dev/null @@ -1,204 +0,0 @@ -# Minimum Pythonic mapping of test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto -import pto -s = pto.scalar - -@pto.to_ir( - name="online_softmax_update_kernel_2d", # default to function name if not given - kernel_kind="vector", - arch="a5", - func_attr="pto.aicore" -) -def softmax_demo( - arg0: pto.ptr(pto.float32, "GM"), - arg2: pto.ptr(pto.float32, "GM"), - arg3: pto.ptr(pto.float32, "GM"), - arg4: pto.ptr(pto.float32, "GM"), - arg5: pto.ptr(pto.float32, "GM"), - arg6: pto.ptr(pto.float32, "GM"), - arg7: pto.int32, - arg8: pto.int32 -) - c0 = pto.const(0) - c1 = pto.const(1) - c8 = pto.const(8) - c64 = pto.const(64) - c128 = pto.const(128) - - c0_i64 = pto.const(0, dtype=pto.int64) - c1_i64 = pto.const(1, dtype=pto.int64) - c8_i64 = pto.const(8, dtype=pto.int64) - c16_i64 = pto.const(16, dtype=pto.int64) - c32_i64 = pto.const(32, dtype=pto.int64) - c64_i64 = pto.const(64, dtype=pto.int64) - c128_i64 = pto.const(128, dtype=pto.int64) - c256_i64 = pto.const(256, dtype=pto.int64) - c512_i64 = pto.const(512, dtype=pto.int64) - c8448_i64 = pto.const(8448, dtype=pto.int64) - c16640_i64 = pto.const(16640, dtype=pto.int64) - c16768_i64 = pto.const(16768, dtype=pto.int64) - c16896_i64 = pto.const(16896, dtype=pto.int64) - - c1_i32 = pto.const(1, dtype=pto.int32) - c8_i32 = pto.const(8, dtype=pto.int32) - c64_i32 = pto.const(64, dtype=pto.int32) - c0_i32 = pto.const(0, dtype=pto.int32) - - block_i64 = pto.get_block_idx() - block_idx = s.index_cast(idx, block_i64) - row_base = s.muli(block_idx, c8) - _ = s.index_cast(i32, c8) # block_rows_i32 - row_base_i32 = s.index_cast(i32, row_base) - remaining_rows= s.subi(arg8, row_base_i32) - has_rows = s.cmpi_sgt(remaining_rows, c0_i32) # optionally overload __gt__ - too_many_rows = s.cmpi_sgt(remaining_rows, c8_i32) - row_count_i32 = s.select(too_many_rows, c8_i32, remaining_rows) - row_count = s.index_cast(row_count_i32) - seq = s.index_cast(arg7) - rows = s.index_cast(arg8) - rows_x_128 = s.muli(rows, c128) - - with pto.if_(has_rows): - # ── Tensor views ─────────────────────────────────── - s1 = [rows, rows, rows, c1, rows] - s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] - sh1 = [c1, c1, c1, rows, c1] - sh128= [c1, c1, c1, rows, c128] - - # 5D type `!pto.tensor_view` can be inferred from shape rank - oldmax_view = pto.make_tensor_view(arg0, shape=sh1, strides=s1) - oldsum_view = pto.make_tensor_view(arg1, shape=sh1, strides=s1) - qk_view = pto.make_tensor_view(arg2, shape=h128, strides=s128) - newmax_view = pto.make_tensor_view(arg3, shape=sh1, strides=s1) - newsum_view = pto.make_tensor_view(arg4, shape=sh1, strides=s1) - expmax_view = pto.make_tensor_view(arg5, shape=sh1, strides=s1) - out_view = pto.make_tensor_view(arg6, shape=sh128, strides=s128) - - # ── Partition views ──────────────────────────────── - off = [c0, c0, c0, row_base, c0] - z1 = [c1, c1, c1, row_count, c1] - zs = [c1, c1, c1, row_count, seq] - - # 5D type `!pto.tensor_view -> !pto.partition_tensor_view` can be inferred from shape rank - oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) - oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) - qk_part = pto.partition_view(qk_view, offsets=off, sizes=zs) - newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) - newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) - expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) - out_part = pto.partition_view(out_view, offsets=off, sizes=zs) - - # ── UB tile allocation ───────────────────────────── - tile_col = pto.tile_buf_type( - shape=[8, 1], dtype=pto.float32, valid_shape=[-1, 1], blayout="ColMajor") # valid=?x1, col_major - tile_w = pto.tile_buf_type( - shape=[8, 128], dtype=pto.float32, valid_shape=[-1, -1]) # valid=?x? - - oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) - oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) - qk_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) - out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) - newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) - newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) - expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) - - # ── Tile loads from GM ───────────────────────────── - pto.tload(oldmax_part, oldmax_tile) - pto.tload(oldsum_part, oldsum_tile) - pto.tload(qk_part, qk_tile) - - pto.set_flag("MTE2", "V", event_id=0) - pto.wait_flag("MTE2", "V", event_id=0) - - with pto.vecscope(): - # Materialise typed UB pointers from tile handles - ptr_ub = pto.ptr(pto.float32, "UB") # !pto.ptr - vf32 = pto.vreg_type(64, pto.float32) - ub_om = pto.tile_ptr(oldmax_tile, ptr_ub) - ub_os = pto.tile_ptr(oldsum_tile, ptr_ub) - ub_qk = pto.tile_ptr(qk_tile, ptr_ub) - ub_out = pto.tile_ptr(out_tile, ptr_ub) - ub_nm = pto.tile_ptr(newmax_tile, ptr_ub) - ub_ns = pto.tile_ptr(newsum_tile, ptr_ub) - ub_em = pto.tile_ptr(expmax_tile, ptr_ub) - - active = pto.pset_b32("PAT_ALL") - one_mask, _ = pto.plt_b32(c1_i32) - - with pto.for_(c0, row_count, step=c1) as row: - row_qk = s.muli(row, c128) # can optionally overload __mul__ - oldmax_bc = pto.vbrc_load(ub_om, row, vf32) - oldsum_bc = pto.vbrc_load(ub_os, row, vf32) - - # ── Chunk loop: compute running max & sum ── - # %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 - # iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) - # -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) - with pto.for_(c0, c128, step=c64, - iter_args=(oldmax_bc, oldsum_bc), - results=(vf32, vf32)) as loop: - chunk = loop.iv # induction variable %chunk (index) bound by `scf.for %chunk = ...` - running_max, running_sum = loop.iter_args - - chunk_i32 = s.index_cast(pto.int32, chunk) # arith.index_cast index to i32 - remaining_cols = s.subi(arg7, chunk_i32) # arith.subi - has_chunk = s.cmpi("sgt", remaining_cols, c0_i32) # arith.cmpi sgt - - # %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) - with pto.if_(has_chunk, results=(vf32, vf32)) as br: - with br.then_: - chunk_mask, chunk_rest = pto.plt_b32(remaining_cols) - chunk_base = s.addi(row_qk, chunk) - vec = pto.vlds(ub_qk, chunk_base, vf32) - chunk_max = pto.vcmax(vec, chunk_mask) - chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") - merged_max = pto.vmax(running_max, chunk_max_bc, active) - scaled_running = pto.vexpdif(running_max, merged_max, active, "ODD") - running_sum_scaled = pto.vmul(scaled_running, running_sum, active) - chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask, "ODD") - chunk_sum = pto.vcadd(chunk_exp, chunk_mask) - chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") - merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) - pto.yield_(merged_max, merged_sum) # scf.yield - with br.else_: - pto.yield_(running_max, running_sum) # scf.yield - next_max, next_sum = br.results - - pto.yield_(next_max, next_sum) # scf.yield - - final_max, final_sum = loop.results - - # ── Compute expmax scalar for this row ───── - raw_em = pto.vexpdif(oldmax_bc, final_max, active) - sc_os = pto.vmul(raw_em, oldsum_bc, active) - expmax = pto.vdiv(sc_os, final_sum, active) - - pto.vsts_1pt(final_max, ub_nm, row, one_mask) - pto.vsts_1pt(final_sum, ub_ns, row, one_mask) - pto.vsts_1pt(expmax, ub_em, row, one_mask) - - # ── Output normalisation loop ────────────── - with pto.for_(c0, c128, step=c64) as chunk2: - rem2 = s.subi(arg7, s.index_cast(pto.int32, chunk2)) - has_c2 = s.cmpi_sgt(rem2, c0_i32) - with pto.if_(has_c2): - cmask2, _ = pto.plt_b32(rem2) - cbase2 = s.addi(row_qk, chunk2) - vec2 = pto.vlds(ub_qk, cbase2, vf32) - exp2 = pto.vexpdif(vec2, final_max, cmask2) - out2 = pto.vdiv(exp2, final_sum, cmask2) - pto.vsts(out2, ub_out, cbase2, cmask2) - - pto.set_flag("V", "MTE3", event_id=0) - pto.wait_flag("V", "MTE3", event_id=1) - - # ── Tile stores to GM ────────────────────────────── - pto.tstore(newmax_tile, newmax_part) - pto.tstore(newsum_tile, newsum_part) - pto.tstore(expmax_tile, expmax_part) - pto.tstore(out_tile, out_part) - - pto.barrier_all() - - -if __name__ == "__main__": - print(softmax_demo) diff --git a/ptodsl/tile_and_vpto_builder_highlevel.py b/ptodsl/tile_and_vpto_builder_highlevel.py deleted file mode 100644 index 7c16cadc7..000000000 --- a/ptodsl/tile_and_vpto_builder_highlevel.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -""" -High-level builder for the TADD vPTO kernel. - -Reconstructs the same IR as expand_tileop_to_vpto_result.pto using the -thin wrappers in ptodsl_utils instead of raw MLIR Python binding calls. - -Compare with tile_and_vpto_builder_lowlevel.py to see what the utils hide: - • No manual InsertionPoint management - • No Operation.create("builtin.module", ...) boilerplate - • No Type.parse() / arith.ConstantOp(...).result calls in the kernel body - • vecscope and scf.for become ordinary Python context managers -""" - -from mlir.ir import F32Type - -from ptodsl_utils import ( - # types - ptr_type, vreg_type, - # constants - c_idx, c_i32, c_i64, - # arithmetic - muli, - # vector / pointer ops - castptr, addptr, vlds, vadd, vsts, plt_b32, - # scope helpers - vecscope, for_range, - # module builders - pto_context, vpto_kernel, -) - - -def build(): - with pto_context(): - # ── Types used in this kernel ───────────────────────────────────── - f32 = F32Type.get() - ptr_f32_ub = ptr_type(f32, "ub") # !pto.ptr - vreg_64f32 = vreg_type(64, f32) # !pto.vreg<64xf32> - - # ── Build the nested module shell and the @TADD function body ───── - with vpto_kernel("TADD", arch="a5") as mod: - - # Constants – declared in the same order as the reference IR. - c0_i64 = c_i64(0) - c16 = c_idx(16) # loop trip-count: 1024 elems / 64-wide vreg - c4096_i64 = c_i64(4096) - c0 = c_idx(0) - c1 = c_idx(1) - c64_i32 = c_i32(64) # scalar for mask generation - c64 = c_idx(64) - - with vecscope(): - # Materialise typed pointers from the raw integer addresses - ptr_src = castptr(c4096_i64, ptr_f32_ub) # source buffer - ptr_dst = castptr(c0_i64, ptr_f32_ub) # destination buffer - - with for_range(c0, c16, c1) as tile_idx: - # Build a 64-lane all-true mask for this iteration - mask, _ = plt_b32(c64_i32) - - # Byte offset for the current 64-element tile - tile_off = muli(tile_idx, c64) - - # Load source tile, add to destination tile, store result - va = vlds(addptr(ptr_src, tile_off), c0, vreg_64f32) - ptr_dst_tile = addptr(ptr_dst, tile_off) - vb = vlds(ptr_dst_tile, c0, vreg_64f32) - vc = vadd(va, vb, mask, vreg_64f32) - vsts(vc, ptr_dst_tile, c0, mask) - - return mod - - -if __name__ == "__main__": - print(build()) diff --git a/ptodsl/tile_and_vpto_builder_suggested.py b/ptodsl/tile_and_vpto_builder_suggested.py deleted file mode 100644 index 583d491ba..000000000 --- a/ptodsl/tile_and_vpto_builder_suggested.py +++ /dev/null @@ -1,32 +0,0 @@ -# minimum Pythonic mapping of test/lit/vpto/expand_tileop_to_vpto_result.pto - -import pto -s = pto.scalar - -@pto.to_ir(name="TADD", kernel_kind="vector", arch="a5") -def vpto_demo(): - c0_i64 = pto.const(0, dtype=pto.int64) - c16 = pto.const(16, dtype=pto.index) # if no dtype passed, default to pto.index - c4096_i64 = pto.const(4096, dtype=pto.int64) - c0 = pto.const(0) - c1 = pto.const(1) - c64_i32 = pto.const(64, dtype=pto.int32) - c64 = pto.const(64) - with pto.vecscope(): - ptr_type = pto.ptr(pto.float32, "UB") - ptr_src = pto.castptr(c4096_i64, ptr_type) - ptr_dst = pto.castptr(c0_i64, ptr_type) - vreg_type = vreg_type(64, pto.float32) - with for_(c0, c16, step=c1) as tile_idx: - mask, _ = pto.plt_b32(c64_i32) - tile_off = s.muli(tile_idx, c64) # can optionally overload __mul__ - va = pto.vlds(pto.addptr(ptr_src, tile_off), c0, vreg_type) - ptr_dst_tile = pto.addptr(ptr_dst, tile_off) - vb = pto.vlds(ptr_dst_tile, c0, vreg_type) - vc = pto.vadd(va, vb, mask, vreg_type) - pto.vsts(vc, ptr_dst_tile, c0, mask) - # by default return None, matches IR `return` - - -if __name__ == "__main__": - print(vpto_demo) From 1ac8d0d6048441e7987968b5abd9f69cc91140aa Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 14 May 2026 07:15:01 +0000 Subject: [PATCH 10/31] [vpto] Add ptodsl tracing POC --- docs/designs/ptodsl-tiletrace-poc-proposal.md | 178 +++++ lib/TileOps/tadd_template_tracing_poc.py | 79 ++ ptodsl/README.md | 39 + ptodsl/ptodsl/__init__.py | 12 +- ptodsl/ptodsl/_bootstrap.py | 45 +- ptodsl/ptodsl/vpto.py | 743 ++++++++++++++++++ 6 files changed, 1087 insertions(+), 9 deletions(-) create mode 100644 docs/designs/ptodsl-tiletrace-poc-proposal.md create mode 100644 lib/TileOps/tadd_template_tracing_poc.py create mode 100644 ptodsl/ptodsl/vpto.py diff --git a/docs/designs/ptodsl-tiletrace-poc-proposal.md b/docs/designs/ptodsl-tiletrace-poc-proposal.md new file mode 100644 index 000000000..d44f1c974 --- /dev/null +++ b/docs/designs/ptodsl-tiletrace-poc-proposal.md @@ -0,0 +1,178 @@ +# ptodsl `vpto` POC Proposal + +## Background + +Today we have two very different authoring paths for VPTO-related Python DSLs: + +- `ptodsl` executes Python directly and builds IR through tracing-style wrappers. +- `tilelang-dsl` captures Python source as AST, then runs `frontend_ast -> semantic -> lowering`. + +This split is especially visible for tile templates such as +[`lib/TileOps/tadd_template.py`](/home/zhangzhendong/ptoas-workspace/PTOAS/lib/TileOps/tadd_template.py), +whose body is conceptually simple but currently depends on the full AST frontend. + +For the longer-term direction, we want VPTO-level authoring to converge on the +same tracing-style route as `ptodsl`, while preserving as much of the +TileLang-style surface as practical. + +## Problem Statement + +The current AST route gives us good source diagnostics and broad surface +coverage, but it also has clear costs: + +- Every new surface feature needs to be added in three layers: + frontend node building, semantic typing, and text lowering. +- Reusing mature `ptodsl` builder idioms is difficult because authored Python is + no longer the execution model. +- Simple tile templates still pay the cost of a compiler frontend even when the + kernel body is static, structured, and already close to the desired VPTO form. + +For team discussion, the concrete question is: + +Can we execute a TileLang-style tile template directly and emit useful VPTO IR +without going through AST capture? + +## Proposal + +Introduce an experimental `ptodsl.vpto` namespace as a tracing-oriented POC for +TileLang-style tile templates. + +### Design Goals + +- Reuse the authored Python function body directly. +- Keep the POC independent from `tilelang-dsl` internals. +- Preserve the most recognizable TileLang surface where it is cheap: + `@pto.vkernel`, `Tile`, `dst.element_type`, `dst.valid_shape`, + `tile[row, col:]`, `get_lanes`, `make_mask`, `vlds`, `vadd`, `vsts`. +- Keep the implementation minimal and explicit enough that the team can judge + whether the tracing route is viable before we invest in broader migration. + +### Non-Goals for This POC + +- No attempt to replace `tilelang-dsl` in-place. +- No matcher, multi-dtype registry, template slots, inline-proc, or cube + surface. +- No source-diagnostic parity with the AST frontend. +- No requirement to generalize beyond the minimal pybinding-backed subset + needed for `tadd_template.py`. + +## POC Scope + +The POC is intentionally limited to a single template shape: + +- Target template: `tadd_template.py` +- Supported parameter kind: bare static 2D `Tile` +- Supported control flow: explicit builder-style `vecscope()` and `for_()` +- Supported ops: `make_mask`, `vlds`, `vadd`, `vsts` +- Supported lowering shape: one `pto.vecscope` containing nested `scf.for`, + `pto.tile_buf_addr`, and vector micro-ops + +This means the first implementation validates the core idea: + +1. specialize bare `Tile` parameters with static shape + dtype +2. execute the authored Python body directly +3. trace tile slice accesses such as `src0[row, col:]` +4. emit structured VPTO IR with `scf.for` and no AST capture + +## Why This Cut Is Useful + +This is not yet the final architecture, but it answers the most important +migration question with low implementation risk: + +- If the POC is too awkward even for `tadd_template.py`, we should not try to + move the main TileLang route onto tracing. +- If the POC stays small and readable, then we have evidence that a tracing + backend can carry at least a meaningful subset of tile templates. + +This cut also forces one important architectural decision early: + +- The tracing route should standardize on explicit builder-style control flow. + Reconstructing `scf.for` from raw Python `for range(...)` would pull us back + toward AST capture or source transformation, which defeats the purpose of the + experiment. + +## Proposed Architecture + +Add a new lightweight module: + +- [`ptodsl/ptodsl/vpto.py`](/home/zhangzhendong/ptoas-workspace/PTOAS/ptodsl/ptodsl/vpto.py) + +Core pieces: + +- `Tile` annotation marker +- `TileSpec(shape, dtype, memory_space="ub")` +- `@vkernel(target="a5", op="pto.tadd")` +- `TracingKernelDescriptor.specialize(...)` +- proxy `Tile` arguments that expose: + - `.element_type` + - `.valid_shape` + - `tile[row, col:]` +- a trace builder that emits structured MLIR objects through Python bindings + +The key idea is that `tile[row, col:]` is not lowered from AST. Instead, it is +captured at runtime through a proxy object and immediately converted into a +traced tile-slice value. + +## Expected Output Shape + +For the `tadd_template.py`-style kernel body, the POC emits: + +- tile-buffer arguments +- one `pto.vecscope` +- nested `scf.for` for rows and columns +- `pto.tile_buf_addr` for each referenced tile +- `pto.plt_b32` +- `pto.vlds` +- `pto.vadd` +- `pto.vsts` +- `scf.yield` for loop-carried `remained` + +This is intentionally close to the already documented tile-op expand form, but +keeps structured control flow instead of concretely unrolling the loops. + +## Tradeoffs + +### Advantages + +- Very small implementation surface for the first proof point. +- No dependency on AST parsing or source capture. +- Easy to compare source body and emitted IR side by side. +- Makes it clear which parts of TileLang syntax are “real execution” versus + “frontend-only sugar”. +- Produces IR that is much closer to a future scalable frontend than the + original fully unrolled POC. + +### Limitations + +- No rich diagnostics or semantic model yet. +- No integration with the existing `tilelang-dsl` package entrypoint. +- Current output is deliberately narrow and only covers the pybinding-backed + operations needed by the first POC template. +- Control flow currently needs explicit `vecscope()` / `for_()` builders instead + of raw Python `for range(...)`. + +These are acceptable for the first experiment because the goal is not feature +completeness; it is to validate the tracing execution model on a real tile +template. + +## Rollout Path If The POC Works + +If the POC proves maintainable, the next steps should be: + +1. Add a slightly broader vector subset beyond `tadd`. +2. Replace any remaining POC-specific glue with shared `ptodsl`/MLIR builders. +3. Introduce a reusable runtime contract layer for dtype, mask, and slice + checks. +4. Decide whether `tilelang-dsl` should: + - keep AST frontend and optionally target the tracing backend, or + - expose a parallel tracing-first authoring mode for static templates. + +## Deliverables In This Change + +- This proposal document +- an experimental `ptodsl.vpto` namespace +- a minimal `tadd_template.py`-oriented POC example + +The change is intentionally framed as a team discussion artifact plus a narrow +executable proof-of-concept, not as a replacement plan for the current +TileLang frontend. diff --git a/lib/TileOps/tadd_template_tracing_poc.py b/lib/TileOps/tadd_template_tracing_poc.py new file mode 100644 index 000000000..418783104 --- /dev/null +++ b/lib/TileOps/tadd_template_tracing_poc.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Experimental `ptodsl.vpto` POC version of the TileLang tadd template. + +This keeps the authored kernel body intentionally close to +`lib/TileOps/tadd_template.py`, but routes it through the experimental +`ptodsl.vpto` path instead of the TileLang AST frontend. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +PTODSL_DIR = REPO_ROOT / "ptodsl" +if str(PTODSL_DIR) not in sys.path: + sys.path.insert(0, str(PTODSL_DIR)) + +from ptodsl import vpto as pto + + +@pto.vkernel( + target="a5", + op="pto.tadd", +) +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + mask_scalar_ty = pto.i32 + + with pto.vecscope(): + with pto.for_(0, valid_rows, step=1) as row: + remained0 = pto.scalar_const(64, mask_scalar_ty) + with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: + col = loop.iv + remained = loop.state.remained + mask, next_remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + loop.yield_state(remained=next_remained) + + +def build_specialized_kernel(): + return template_tadd.specialize( + src0=pto.TileSpec(shape=(16, 64), dtype=pto.f32), + src1=pto.TileSpec(shape=(16, 64), dtype=pto.f32), + dst=pto.TileSpec(shape=(16, 64), dtype=pto.f32), + ) + + +def main(argv: list[str]) -> int: + materialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + materialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(materialized.mlir_text(), end="") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/ptodsl/README.md b/ptodsl/README.md index d2c2dac81..3f26a9eb5 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -200,6 +200,45 @@ pto.wait_flag("MTE2", "V", event_id=0) pto.barrier_all() ``` +### Experimental `vpto` POC + +For early experiments around AST-free tracing of TileLang-style tile templates, +`ptodsl` also exposes an experimental namespace: + +```python +from ptodsl import vpto as pto + +@pto.vkernel(target="a5", op="pto.tadd") +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + with pto.vecscope(): + with pto.for_(0, valid_rows, step=1) as row: + remained0 = pto.scalar_const(64, pto.i32) + with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: + col = loop.iv + remained = loop.state.remained + mask, next_remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.vadd(lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) + loop.yield_state(remained=next_remained) +``` + +Current limitations: + +- pybinding-backed POC only; it still covers a narrow TileLang-shaped subset +- supports only static 2D `Tile` parameters +- supports only a narrow vector subset needed by `tadd_template.py` +- requires explicit builder-style `vecscope()` / `for_()` rather than Python `for range(...)` + +Reference script: + +```bash +python3 lib/TileOps/tadd_template_tracing_poc.py +``` + --- ## How the IR check works diff --git a/ptodsl/ptodsl/__init__.py b/ptodsl/ptodsl/__init__.py index f558e21e3..cfd6e6537 100644 --- a/ptodsl/ptodsl/__init__.py +++ b/ptodsl/ptodsl/__init__.py @@ -7,6 +7,14 @@ # See LICENSE in the root of the software repository for the full text of the License. """ptodsl – PTO MLIR DSL package.""" -from . import pto, scalar # noqa: F401 +from importlib import import_module -__all__ = ["pto", "scalar"] +__all__ = ["pto", "scalar", "vpto"] + + +def __getattr__(name): + if name in {"pto", "scalar", "vpto"}: + module = import_module(f".{name}", __name__) + globals()[name] = module + return module + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/ptodsl/ptodsl/_bootstrap.py b/ptodsl/ptodsl/_bootstrap.py index 894310ae3..50449e312 100644 --- a/ptodsl/ptodsl/_bootstrap.py +++ b/ptodsl/ptodsl/_bootstrap.py @@ -8,18 +8,49 @@ """ MLIR path bootstrap and context factory. -Adds the ptoas install directory to sys.path so that the mlir package is -importable regardless of how the ptodsl package itself was installed. +Discovers local LLVM MLIR Python bindings plus PTO Python dialect artifacts so +that ``ptodsl`` can import ``mlir`` / ``mlir.dialects.pto`` directly from a +developer workspace without requiring the caller to pre-seed ``PYTHONPATH``. """ import os import sys +from pathlib import Path -_INSTALL = os.path.normpath( - os.path.join(os.path.dirname(__file__), "..", "..", "install", "mlir") -) -if os.path.isdir(_INSTALL) and _INSTALL not in sys.path: - sys.path.insert(0, _INSTALL) + +def _candidate_python_roots() -> list[Path]: + here = Path(__file__).resolve() + repo_root = here.parents[2] + workspace_root = repo_root.parent + env_roots = [] + for env_name in ("MLIR_PYTHON_ROOT", "PTO_PYTHON_ROOT"): + raw = os.environ.get(env_name) + if raw: + env_roots.append(Path(raw)) + + return [ + *env_roots, + repo_root / "build" / "python", + repo_root / "install", + workspace_root / "llvm-project" / "build-shared" / "tools" / "mlir" / "python_packages" / "mlir_core", + ] + + +def _bootstrap_python_paths() -> None: + added = set() + for root in _candidate_python_roots(): + if not root or not root.is_dir(): + continue + if not (root / "mlir").exists(): + continue + root_text = str(root) + if root_text in added or root_text in sys.path: + continue + sys.path.insert(0, root_text) + added.add(root_text) + + +_bootstrap_python_paths() from mlir.dialects import pto as _pto_dialect # noqa: E402 from mlir.ir import Context, Location # noqa: E402 diff --git a/ptodsl/ptodsl/vpto.py b/ptodsl/ptodsl/vpto.py new file mode 100644 index 000000000..5ad7e7bca --- /dev/null +++ b/ptodsl/ptodsl/vpto.py @@ -0,0 +1,743 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Experimental `ptodsl.vpto` POC for TileLang-style tile templates. + +This module keeps the authored Python body close to TileLang-style templates, +but traces execution directly into MLIR Python bindings instead of going through +an AST-capture frontend. + +Current scope: +- bare ``Tile`` parameters with static 2D specializations +- ``dst.element_type`` / ``dst.valid_shape`` +- explicit `with pto.vecscope():` +- explicit structured `with pto.for_(...) as ...:` +- optional named loop-carried state via ``state={...}`` +- ``get_lanes(dtype)`` +- ``make_mask(dtype, remained)`` +- ``vlds(tile[row, col:])`` +- ``vadd(lhs, rhs, mask)`` +- ``vsts(vec, tile[row, col:], mask)`` + +The goal of this POC is to validate a tracing-oriented VPTO frontend shape that +already builds real MLIR Python objects, while staying intentionally narrow and +readable for `tadd_template.py`. +""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from pathlib import Path + +from . import scalar as _scalar +from ._bootstrap import make_context +from ._types import ( + _resolve, + float16 as _float16, + float32 as _float32, + index as _index, + int8 as _int8, + int16 as _int16, + int32 as _int32, + int64 as _int64, + mask_type as _mask_type, + ptr as _ptr, + tile_buf_type as _tile_buf_type, + vreg_type as _vreg_type, +) + +from mlir.dialects import arith, func, pto as _pto, scf +from mlir.ir import Attribute, InsertionPoint, IntegerType, Location, Module, Operation, StringAttr, Type + + +_ACTIVE_TRACE = None + + +@dataclass(frozen=True) +class ScalarType: + name: str + lanes: int + mask_bits: int + bytewidth: int + + def __repr__(self) -> str: + return self.name + + +f32 = ScalarType("f32", lanes=64, mask_bits=32, bytewidth=4) +f16 = ScalarType("f16", lanes=128, mask_bits=16, bytewidth=2) +bf16 = ScalarType("bf16", lanes=128, mask_bits=16, bytewidth=2) +i32 = ScalarType("i32", lanes=64, mask_bits=32, bytewidth=4) +i16 = ScalarType("i16", lanes=128, mask_bits=16, bytewidth=2) +i8 = ScalarType("i8", lanes=256, mask_bits=8, bytewidth=1) + + +class Tile: + """Tile annotation marker for the tracing POC.""" + + +@dataclass(frozen=True) +class TileSpec: + shape: tuple[int, int] + dtype: ScalarType + memory_space: str = "ub" + + def __post_init__(self): + if len(self.shape) != 2: + raise ValueError("TileSpec currently only supports rank-2 tile shapes") + if any(not isinstance(dim, int) or dim <= 0 for dim in self.shape): + raise ValueError("TileSpec.shape must contain positive integers") + if self.memory_space != "ub": + raise ValueError("TileSpec currently only supports ub tiles") + + def mlir_type(self): + rows, cols = self.shape + return _tile_buf_type( + [rows, cols], + _scalar_descriptor(self.dtype), + [rows, cols], + blayout="RowMajor", + address_space=self.memory_space, + slayout="NoneBox", + fractal_size=512, + pad="Null", + ) + + +@dataclass(frozen=True) +class _Value: + value: object + const_value: int | None = None + + def __repr__(self) -> str: + return str(self.value) + + @property + def type_text(self) -> str: + return str(self.value.type) + + @property + def is_const(self) -> bool: + return self.const_value is not None + + +@dataclass(frozen=True) +class _MaskValue: + value: object + dtype: ScalarType + + @property + def type_text(self) -> str: + return str(self.value.type) + + +@dataclass(frozen=True) +class _VectorValue: + value: object + dtype: ScalarType + + @property + def type_text(self) -> str: + return str(self.value.type) + + +@dataclass(frozen=True) +class _TileSlice: + tile: "_TileProxy" + row: int | _Value + col: int | _Value + + +class _TileProxy: + def __init__(self, trace: "_TraceBuilder", arg_value, spec: TileSpec): + self._trace = trace + self._arg_value = arg_value + self._spec = spec + + @property + def element_type(self) -> ScalarType: + return self._spec.dtype + + @property + def valid_shape(self) -> tuple[_Value, _Value]: + return ( + self._trace.index_const(self._spec.shape[0]), + self._trace.index_const(self._spec.shape[1]), + ) + + @property + def type_text(self) -> str: + return str(self._arg_value.type) + + def __getitem__(self, key): + if ( + not isinstance(key, tuple) + or len(key) != 2 + or not _is_index_like(key[0]) + or not isinstance(key[1], slice) + ): + raise TypeError("vpto POC only supports tile[row, col:] indexing") + row, col_slice = key + if col_slice.stop is not None or col_slice.step is not None: + raise TypeError("vpto POC only supports tile[row, col:] slices") + col = 0 if col_slice.start is None else col_slice.start + if not _is_index_like(col): + raise TypeError("vpto POC only supports integer/index column offsets") + _validate_static_bound(row, self._spec.shape[0], "row") + _validate_static_bound(col, self._spec.shape[1], "column") + return _TileSlice(self, row=row, col=col) + + +class _LoopStateView: + def __init__(self, names: tuple[str, ...], values: tuple[_Value, ...]): + self._values = dict(zip(names, values)) + + def __getattr__(self, name: str) -> _Value: + try: + return self._values[name] + except KeyError as exc: + raise AttributeError(name) from exc + + +class _LoopHandle: + def __init__( + self, + trace: "_TraceBuilder", + for_op, + iv: _Value, + iter_args: tuple[_Value, ...], + state_names: tuple[str, ...] = (), + ): + self._trace = trace + self._for_op = for_op + self.iv = iv + self.iter_args = iter_args + self._state_names = state_names + self.state = _LoopStateView(state_names, iter_args) if state_names else None + self.results: tuple[_Value, ...] = () + + def _finalize(self) -> None: + self.results = tuple(_Value(result) for result in self._for_op.results) + + def yield_state(self, **kwargs) -> None: + if not self._state_names: + raise RuntimeError("loop.yield_state(...) requires for_(..., state={...})") + missing = [name for name in self._state_names if name not in kwargs] + extra = [name for name in kwargs if name not in self._state_names] + if missing or extra: + pieces = [] + if missing: + pieces.append(f"missing: {', '.join(missing)}") + if extra: + pieces.append(f"unexpected: {', '.join(extra)}") + raise RuntimeError( + "loop.yield_state(...) must match loop state names exactly; " + + "; ".join(pieces) + ) + ordered = tuple(kwargs[name] for name in self._state_names) + self._trace._yield_loop_values(ordered, surface="loop.yield_state", from_named_state=True) + + +class _VecScopeCM: + def __init__(self, trace: "_TraceBuilder"): + self._trace = trace + + def __enter__(self): + self._trace._enter_vecscope() + return None + + def __exit__(self, exc_type, exc, tb): + self._trace._exit_vecscope(exc_type, exc, tb) + + +class _ForCM: + def __init__(self, trace: "_TraceBuilder", start, stop, step, iter_args, state): + self._trace = trace + self._start = start + self._stop = stop + self._step = step + self._iter_args = list(iter_args) if iter_args is not None else [] + self._state = tuple(state.items()) if state is not None else () + self._handle: _LoopHandle | None = None + + def __enter__(self): + self._handle = self._trace._enter_for( + self._start, + self._stop, + self._step, + self._iter_args, + self._state, + ) + if self._iter_args or self._state: + return self._handle + return self._handle.iv + + def __exit__(self, exc_type, exc, tb): + self._trace._exit_for(self._handle, exc_type, exc, tb) + + +class _TraceBuilder: + def __init__(self, descriptor: "TracingKernelDescriptor", tile_specs: dict[str, TileSpec]): + self.descriptor = descriptor + self.tile_specs = tile_specs + self._const_cache: dict[tuple[int, str], _Value] = {} + self._tile_ptr_cache: dict[int, _Value] = {} + self._row_offset_cache: dict[tuple[str, str], _Value] = {} + self._loop_stack: list[dict] = [] + self._inside_vecscope = False + + def build_module(self): + global _ACTIVE_TRACE + if _ACTIVE_TRACE is not None: + raise RuntimeError("nested vpto builds are not supported") + + signature = inspect.signature(self.descriptor.py_fn) + ctx = make_context() + with ctx, Location.unknown(): + arg_types = [] + ordered_specs = [] + for param_name, param in signature.parameters.items(): + if not _is_tile_annotation(param.annotation): + raise TypeError( + "vpto POC currently only supports Tile parameters; " + f"parameter {param_name!r} uses {param.annotation!r}" + ) + spec = self.tile_specs.get(param_name) + if spec is None: + raise ValueError(f"missing specialization for Tile parameter {param_name!r}") + ordered_specs.append((param_name, spec)) + arg_types.append(spec.mlir_type()) + + module = Module.create() + module.operation.attributes["pto.target_arch"] = StringAttr.get(self.descriptor.target) + + with InsertionPoint(module.body): + inner_op = Operation.create("builtin.module", regions=1) + inner_op.attributes["pto.target_arch"] = StringAttr.get(self.descriptor.target) + inner_op.attributes["pto.kernel_kind"] = Attribute.parse("#pto.kernel_kind") + inner_body = inner_op.regions[0].blocks.append() + + with InsertionPoint(inner_body): + fn_ty = func.FunctionType.get(arg_types, []) + ir_fn = func.FuncOp(self.descriptor.name, fn_ty) + + entry = ir_fn.add_entry_block() + with InsertionPoint(entry): + args = [] + for arg_value, (_, spec) in zip(entry.arguments, ordered_specs): + args.append(_TileProxy(self, arg_value, spec)) + + _ACTIVE_TRACE = self + try: + self.descriptor.py_fn(*args) + finally: + _ACTIVE_TRACE = None + + if self._inside_vecscope: + raise RuntimeError("vpto kernel exited with an open vecscope block") + if self._loop_stack: + raise RuntimeError("vpto kernel exited with an open scf.for block") + + func.ReturnOp([]) + + module.operation.verify() + return module + + def vecscope(self) -> _VecScopeCM: + return _VecScopeCM(self) + + def for_(self, start, stop, *, step, iter_args=None, state=None) -> _ForCM: + if iter_args is not None and state is not None: + raise ValueError("for_() accepts either iter_args= or state=, not both") + if state is not None: + if not hasattr(state, "items"): + raise TypeError("for_(..., state=...) expects a mapping of name -> initial value") + for name in state: + if not isinstance(name, str) or not name: + raise TypeError("for_ state names must be non-empty strings") + return _ForCM(self, start, stop, step, iter_args, state) + + def yield_(self, *vals): + self._yield_loop_values(vals, surface="yield_", from_named_state=False) + + def _yield_loop_values(self, vals, *, surface: str, from_named_state: bool): + if not self._loop_stack: + raise RuntimeError(f"{surface}(...) may only be used inside a vpto for_ block") + frame = self._loop_stack[-1] + if frame["kind"] != "for": + raise RuntimeError(f"{surface}(...) may only be used inside a vpto for_ block") + if frame["state_names"] and not from_named_state: + raise RuntimeError( + f"{surface}(...) is ambiguous for vpto for_ with named state; " + "use loop.yield_state(...) instead" + ) + if frame["yielded"]: + raise RuntimeError(f"{surface}(...) may only be emitted once per vpto for_ block") + if len(vals) != len(frame["iter_args"]): + raise RuntimeError( + f"{surface}(...) expected {len(frame['iter_args'])} value(s), got {len(vals)}" + ) + coerced = tuple( + self._coerce_like(arg, expected.type_text) + for arg, expected in zip(vals, frame["iter_args"]) + ) + scf.YieldOp([val.value for val in coerced]) + frame["yielded"] = True + frame["yield_vals"] = coerced + + def index_const(self, value: int) -> _Value: + return self._const(value, _resolve(_index)) + + def scalar_const(self, value: int, dtype: ScalarType) -> _Value: + return self._const(value, _resolve(_scalar_descriptor(dtype))) + + def _const(self, value: int, mlir_type) -> _Value: + cache_key = (value, str(mlir_type)) + cached = self._const_cache.get(cache_key) + if cached is not None: + return cached + const = _Value(arith.ConstantOp(mlir_type, value).result, const_value=value) + self._const_cache[cache_key] = const + return const + + def ensure_tile_ptr(self, tile: _TileProxy) -> _Value: + cache_key = id(tile._arg_value) + cached = self._tile_ptr_cache.get(cache_key) + if cached is not None: + return cached + ptr_type = _resolve(_ptr(_scalar_descriptor(tile.element_type), tile._spec.memory_space)) + ptr_value = _Value(_pto.TileBufAddrOp(ptr_type, tile._arg_value).result) + self._tile_ptr_cache[cache_key] = ptr_value + return ptr_value + + def materialize_linear_offset(self, tile_slice: _TileSlice) -> _Value: + cols = tile_slice.tile._spec.shape[1] + row = self._coerce_index(tile_slice.row) + col = self._coerce_index(tile_slice.col) + if row.is_const and col.is_const: + return self.index_const(row.const_value * cols + col.const_value) + row_stride = self.index_const(cols) + row_off = self._materialize_row_offset(row, row_stride) + return _Value(_scalar.addi(row_off.value, col.value)) + + def _enter_vecscope(self): + if self._inside_vecscope: + raise RuntimeError("nested vpto vecscope blocks are not supported in this POC") + vecscope_op = _pto.VecScopeOp() + vecscope_block = vecscope_op.body.blocks.append() + vecscope_ip = InsertionPoint(vecscope_block) + vecscope_ip.__enter__() + self._loop_stack.append( + { + "kind": "vecscope", + "ip": vecscope_ip, + } + ) + self._inside_vecscope = True + + def _exit_vecscope(self, exc_type, exc, tb): + if not self._inside_vecscope: + raise RuntimeError("vecscope exit without matching enter") + frame = self._loop_stack.pop() + if frame["kind"] != "vecscope": + raise RuntimeError("vpto vecscope stack corruption detected") + frame["ip"].__exit__(exc_type, exc, tb) + self._inside_vecscope = False + + def _enter_for(self, start, stop, step, iter_args, state_items) -> _LoopHandle: + if not self._inside_vecscope: + raise RuntimeError("vpto POC currently only supports for_ inside vecscope") + start_val = self._coerce_index(start) + stop_val = self._coerce_index(stop) + step_val = self._coerce_index(step) + state_names = tuple(name for name, _ in state_items) + if state_names: + iter_arg_vals = tuple(self._coerce_value(arg) for _, arg in state_items) + else: + iter_arg_vals = tuple(self._coerce_value(arg) for arg in iter_args) + for_op = scf.ForOp( + start_val.value, + stop_val.value, + step_val.value, + [arg.value for arg in iter_arg_vals] if iter_arg_vals else None, + ) + loop_ip = InsertionPoint(for_op.body) + loop_ip.__enter__() + iv = _Value(for_op.induction_variable) + inner_iter_args = tuple(_Value(arg) for arg in for_op.inner_iter_args) + handle = _LoopHandle(self, for_op, iv, inner_iter_args, state_names=state_names) + self._loop_stack.append( + { + "kind": "for", + "handle": handle, + "ip": loop_ip, + "iter_args": inner_iter_args, + "state_names": state_names, + "yielded": False, + "yield_vals": (), + } + ) + return handle + + def _exit_for(self, handle: _LoopHandle | None, exc_type, exc, tb): + if handle is None: + raise RuntimeError("for_ exit without a loop handle") + frame = self._loop_stack.pop() + if frame["kind"] != "for" or frame["handle"] is not handle: + raise RuntimeError("vpto for_ stack corruption detected") + if exc_type is None: + if frame["iter_args"] and not frame["yielded"]: + if frame["state_names"]: + raise RuntimeError( + "vpto for_ with named state requires explicit loop.yield_state(...)" + ) + raise RuntimeError("vpto for_ with iter_args requires explicit yield_(...)") + if not frame["iter_args"]: + scf.YieldOp([]) + frame["ip"].__exit__(exc_type, exc, tb) + if exc_type is not None: + return + handle._finalize() + + def _materialize_row_offset(self, row: _Value, row_stride: _Value) -> _Value: + if row.is_const and row_stride.is_const: + return self.index_const(row.const_value * row_stride.const_value) + cache_key = (str(row.value), str(row_stride.value)) + cached = self._row_offset_cache.get(cache_key) + if cached is not None: + return cached + result = _Value(_scalar.muli(row.value, row_stride.value)) + self._row_offset_cache[cache_key] = result + return result + + def _coerce_index(self, value) -> _Value: + coerced = self._coerce_value(value) + if coerced.type_text != str(_resolve(_index)): + raise TypeError(f"expected index value, got {coerced.type_text}") + return coerced + + def _coerce_value(self, value) -> _Value: + if isinstance(value, _Value): + return value + if isinstance(value, int): + return self.index_const(value) + if hasattr(value, "type"): + return _Value(value) + raise TypeError(f"unsupported vpto scalar value {value!r}") + + def _coerce_like(self, value, ty: str) -> _Value: + coerced = self._coerce_value(value) + if coerced.type_text != ty: + raise TypeError(f"expected value of type {ty}, got {coerced.type_text}") + return coerced + + +@dataclass(frozen=True) +class TracingKernelDescriptor: + py_fn: object + target: str + op: str + name: str + source_label: str + + def specialize(self, **tile_specs: TileSpec) -> "MaterializedTracingKernel": + return MaterializedTracingKernel(self, tile_specs) + + +class MaterializedTracingKernel: + def __init__(self, descriptor: TracingKernelDescriptor, tile_specs: dict[str, TileSpec]): + self.descriptor = descriptor + self.tile_specs = tile_specs + self._cached_module = None + + def build(self): + if self._cached_module is None: + self._cached_module = _TraceBuilder(self.descriptor, self.tile_specs).build_module() + return self._cached_module + + def mlir_text(self) -> str: + return str(self.build()) + + def emit(self, path: str | Path) -> None: + Path(path).write_text(self.mlir_text(), encoding="utf-8") + + def __str__(self) -> str: + return self.mlir_text() + + +def vkernel(*, target: str = "a5", op: str, name: str | None = None): + if target != "a5": + raise ValueError("vpto POC currently only supports target='a5'") + + def decorator(fn): + source_path = Path(inspect.getsourcefile(fn) or "") + descriptor_name = name or fn.__name__ + return TracingKernelDescriptor( + py_fn=fn, + target=target, + op=op, + name=descriptor_name, + source_label=f"{source_path}:{fn.__name__}", + ) + + return decorator + + +def vecscope() -> _VecScopeCM: + return _require_active_trace("vecscope").vecscope() + + +def for_(start, stop, *, step, iter_args=None, state=None) -> _ForCM: + return _require_active_trace("for_").for_(start, stop, step=step, iter_args=iter_args, state=state) + + +def yield_(*vals): + _require_active_trace("yield_").yield_(*vals) + + +def get_lanes(dtype: ScalarType) -> _Value: + return _require_active_trace("get_lanes").index_const(dtype.lanes) + + +def scalar_const(value: int, dtype: ScalarType) -> _Value: + return _require_active_trace("scalar_const").scalar_const(value, dtype) + + +def make_mask(dtype: ScalarType, remained) -> tuple[_MaskValue, _Value]: + trace = _require_active_trace("make_mask") + remained_val = trace._coerce_value(remained) + expected_scalar_ty = str(_resolve(_scalar_descriptor(_scalar_type_for_mask(dtype)))) + if remained_val.type_text != expected_scalar_ty: + raise TypeError( + f"vpto POC expects make_mask remained to use {expected_scalar_ty}, got {remained_val.type_text}" + ) + if dtype.mask_bits not in {8, 16, 32}: + raise ValueError(f"unsupported mask bit-width {dtype.mask_bits}") + mask_ty = _mask_type(f"b{dtype.mask_bits}") + scalar_ty = IntegerType.get_signless(dtype.mask_bits) + op_cls = getattr(_pto, f"PltB{dtype.mask_bits}Op", None) + if op_cls is None: + raise NotImplementedError( + f"pto.PltB{dtype.mask_bits}Op is not available in the current Python bindings" + ) + plt_op = op_cls(mask_ty, scalar_ty, remained_val.value) + lanes = trace.scalar_const(dtype.lanes, _scalar_type_for_mask(dtype)) + next_value = _Value(_scalar.subi(remained_val.value, lanes.value)) + return _MaskValue(plt_op.mask, dtype), next_value + + +def vlds(tile_slice: _TileSlice) -> _VectorValue: + trace = _require_active_trace("vlds") + if not isinstance(tile_slice, _TileSlice): + raise TypeError("vpto POC only supports vlds(tile[row, col:])") + ptr_value = trace.ensure_tile_ptr(tile_slice.tile) + offset = trace.materialize_linear_offset(tile_slice) + vector_ty = _resolve(_vreg_type(tile_slice.tile.element_type.lanes, _scalar_descriptor(tile_slice.tile.element_type))) + result = _pto.VldsOp(vector_ty, ptr_value.value, offset.value).result + return _VectorValue(result, tile_slice.tile.element_type) + + +def vadd(lhs: _VectorValue, rhs: _VectorValue, mask: _MaskValue) -> _VectorValue: + if lhs.dtype != rhs.dtype: + raise TypeError("vpto POC expects vadd operands to use the same dtype") + if lhs.dtype != mask.dtype: + raise TypeError("vpto POC expects vadd mask dtype to match vector dtype") + result = _pto.VaddOp(lhs.value.type, lhs.value, rhs.value, mask.value).result + return _VectorValue(result, lhs.dtype) + + +def vsts(vec: _VectorValue, tile_slice: _TileSlice, mask: _MaskValue) -> None: + trace = _require_active_trace("vsts") + if vec.dtype != mask.dtype: + raise TypeError("vpto POC expects vsts mask dtype to match vector dtype") + if vec.dtype != tile_slice.tile.element_type: + raise TypeError("vpto POC expects vsts destination dtype to match vector dtype") + ptr_value = trace.ensure_tile_ptr(tile_slice.tile) + offset = trace.materialize_linear_offset(tile_slice) + _pto.VstsOp(vec.value, ptr_value.value, offset.value, mask.value) + + +def _require_active_trace(surface: str) -> _TraceBuilder: + if _ACTIVE_TRACE is None: + raise RuntimeError(f"{surface}() may only be used while tracing a vpto kernel") + return _ACTIVE_TRACE + + +def _is_tile_annotation(annotation) -> bool: + if annotation is Tile: + return True + if isinstance(annotation, str): + return annotation == "Tile" or annotation.endswith(".Tile") + return getattr(annotation, "__name__", None) == "Tile" + + +def _is_index_like(value) -> bool: + return isinstance(value, int) or (isinstance(value, _Value) and value.type_text == str(_resolve(_index))) + + +def _validate_static_bound(value, upper_bound: int, label: str): + if isinstance(value, int): + if value < 0 or value >= upper_bound: + raise IndexError(f"{label} {value} is outside tile bound {upper_bound}") + return + if isinstance(value, _Value) and value.is_const: + concrete = value.const_value + if concrete < 0 or concrete >= upper_bound: + raise IndexError(f"{label} {concrete} is outside tile bound {upper_bound}") + + +def _scalar_descriptor(dtype: ScalarType): + descriptors = { + "f32": _float32, + "f16": _float16, + "bf16": Type.parse("bf16"), + "i8": _int8, + "i16": _int16, + "i32": _int32, + "i64": _int64, + } + descriptor = descriptors.get(dtype.name) + if descriptor is None: + raise ValueError(f"unsupported scalar dtype {dtype.name}") + return descriptor + + +def _scalar_type_for_mask(dtype: ScalarType) -> ScalarType: + if dtype.mask_bits == 8: + return i8 + if dtype.mask_bits == 16: + return i16 + if dtype.mask_bits == 32: + return i32 + raise ValueError(f"unsupported mask bit-width {dtype.mask_bits}") + + +__all__ = [ + "Tile", + "TileSpec", + "TracingKernelDescriptor", + "MaterializedTracingKernel", + "ScalarType", + "f32", + "f16", + "bf16", + "i32", + "i16", + "i8", + "vkernel", + "vecscope", + "for_", + "yield_", + "get_lanes", + "scalar_const", + "make_mask", + "vlds", + "vadd", + "vsts", +] From 2c2cf6de235ccb9dbdf8784d16722736a6a73779 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 14 May 2026 07:36:16 +0000 Subject: [PATCH 11/31] [vpto] Allow structured loops without vecscope --- docs/designs/ptodsl-tiletrace-poc-proposal.md | 13 +++++----- lib/TileOps/tadd_template_tracing_poc.py | 23 ++++++++-------- ptodsl/README.md | 26 +++++++++---------- ptodsl/ptodsl/vpto.py | 4 +-- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/docs/designs/ptodsl-tiletrace-poc-proposal.md b/docs/designs/ptodsl-tiletrace-poc-proposal.md index d44f1c974..6b66b23ff 100644 --- a/docs/designs/ptodsl-tiletrace-poc-proposal.md +++ b/docs/designs/ptodsl-tiletrace-poc-proposal.md @@ -62,10 +62,11 @@ The POC is intentionally limited to a single template shape: - Target template: `tadd_template.py` - Supported parameter kind: bare static 2D `Tile` -- Supported control flow: explicit builder-style `vecscope()` and `for_()` +- Supported control flow: explicit structured `for_()` builders, with optional + `vecscope()` when the author wants to spell it directly - Supported ops: `make_mask`, `vlds`, `vadd`, `vsts` -- Supported lowering shape: one `pto.vecscope` containing nested `scf.for`, - `pto.tile_buf_addr`, and vector micro-ops +- Supported lowering shape: nested `scf.for`, `pto.tile_buf_addr`, and vector + micro-ops, with optional `pto.vecscope` This means the first implementation validates the core idea: @@ -118,7 +119,6 @@ traced tile-slice value. For the `tadd_template.py`-style kernel body, the POC emits: - tile-buffer arguments -- one `pto.vecscope` - nested `scf.for` for rows and columns - `pto.tile_buf_addr` for each referenced tile - `pto.plt_b32` @@ -148,8 +148,9 @@ keeps structured control flow instead of concretely unrolling the loops. - No integration with the existing `tilelang-dsl` package entrypoint. - Current output is deliberately narrow and only covers the pybinding-backed operations needed by the first POC template. -- Control flow currently needs explicit `vecscope()` / `for_()` builders instead - of raw Python `for range(...)`. +- Control flow currently needs explicit structured `for_()` builders instead of + raw Python `for range(...)`. `vecscope()` can still be used, but is not a + hard requirement in the POC. These are acceptable for the first experiment because the goal is not feature completeness; it is to validate the tracing execution model on a real tile diff --git a/lib/TileOps/tadd_template_tracing_poc.py b/lib/TileOps/tadd_template_tracing_poc.py index 418783104..0b711f6da 100644 --- a/lib/TileOps/tadd_template_tracing_poc.py +++ b/lib/TileOps/tadd_template_tracing_poc.py @@ -36,18 +36,17 @@ def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): valid_rows, valid_cols = dst.valid_shape mask_scalar_ty = pto.i32 - with pto.vecscope(): - with pto.for_(0, valid_rows, step=1) as row: - remained0 = pto.scalar_const(64, mask_scalar_ty) - with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: - col = loop.iv - remained = loop.state.remained - mask, next_remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - summed = pto.vadd(lhs, rhs, mask) - pto.vsts(summed, dst[row, col:], mask) - loop.yield_state(remained=next_remained) + with pto.for_(0, valid_rows, step=1) as row: + remained0 = pto.scalar_const(64, mask_scalar_ty) + with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: + col = loop.iv + remained = loop.state.remained + mask, next_remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + loop.yield_state(remained=next_remained) def build_specialized_kernel(): diff --git a/ptodsl/README.md b/ptodsl/README.md index 3f26a9eb5..f1510bf1e 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -212,18 +212,17 @@ from ptodsl import vpto as pto def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): dtype = dst.element_type valid_rows, valid_cols = dst.valid_shape - with pto.vecscope(): - with pto.for_(0, valid_rows, step=1) as row: - remained0 = pto.scalar_const(64, pto.i32) - with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: - col = loop.iv - remained = loop.state.remained - mask, next_remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - out = pto.vadd(lhs, rhs, mask) - pto.vsts(out, dst[row, col:], mask) - loop.yield_state(remained=next_remained) + with pto.for_(0, valid_rows, step=1) as row: + remained0 = pto.scalar_const(64, pto.i32) + with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: + col = loop.iv + remained = loop.state.remained + mask, next_remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.vadd(lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) + loop.yield_state(remained=next_remained) ``` Current limitations: @@ -231,7 +230,8 @@ Current limitations: - pybinding-backed POC only; it still covers a narrow TileLang-shaped subset - supports only static 2D `Tile` parameters - supports only a narrow vector subset needed by `tadd_template.py` -- requires explicit builder-style `vecscope()` / `for_()` rather than Python `for range(...)` +- currently uses explicit structured `for_()` builders rather than Python `for range(...)` +- `vecscope()` remains available, but it is no longer required by the POC Reference script: diff --git a/ptodsl/ptodsl/vpto.py b/ptodsl/ptodsl/vpto.py index 5ad7e7bca..cd8c3db23 100644 --- a/ptodsl/ptodsl/vpto.py +++ b/ptodsl/ptodsl/vpto.py @@ -15,7 +15,7 @@ Current scope: - bare ``Tile`` parameters with static 2D specializations - ``dst.element_type`` / ``dst.valid_shape`` -- explicit `with pto.vecscope():` +- optional `with pto.vecscope():` - explicit structured `with pto.for_(...) as ...:` - optional named loop-carried state via ``state={...}`` - ``get_lanes(dtype)`` @@ -451,8 +451,6 @@ def _exit_vecscope(self, exc_type, exc, tb): self._inside_vecscope = False def _enter_for(self, start, stop, step, iter_args, state_items) -> _LoopHandle: - if not self._inside_vecscope: - raise RuntimeError("vpto POC currently only supports for_ inside vecscope") start_val = self._coerce_index(start) stop_val = self._coerce_index(stop) step_val = self._coerce_index(step) From 16303bdef6c988b0b678d2c0d997976d0bfc1599 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 14 May 2026 11:03:03 +0000 Subject: [PATCH 12/31] Add user guides --- ptodsl/docs/user_guide/01-introduction.md | 47 + ptodsl/docs/user_guide/02-quick-start.md | 78 + .../docs/user_guide/03-kernel-declaration.md | 528 ++++++ ptodsl/docs/user_guide/04-template-kernels.md | 333 ++++ ptodsl/docs/user_guide/05-type-system.md | 686 +++++++ ptodsl/docs/user_guide/06-control-flow.md | 181 ++ .../docs/user_guide/07-frontend-operations.md | 352 ++++ .../docs/user_guide/08-sync-dma-operations.md | 622 +++++++ .../user_guide/09-vector-memory-operations.md | 1058 +++++++++++ .../user_guide/10-predicate-operations.md | 637 +++++++ .../11-vector-arithmetic-operations.md | 1611 +++++++++++++++++ ptodsl/docs/user_guide/12-cube-operations.md | 454 +++++ ptodsl/docs/user_guide/13-examples.md | 417 +++++ ptodsl/docs/user_guide/14-common-errors.md | 51 + .../docs/user_guide/15-compatibility-notes.md | 9 + ptodsl/docs/user_guide/16-next-steps.md | 7 + 16 files changed, 7071 insertions(+) create mode 100644 ptodsl/docs/user_guide/01-introduction.md create mode 100644 ptodsl/docs/user_guide/02-quick-start.md create mode 100644 ptodsl/docs/user_guide/03-kernel-declaration.md create mode 100644 ptodsl/docs/user_guide/04-template-kernels.md create mode 100644 ptodsl/docs/user_guide/05-type-system.md create mode 100644 ptodsl/docs/user_guide/06-control-flow.md create mode 100644 ptodsl/docs/user_guide/07-frontend-operations.md create mode 100644 ptodsl/docs/user_guide/08-sync-dma-operations.md create mode 100644 ptodsl/docs/user_guide/09-vector-memory-operations.md create mode 100644 ptodsl/docs/user_guide/10-predicate-operations.md create mode 100644 ptodsl/docs/user_guide/11-vector-arithmetic-operations.md create mode 100644 ptodsl/docs/user_guide/12-cube-operations.md create mode 100644 ptodsl/docs/user_guide/13-examples.md create mode 100644 ptodsl/docs/user_guide/14-common-errors.md create mode 100644 ptodsl/docs/user_guide/15-compatibility-notes.md create mode 100644 ptodsl/docs/user_guide/16-next-steps.md diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md new file mode 100644 index 000000000..26012f781 --- /dev/null +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -0,0 +1,47 @@ +# TileLang Python DSL Guide + +The TileLang Python DSL provides a high-level, Pythonic interface for authoring vector compute kernels targeting the Ascend NPU hardware. This guide is intended for library developers and performance engineers who need to write efficient, hardware-aware kernels using the PTO micro instruction set. + +The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. + +## Language Tier + +The DSL surface is organized into multiple maturity tiers, reflecting the stability and intended use of different language features. As the design evolves, the basic authoring path is being explicitly separated from more advanced surfaces. Refer to the following table when reading this guide: + +| Surface Family | Tier | Usage Guidance | +|----------------|------|----------------| +| `TensorView` | `basic` | Default GM-facing data model for starter kernels. | +| `Tile` | `basic` | Default UB-facing compute tile for starter kernels. | +| Base vector ops (`make_mask`, `vlds`, `vsts`, `vadd`, `vmuls`, etc.) | `basic` | Default compute skeleton for starter kernels. | +| `strict_vecscope` | `advanced` | Explicit vector-scope management for expert authoring. | +| Raw pointer family (`ptr(...)`, `castptr`, `addptr`) | `advanced` | For expert authoring and migration; not required for Quick Start. | +| DMA family (`copy_*`, `set_loop*_stride_*`, `set_loop_size_*`, pad-fill control) | `advanced` | Direct DMA engine control for expert authoring, including GM→UB padding behavior. | +| Tile pointer helper (`tile.as_ptr()`) | `advanced` | Expert-only helper when advanced authoring needs explicit typed pointers. | + +For the authoritative tier classification, consult `tilelang-dsl/python/tilelang_dsl/support_matrix.py`. For known implementation gaps, refer to `tilelang-dsl/docs/unsupported-features.md`. + +### Basic vs Advanced Authoring Modes + +The TileLang DSL provides two distinct authoring modes: + +**Basic Mode (default)** +- Uses **Tile element/slice semantics** for buffer access +- Direct tile indexing syntax: `tile[start:]`, `tile[row, col:]`, `tile[row:, col]` (Tile indexing sugar only supports open-ended vector slices; explicit `stop` and `step` forms are not accepted for `Tile` indexing) +- Vector operations use element-indexing syntax: `pto.vlds(tile[row, col:])`, `pto.vsts(vec, tile[start:], mask)` +- No pointer arithmetic or explicit offset calculations +- Suitable for most kernel authoring with high-level abstractions + +**Advanced Mode (`advanced=True` in `@pto.vkernel`)** +- Uses **raw pointer semantics** for explicit memory management +- Direct pointer operations correspond to `pto.ptr` types in MLIR +- Explicit pointer arithmetic: `ptr(...)`, `castptr`, `addptr` +- Manual DMA engine control with low-level copy operations and explicit GM→UB padding behavior +- Requires explicit buffer management and pointer arithmetic +- Intended for expert users and performance-critical optimizations + +**Key Differences** +- **Basic mode**: Uses tile element-indexing syntax (`tile[row, col:]`, `tile[start:]`) for vector operations +- **Advanced mode**: Uses pointer byte-offset syntax (`pto.vlds(buf: ptr, offset)`) for vector operations +- Tile slices in basic mode correspond to MLIR `memref` types +- Raw pointers in advanced mode correspond to MLIR `pto.ptr` types +- No automatic conversion between tile and pointer semantics - choose the appropriate syntax for your authoring mode diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md new file mode 100644 index 000000000..26b0ba58b --- /dev/null +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -0,0 +1,78 @@ +## Quick Start + +**Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +TileLang DSL provides the following core constructs for kernel authoring: + +- `TensorView` – Access global memory (GM) tensors +- `Tile` – Local computation buffers in unified buffer (UB) +- Base vector operations (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations + +A typical kernel follows the GM → UB → vector compute → GM pattern: + +```python +import tilelang_dsl as pto + +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32)]) +def tile_scale( + input_tensor: pto.TensorView, + output_tensor: pto.TensorView, + work_tile: pto.Tile, + scale_factor: pto.f32, +): + dim0 = 4 + dim1 = 16 + + # Stage one GM tile into UB. + # GM -> UB data movement (implementation detail) + + # Run vector compute over the UB tile using tile indexing sugar. + for i in range(0, dim0): + mask = pto.make_mask(pto.f32, PAT.ALL) + vec = pto.vlds(work_tile[i, 0:]) + scaled = pto.vmuls(vec, scale_factor, mask) + pto.vsts(scaled, work_tile[i, 0:], mask) + + # Write the UB result back to GM. + # UB -> GM data movement (implementation detail) +``` + +The example illustrates the key components of a TileLang kernel: + +1. **`TensorView` parameters** – Access global memory tensors +2. **`Tile` parameters** – Local computation buffers in unified buffer (UB) +3. **Base vector operations** (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations + +Here is a second example with two inputs and one output: + +```python +@pto.vkernel( + target="a5", + op="elementwise_add", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)], +) +def elementwise_add( + lhs_gm: pto.TensorView, + rhs_gm: pto.TensorView, + out_gm: pto.TensorView, + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, +): + dim0 = 4 + dim1 = 16 + + # GM -> UB data movement (implementation detail) + + for lane in range(0, 256, 64): + mask = pto.make_mask(pto.f32, PAT.ALL) + lhs_vec = pto.vlds(lhs_tile, lane) + rhs_vec = pto.vlds(rhs_tile, lane) + summed = pto.vadd(lhs_vec, rhs_vec, mask) + pto.vsts(summed, dst_tile, lane, mask) + + # UB -> GM data movement (implementation detail) +``` + +Both examples follow the same fundamental pattern: load data from global memory into local tiles, perform vector operations, and store results back. The compiler automatically infers vector-scope boundaries for the base vector operations. The `Tile` parameters are specialized to concrete shapes during compilation. Later sections cover advanced features such as matchers, template slots, raw pointer operations, and explicit scope management with `strict_vecscope`. + diff --git a/ptodsl/docs/user_guide/03-kernel-declaration.md b/ptodsl/docs/user_guide/03-kernel-declaration.md new file mode 100644 index 000000000..73c9e1800 --- /dev/null +++ b/ptodsl/docs/user_guide/03-kernel-declaration.md @@ -0,0 +1,528 @@ +## Core Concepts + +### Kernel Declaration + +TileLang DSL exposes two kernel decorators: + +- `@pto.vkernel` for the Vector (AIV) execution model +- `@pto.ckernel` for the Cube (AIC) execution model + +#### Basic Syntax + +```python +@pto.vkernel( + target="a5", # Target architecture + op="pto.matmul ins(a, b) -> outs(c)", # PTO op + operand schema + dtypes=[(pto.f16, pto.f16, pto.f32)], # Type signatures + constraints=[ # Additional constraints + lambda a, b: a.shape[1] == b.shape[0], + lambda batch=1: batch >= 1, + ], + priority=100 # Priority for selection +) +def matmul_fallback(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # kernel implementation +``` + +#### Decorator Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | No* | PTO operation matcher. Preferred form is schema mode: `"pto.op_name ins(in0, in1, ...) -> outs(out0, out1, ...)"`. Legacy bare-op form (`"pto.op_name"`) is still accepted for compatibility. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. Use this when one descriptor should match multiple concrete ops (schema mode is currently only supported in `op`). | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands (inputs and outputs) in order. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static template-slot mappings. Each slot maps concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | +| `constraints` | `List[Callable[..., bool]]` | No | Additional selection-time predicates. Constraint arguments bind by name to kernel parameter proxy objects or `context_attrs` keys. Default: empty list. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Higher values have higher priority. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | +| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is available in both modes and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | + +#### Operation Schema in `op` (ins/outs) + +`op` supports a schema string that declares how kernel parameter names map to PTO op operands: + +```python +op="pto.tadds ins(src, scalar) -> outs(dst)" +``` + +Schema form: + +```text + ins(, , ...) -> outs(, , ...) +``` + +Rules: + +1. `ins(...)` and `outs(...)` are both required in schema mode. +2. Names in `ins` and `outs` must be valid, unique Python identifiers. +3. The decorated function parameter list must exactly match `ins + outs` by both count and name. +4. MLIR function argument ordering is defined by schema order (`ins` first, then `outs`). +5. Constraint binding keeps using parameter names; schema mode makes these names explicit and stable. +6. Schema mode applies to `op=...` (single matcher op). `ops=[...]` remains bare-op matching. + +Example: + +```python +@pto.vkernel( + target="a5", + op="pto.tadds ins(src, scalar) -> outs(dst)", + dtypes=[(pto.f32, pto.f32, pto.f32)], +) +def template_tadds(src: pto.Tile, scalar: pto.f32, dst: pto.Tile): + return None +``` + +If names or order do not match, descriptor construction fails early with a schema mismatch error. + + +#### Type Matching Rules + +The `dtypes` parameter supports flexible type matching: + +1. **Concrete Types**: Exact type matches using DSL scalar types: + - `pto.f16`, `pto.f32`, `pto.bf16` + - `pto.i8`, `pto.si8`, `pto.ui8` + - `pto.i16`, `pto.si16`, `pto.ui16` + - `pto.i32`, `pto.si32`, `pto.ui32` + - `pto.i64`, `pto.si64`, `pto.ui64` + - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` + + Builtin vector operands still use their element dtype in `dtypes=[...]`. + For example, a parameter annotated as `ex_vec: pto.vector(pto.i16, (4,))` + contributes `pto.i16` to the signature tuple, while the vector shape + contract stays in the parameter annotation. + +2. **Type Wildcards**: Generic type patterns: + - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) + - `pto.AnyInt`: Matches any integer type (`i*`, `si*`, `ui*`) + - `pto.AnyType`: Matches any scalar type + - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) + +3. **Type Variables**: Named type variables that enforce consistency within a signature: + ```python + T = pto.TypeVar('T') # Define a type variable + + @pto.vkernel( + target="a5", + op="elementwise", + dtypes=[(T, T, T)], # All three operands must have the same type + constraints=[] + ) + def elementwise_same_type(x: pto.Tile, y: pto.Tile, out: pto.Tile) -> None: + # x, y, and out must have identical element types + pass + ``` + +4. **Mixed Signatures**: Multiple type signatures for the same operation: + ```python + @pto.vkernel( + target="a5", + op="add", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), # Float addition + (pto.AnyInt, pto.AnyInt, pto.AnyInt) # Integer addition + ] + ) + def generic_add(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Supports both float and integer types + pass + ``` + +#### Constraint System + +Constraints are compile-time predicates that refine kernel selection. In the current implementation, each entry in `constraints=[...]` is a Python callable returning `True` or `False`. + +##### Predefined Constraints + +| Constraint | Description | +|------------|-------------| +| `k_dim_aligned_64` | K dimension is aligned to 64 elements (for matmul kernels). | +| `continuous_memory` | Operands reside in contiguous memory regions. | +| `requires_ub_memory` | Operation requires Unified Buffer memory (vs. Global Memory). | +| `tensor_rank(rank)` | Operand tensor has specified rank (e.g., `tensor_rank(2)` for 2D tensors). | +| `broadcastable` | Operands are broadcastable according to NumPy-style broadcasting rules. | +| `static_shape` | All tensor dimensions are known at compile time (no dynamic shapes). | + +##### Logical Constraint Combinators + +| Combinator | Description | Example | +|------------|-------------|---------| +| `AnyOf(c1, c2, ...)` | At least one of the constraints must be satisfied. | `AnyOf(k_dim_aligned_64, continuous_memory)` | +| `AllOf(c1, c2, ...)` | All constraints must be satisfied. | `AllOf(tensor_rank(2), static_shape)` | +| `Not(c)` | The constraint must not be satisfied. | `Not(requires_ub_memory)` | + +##### Custom Constraints + +Users can define custom constraints using predicate functions: + +```python +# Define a custom constraint that consumes one context attr by name. +def large_batch(min_batch: int): + return lambda batch=0: batch >= min_batch + +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[large_batch(1024)] +) +def large_batch_matmul(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized for large batch sizes + pass +``` + +Constraint callables bind by parameter name. + +- Kernel parameter names such as `src`, `dst`, `a`, `b` receive lightweight proxy objects, so constraints can use direct expressions like `src.shape[0] <= dst.shape[0]`. +- Extra `context_attrs` passed to `pto.select_kernel(...)` bind by key name, for example `batch`, `enabled`, or `expected_rows`. + +##### Parameter Proxy Objects + +When a constraint argument name matches a kernel parameter name, the callable receives a lightweight proxy object rather than raw Python data. + +- For `TensorView` parameters, the proxy exposes `rank`, `shape`, `strides`, `dtype`, and `memory_space`. +- For `Tile` parameters, the proxy exposes `rank`, `shape`, `valid_shape`, `dtype`, `memory_space`, and `config`. +- `shape`, `strides`, and `valid_shape` support index access such as `src.shape[0]` or `dst.valid_shape[1]`. +- Missing or not-yet-known metadata evaluates as "unknown", so comparisons conservatively pass rather than failing early. + +Example: + +```python +def tload_preconditions(src, dst): + logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] + logical_cols = src.shape[4] + return ( + src.rank == 5 + and src.strides[4] == 1 + and dst.valid_shape[0] <= logical_rows + and dst.valid_shape[1] <= logical_cols + and logical_rows <= dst.shape[0] + and logical_cols <= dst.shape[1] + ) + +@pto.vkernel( + target="a5", + op="pto.tload", + dtypes=[(pto.f32, pto.f32)], + constraints=[tload_preconditions], +) +def template_tload(src: pto.TensorView, dst: pto.Tile): + return None +``` + +This is the recommended constraint style for current TileLang DSL head. + +##### Builtin Vector Parameters + +When a kernel needs to match a builtin MLIR vector operand, annotate that +parameter with `pto.vector(element_dtype, shape)`. + +```python +@pto.vkernel( + target="a5", + op="pto.tmrgsort ins(src0, src1, tmp) -> outs(dst, ex_vec)", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.i16)], +) +def template( + src0: pto.Tile, + src1: pto.Tile, + tmp: pto.Tile, + dst: pto.Tile, + ex_vec: pto.vector(pto.i16, (4,)), +): + return None +``` + +Rules: + +- Use `pto.vector(...)` for builtin vector operands, not Python `list`. +- `shape` is a Python tuple. A 1-D vector of length 4 is written `(4,)`. +- `dtypes=[...]` still records only the element dtype for that operand (`pto.i16` + in the example above). +- `pto.vector(...)` is distinct from `pto.vreg(...)`: the former models builtin + `vector<...>`, the latter models fixed-width VPTO vector registers. + +#### Kernel Selection Mechanism + +When a PTO operation needs implementation, the system performs the following matching process: + +1. **Target Filtering**: Select kernels with matching `target` architecture. +2. **Operation Filtering**: Select kernels whose matcher metadata covers the concrete query op: + - `op="foo"` requires exact match + - `op="foo ins(...) -> outs(...)"` still matches by op name `foo`; `ins/outs` additionally defines parameter naming/order contract for descriptor validation and materialization + - `ops=[...]` requires the concrete query op to appear in that list +3. **Type Matching**: For each kernel's `dtypes` list, check if any signature matches the operation's operand types: + - Concrete types must match exactly. + - Wildcard types match according to their category. + - Type variables must be consistent within the signature. +4. **Constraint Validation**: For each matching kernel, evaluate all `constraints`. If any constraint fails, the kernel is rejected. +5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. +6. **Fallback**: If no kernel matches, compilation fails with an error. + +For multi-op descriptors selected through `ops=[...]`, `pto.select_kernel(...)` +also binds the concrete query op before materialization. This bound +`selected_op` is what template-slot expansion uses later. + +The package also exposes explicit selection utilities: + +```python +registry = pto.KernelRegistry() +registry.register(my_kernel) + +selected = pto.select_kernel( + "a5", + "matmul", + (pto.f16, pto.f16, pto.f32), + context_attrs={"k_aligned": True}, + registry=registry, +) +``` + +`pto.select_kernel(...)` also supports an opt-in diagnostics path for matcher debugging: + +```python +report = pto.select_kernel( + "a5", + "matmul", + (pto.f16, pto.f16, pto.f32), + context_attrs={"k_aligned": False}, + return_metadata=True, + include_mlir=False, +) +``` + +When `return_metadata=True`, the result is a `KernelSelectionReport` instead of one +selected descriptor. + +- `report.selected` carries the winner when one candidate is selected. +- `report.final_status` is one of `selected`, `no_candidate`, or `priority_tie`. +- `report.final_error` summarizes the final selection outcome. +- `report.candidates` contains one `KernelSelectionCandidateMetadata` per + `target/op`-matched descriptor, including `dtype_mismatch`, + `constraint_failed`, `constraint_error`, `priority_shadowed`, `selected`, and + `priority_tie` states. + +Constraint diagnostics in report mode include: + +- `failed_constraint_index` +- `failed_constraint_name` +- `failed_constraint_location` as `file:line` + +For best diagnostics, prefer splitting compound predicates into multiple +constraint entries instead of writing one large `cond0 and cond1 and cond2` +callable. Report mode can precisely identify which constraint entry failed, but +it does not introspect which sub-expression inside one Python boolean +expression returned `False`. + +When `include_mlir=True`, report mode also attempts `mlir_text()` for candidates +that pass constraint evaluation. + +- On success, the candidate carries `mlir_text`. +- On materialization failure such as missing `specialize()` bindings, the + candidate carries `mlir_error`. +- Use `include_mlir=False` to skip this extra materialization attempt. + +#### Examples + +##### Matmul with Multiple Implementations + +```python +# High-performance kernel for aligned K dimension +def k_aligned_64(k=0): + return k % 64 == 0 + +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[k_aligned_64], + priority=200 +) +def matmul_aligned_k(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized implementation for aligned K + pass + +# General-purpose fallback +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], + constraints=[], + priority=100 +) +def matmul_general(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Generic implementation + pass +``` + +##### Elementwise Operation with Type Polymorphism + +```python +def same_shape(a, b, out): + return a.shape[0] == out.shape[0] and b.shape[0] == out.shape[0] + +@pto.vkernel( + target="a5", + op="pto.add ins(a, b) -> outs(out)", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), + (pto.AnyInt, pto.AnyInt, pto.AnyInt) + ], + constraints=[same_shape] +) +def polymorphic_add(a: pto.Tile, b: pto.Tile, out: pto.Tile) -> None: + # Single implementation handles both float and integer types + dtype = a.element_type + all_mask = pto.make_mask(dtype, PAT.ALL) + # ... implementation using generic vector operations + pass +``` + +##### Constrained Convolution Kernel + +```python +def prefer_static_nhwc(src, weight): + return src.rank == 4 and weight.rank == 4 + +@pto.vkernel( + target="a5", + op="pto.conv2d ins(input, filter) -> outs(output)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[prefer_static_nhwc], + priority=150 +) +def conv2d_nhwc_f16_f32(input: pto.Tile, filter: pto.Tile, output: pto.Tile) -> None: + # Optimized for NHWC layout with static shapes + pass +``` + +--- + +### Cube Kernel Declaration + +Cube kernels target the AIC (Cube) hardware unit for matrix multiplication operations. Unlike Vector kernels, Cube kernels operate on raw `pto.ptr` pointers and do not use `vecscope` execution scopes. + +#### Basic Syntax + +```python +@pto.ckernel( + target="a5", + op="pto.mad", # concrete matcher op + dtypes=[(pto.f16, pto.f16, pto.f32)], # selection dtype signature + name="my_gemm", # optional registry/debug name +) +def gemm(inp: pto.TensorView): + # Cube kernel body — linear cube authoring IR + ... +``` + +#### Parameter Type Conventions + +Cube kernel parameters represent different roles in the data flow: + +| Parameter Type | Role | Description | +|---------------|------|-------------| +| `PartitionTensorView` | GM input/output | Tiled view of a logical tensor in GM, partitioned by the caller | +| `TensorView` | GM input/output | Full logical tensor view in GM (for non-partitioned use) | +| `Tile` (specific addr space) | Pre-allocated hardware buffer | Tile already allocated in LEFT/RIGHT/ACC/MAT/BIAS address space | +| `int` | Dimension | Scalar dimension parameter (M, K, N, etc.) | +| `pto.f16` / `pto.f32` etc. | Scalar | Scalar parameters (threshold, alpha, etc.) | + +GM payload is modeled through `TensorView` and `PartitionTensorView`. `Tile` +values represent staged hardware buffers allocated in concrete hardware address +spaces such as `MAT`, `LEFT`, `RIGHT`, `ACC`, and `BIAS` via `pto.Tile`. + +#### Decorator Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | No | Target hardware architecture. Cube DSL v1 supports `"a5"`. Default: `"a5"`. | +| `op` | `str` | 与 `ops` 二选一 | Single concrete matcher op. Bare-op strings such as `"pto.mad"` are supported. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | 与 `op` 二选一 | List of concrete matcher ops for shared-body selection and template-slot dispatch. **Mutually exclusive with `op`**. | +| `dtypes` | `List[Tuple[Type, ...]]` | Recommended | List of selection dtype signatures. For cube kernels, these signatures describe the concrete query op rather than necessarily mirroring the Python parameter list. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static template-slot mappings. Each slot maps concrete op names to real `pto.*` calls. Required when the kernel body uses `pto.tpl(...)`. | +| `name` | `str` | No | Descriptor name used for registration, debugging, and emitted symbol naming. Defaults to the decorated function name. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Default: `0`. | + +#### Key Differences from `@pto.vkernel` + +| Feature | `@pto.vkernel` (Vector) | `@pto.ckernel` (Cube) | +|---------|--------------------------|------------------------| +| Hardware unit | AIV (Vector) | AIC (Cube) | +| Execution scope | `pto.vecscope` / `pto.strict_vecscope` | **No scope** — function body is linear IR | +| GM data input | `TensorView` / `Tile` | `TensorView` / `PartitionTensorView` | +| Operand abstraction | Tile + vector registers + masks | `pto.ptr` raw pointers | +| Core operations | Vector ALU, load/store | Data movement (cube_load/store) + matmul (mad) | +| Address spaces | GM, UB (VEC) | GM, MAT, LEFT, RIGHT, ACC, BIAS, UB | +| Generated IR attr | `#pto.kernel_kind` | `#pto.kernel_kind` | + +#### Programming Model + +Cube kernels follow a GM → L1 → L0 → compute → L0 → GM data flow: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm", +) +def gemm(a_tv: pto.PartitionTensorView, # [M, K] in GM + b_tv: pto.PartitionTensorView, # [K, N] in GM + c_tv: pto.PartitionTensorView): # [M, N] in GM, output + # 1. Get GM pointers from PartitionTensorViews + a_ptr = a_tv.as_ptr() # -> pto.ptr + b_ptr = b_tv.as_ptr() # -> pto.ptr + c_ptr = c_tv.as_ptr() # -> pto.ptr + + # 2. Allocate L1 (MAT) tile buffers (returns Tile, then get ptr) + l1_a = pto.Tile([16, 32], pto.f16, pto.MemorySpace.MAT) + l1_b = pto.Tile([32, 16], pto.f16, pto.MemorySpace.MAT) + + # 3. Allocate L0 tile buffers + l0a = pto.Tile([16, 32], pto.f16, pto.MemorySpace.LEFT) + l0b = pto.Tile([32, 16], pto.f16, pto.MemorySpace.RIGHT) + l0c = pto.Tile([16, 16], pto.f32, pto.MemorySpace.ACC) + + # 4. GM → L1 data movement + pto.cube_load(a_ptr, l1_a.as_ptr(), 16, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b.as_ptr(), 16, nburst=(1, 0, 0)) + + # 5. L1 → L0 data movement + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), 16, 32) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), 32, 16) + + # 6. Matrix multiplication + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), 16, 16, 32) + + # 7. L0C → GM writeback + pto.acc_store_gm( + l0c.as_ptr(), c_ptr, 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND + ) +``` + +This example shows a **full-pipeline** kernel that handles data movement and compute. Alternatively, a **pure-compute** kernel can take pre-allocated tiles directly: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="matmul_compute", +) +def matmul_compute(a_left: pto.Tile, # Pre-allocated LEFT tile (L0A) + b_right: pto.Tile, # Pre-allocated RIGHT tile (L0B) + c_acc: pto.Tile): # Pre-allocated ACC tile (L0C) + pto.mad_acc(a_left.as_ptr(), b_right.as_ptr(), c_acc.as_ptr(), 16, 16, 32) +``` + +#### Hardware Isolation + +- `@pto.ckernel` functions generate `#pto.kernel_kind` IR attribute. +- `@pto.vkernel` functions generate `#pto.kernel_kind` IR attribute. +- The IR verifier prevents Cube and Vector operations from appearing in the same function. +- The DSL semantic analyzer additionally checks that Cube kernel bodies do not contain Vector-specific operations (`vlds`, `vadd`, etc.) or `vecscope` scopes. +- Both kernel types can coexist in the same `.py` file; each compiles independently with conditional compilation macros (`__DAV_CUBE__` / `__DAV_VEC__`). + +For the complete Cube operation reference and `pto.Tile` constructor details, see [Cube Matrix Multiply Operations](12-cube-operations.md). diff --git a/ptodsl/docs/user_guide/04-template-kernels.md b/ptodsl/docs/user_guide/04-template-kernels.md new file mode 100644 index 000000000..9fcda0fd0 --- /dev/null +++ b/ptodsl/docs/user_guide/04-template-kernels.md @@ -0,0 +1,333 @@ +### Template-based Kernel Authoring + +For operations that share similar computation patterns but differ in their core vector operations, the DSL supports template-based kernel authoring. This allows a single kernel implementation to serve multiple related operations through parameterized templates. + +#### Multi-operation Kernels with `ops` Parameter + +Instead of specifying a single `op` parameter, you can provide an `ops` list to match multiple operations: + +```python +@pto.vkernel( + target="a5", + ops=["tadd", "tsub", "tmul", "tdiv"], # List of operations + dtypes=[(T, T, T)], # Type signature using type variable + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + } +) +def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, elems_per_vreg): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) # Template dispatch + pto.vsts(out, dst[row, col:], mask) +``` + +`op` and `ops` are mutually exclusive, and exactly one of them must be +provided. `ops=[...]` only widens the matcher set; callers still use +`pto.select_kernel(target, concrete_op, operand_types, ...)` with a concrete +PTO op such as `"tadd"` or `"tmul"`. + +#### Template System + +The template system consists of three components: + +1. **`templates` parameter**: A dictionary mapping template names to operation-specific implementations +2. **`pto.tpl()` function**: A compile-time placeholder that resolves to the appropriate implementation for the currently selected concrete op +3. **`ops` parameter**: Replaces the singular `op` parameter for multi-operation kernels + +##### Template Definition + +Templates are defined in the `templates` parameter of `@pto.vkernel`. Each template is a dictionary mapping operation names to implementation strings: + +```python +templates={ + "template_name": { + "op1": "implementation_for_op1", + "op2": "implementation_for_op2", + # ... + }, + "another_template": { + "op1": "different_implementation_for_op1", + # ... + } +} +``` + +Template-slot metadata is static and validated when the descriptor is +registered: + +- slot names must be non-empty strings +- mapping keys must be concrete ops covered by the descriptor matcher set +- mapping values must be supported real `pto.*` op names + +The implementation strings are typically vector operation names such as +`"vadd"`, `"vsub"`, `"vmul"`, and `"vdiv"`, which are resolved during kernel +expansion. + +##### Template Usage with `pto.tpl()` + +The `pto.tpl()` operation enables template dispatch for multi-operation kernels, allowing code reuse across related operations through compile-time substitution. + +#### `pto.tpl(template_name: str, *args) -> Any` + +**Description**: Template dispatch operation for multi-operation kernels. Resolves to different implementations based on the current operation being expanded. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `template_name` | `str` | Name of the template to dispatch | +| `*args` | `Any` | Positional arguments passed unchanged to the resolved real implementation | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `Any` | Result of the template implementation | + +**Behavior**: +- Only valid inside kernels decorated with `@pto.vkernel` that have a `templates` parameter +- The first argument must be a string literal template-slot name +- During kernel expansion for a specific operation `op_name`, `pto.tpl("template_name", ...)` is replaced with the implementation specified in `templates["template_name"]["op_name"]` +- The replacement is a direct compile-time substitution; positional arguments are passed unchanged +- Template implementations are typically string names of vector operations (e.g., `"vadd"`, `"vsub"`) +- `pto.select_kernel(...)` must bind a concrete op before template expansion can happen +- Python dict lookup, callable values, lambdas, and other runtime dispatch patterns are not part of the supported kernel-body surface + +**Example**: +```python +@pto.vkernel( + ops=["tadd", "tsub"], + dtypes=[(T, T, T)], + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + } + } +) +def elementwise_kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # ... load vectors + result = pto.tpl("core", lhs, rhs, mask) # Expands to vadd for tadd, vsub for tsub + # ... store result +``` + +**Constraints**: +- Template names must be defined in the `templates` parameter of the `@pto.vkernel` decorator +- When a kernel body uses `pto.tpl("slot", ...)`, that slot must define an implementation for the currently selected concrete op +- Template implementations must be valid operation names in the DSL + +#### Decorator Parameters Update + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | No* | Name of the PTO operation to match. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static slot mappings from concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | +| `constraints` | `List[Constraint]` | No | Additional constraints that must be satisfied for kernel selection. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | +| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is mode-independent and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | + +**Note**: +- Either `op` or `ops` must be provided, but not both. +- `templates` is only needed when the kernel body uses `pto.tpl(...)`. +- `pto.select_kernel(...)` still queries with a concrete op even for `ops=[...]` descriptors. + +#### Advanced Template Patterns + +##### Multiple Templates per Kernel + +A kernel can define multiple templates for different aspects of the computation: + +```python +@pto.vkernel( + target="a5", + ops=["tadd_relu", "tsub_relu", "tadd_abs", "tsub_abs"], + dtypes=[(T, T, T)], + templates={ + "arithmetic": { + "tadd_relu": "vadd", + "tsub_relu": "vsub", + "tadd_abs": "vadd", + "tsub_abs": "vsub", + }, + "postprocess": { + "tadd_relu": "vrelu", + "tsub_relu": "vrelu", # Same activation for both + "tadd_abs": "vabs", + "tsub_abs": "vabs", + } + } +) +def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # ... load vectors + arith_result = pto.tpl("arithmetic", lhs, rhs, mask) + postprocessed = pto.tpl("postprocess", arith_result, mask) + # ... store result +``` + +##### Compile-time Substitution Model + +Template-slot expansion happens before semantic checking and lowering: + +- `pto.select_kernel(...)` first binds a concrete op such as `"tadd"` +- the frontend then resolves `pto.tpl("core", ...)` using `templates["core"]["tadd"]` +- the placeholder is rewritten to a real `pto.*` call before semantic analysis +- diagnostics for unknown slots, missing mappings, or unsupported resolved surfaces are raised before any VPTO IR is generated + +#### Type Variables in Template Kernels + +Template kernels often use type variables to enforce type consistency: + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + target="a5", + ops=["tadd", "tsub"], + dtypes=[(T, T, T)], # All three operands share type T + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + } + } +) +def typed_elementwise(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # Type variable T ensures all tiles have same element type + dtype = dst.element_type # This is type T + # ... implementation +``` + +#### Selection Mechanism for Template Kernels + +When a PTO operation matches a template kernel: +1. The system selects the descriptor based on `op` exact match or `ops` list inclusion. +2. `pto.select_kernel(...)` binds the concrete query op as the descriptor's `selected_op`. +3. During frontend expansion, `pto.tpl()` calls are resolved using that bound concrete op. +4. For operation `"op_name"`, template `"template_name"` resolves to `templates["template_name"]["op_name"]`. +5. The resolved string (e.g., `"vadd"`) is replaced with the corresponding real DSL operation before semantic analysis and lowering. + +#### Example: Unified Arithmetic Kernel + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + ops=["tadd", "tsub", "tmul", "tdiv", "tmax", "tmin"], + dtypes=[(T, T, T)], + advanced=True, + templates={ + "arithmetic": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + "tmax": "vmax", + "tmin": "vmin", + } + } +) +def unified_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + """Single implementation for six arithmetic operations.""" + dtype = dst.element_type + rows, cols = dst.valid_shape + elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, elems_per_vreg): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("arithmetic", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) +``` + +#### Compile-time Specialization with `pto.constexpr` + +The `pto.constexpr` construct enables compile-time branching for kernel specialization, allowing different code paths to be selected based on static compile-time information. Unlike runtime conditionals that generate control flow, `pto.constexpr` branches are resolved during kernel descriptor materialization, with only the selected branch retained for lowering. + +**Syntax and Usage**: +```python +if pto.constexpr(condition): + # Branch taken if condition evaluates to True at compile time + ... +else: + # Branch taken if condition evaluates to False at compile time + ... +``` + +**Semantics**: +- The `condition` must be evaluable at compile time during kernel descriptor materialization. +- Only the selected branch is analyzed, semantically checked, and lowered to VPTO IR. +- The non-selected branch is discarded entirely and does not contribute to runtime control flow or value merging. +- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. + +**Comparison with Runtime Conditionals**: + +| Aspect | Runtime `if` | `pto.constexpr` | +|--------|--------------|-----------------| +| **Evaluation time** | Runtime | Compile-time (descriptor materialization) | +| **Control flow** | Generates `scf.if` with merge logic | No runtime control flow; branch eliminated | +| **Value merging** | Both branches must produce compatible values for merge | No value merging; only one branch exists after elimination | +| **Use case** | Dynamic decision making based on runtime values | Code generation specialization based on static parameters | + +**Typical Static Inputs**: +- Literal integers, booleans, and strings +- Data type symbols (`src.element_type`, `dst.element_type`) and comparisons derived from them +- Statically specialized `Tile.shape` and `Tile.valid_shape` values +- Frontend query helpers such as `pto.bytewidth(dtype)` and `pto.elements_per_vreg(dtype)` (which computes elements per vector register) + +**Constraints and Notes**: +- `TensorView.shape` and `TensorView.strides` may be represented by hidden kernel parameters rather than descriptor-time constants. They should not be assumed constexpr unless separately bound through specialization or other compile-time context. +- `pto.constexpr` is a frontend-only authoring construct; it does not correspond to any runtime VPTO instruction. + +**Guidelines**: +- Use `constraints=[...]` and `pto.select_kernel(...)` when specialization requires selecting an entirely different kernel descriptor. +- Use `pto.constexpr` when the kernel remains the same but internal regions require specialization based on compile-time parameters. + +**Example**: +```python +@pto.vkernel(target="a5", op="pto.trowsum") +def template_trowsum(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): + acc_dtype = tmp.element_type + dst_dtype = dst.element_type + acc_mask_1, _ = pto.make_mask(acc_dtype, 1) + dst_mask_1, _ = pto.make_mask(dst_dtype, 1) + + if pto.constexpr(acc_dtype != dst_dtype): + # Type conversion required + v_acc_casted = pto.vcvt(v_acc, dst_dtype, acc_mask_1) + pto.vsts(v_acc_casted, dst[row, 0:], dst_mask_1) + else: + # No conversion needed + pto.vsts(v_acc, dst[row, 0:], dst_mask_1) +``` + +### Value Model + +The DSL operates on symbolic values, not Python runtime values: +- **Constants**: Python literals that are typed to machine types +- **Operation results**: Values produced by DSL operations +- **Block arguments**: Values introduced by control flow structures + +### Memory Spaces + +The DSL supports different memory spaces: +- `MemorySpace.GM`: Global Memory +- `MemorySpace.UB`: Unified Buffer (local storage for vector computation) diff --git a/ptodsl/docs/user_guide/05-type-system.md b/ptodsl/docs/user_guide/05-type-system.md new file mode 100644 index 000000000..c40f12475 --- /dev/null +++ b/ptodsl/docs/user_guide/05-type-system.md @@ -0,0 +1,686 @@ + + +## Type System + +### Scalar Types + +| DSL Type | Description | Bit Width | +|----------|-------------|-----------| +| `pto.i1` | Boolean | 1 | +| `pto.i8` | 8-bit signless integer | 8 | +| `pto.si8` | 8-bit signed integer | 8 | +| `pto.ui8` | 8-bit unsigned integer | 8 | +| `pto.i16` | 16-bit signless integer | 16 | +| `pto.si16` | 16-bit signed integer | 16 | +| `pto.ui16` | 16-bit unsigned integer | 16 | +| `pto.i32` | 32-bit signless integer | 32 | +| `pto.si32` | 32-bit signed integer | 32 | +| `pto.ui32` | 32-bit unsigned integer | 32 | +| `pto.i64` | 64-bit signless integer | 64 | +| `pto.si64` | 64-bit signed integer | 64 | +| `pto.ui64` | 64-bit unsigned integer | 64 | +| `pto.f16` | Half precision float | 16 | +| `pto.bf16` | Brain float 16 | 16 | +| `pto.f32` | Single precision float | 32 | + +Python literals are automatically typed: +- `bool` → `pto.i1` +- `int` → Context-dependent (typically `pto.i32` or `pto.i64`) +- `float` → `pto.f32` + +For explicit typing, use type constructors: +```python +x = pto.i32(1024) # Explicit i32 constant +y: pto.i32 = 1024 # Type annotation +z = pto.ui16(7) # Explicit unsigned 16-bit constant +``` + +Static dtype bindings can also be called like constructors. This is useful when +the dtype comes from compile-time metadata such as `element_type`: + +```python +idx_dtype = tile.element_type +zero_idx = idx_dtype(0) +v_col = idx_dtype(col) +``` + +Integer sign semantics are part of the DSL type surface. `pto.si16`, +`pto.ui16`, and `pto.i16` are distinct scalar dtypes and lower to `si16`, +`ui16`, and `i16` respectively in VPTO IR. + +### Integer Literal Guidance + +For ordinary integer constants, prefer plain integer literals instead of +string forms. + +```python +count = pto.i32(1024) +delta = pto.i16(-12) +min_i32 = pto.i32(-2147483648) +unsigned_hi = pto.ui16(32768) +``` + +Integer string literals are reserved for explicit bit-pattern authoring. They +must use hex form. + +```python +# Use hex strings only when you intentionally want fixed-width bit-pattern +# interpretation at the target dtype width. +hi_bit = pto.i32("0x80000000") # -2147483648 +all_ones = pto.i16("0xFFFF") # -1 +unsigned_hi = pto.ui16("0x8000") # 32768 +``` + +Rules: +- Prefer plain integer literals such as `pto.i32(1024)` or `pto.i16(-12)` for normal integer authoring. +- Integer string literals must use hex bit-pattern form such as `"0xFFFF"`. +- Ordinary integer strings such as `"1024"` or `"-12"` are rejected; write them as integer literals instead. +- For signed and signless integer dtypes (`pto.i*`, `pto.si*`), hex strings use two's-complement interpretation at the target dtype width. +- For unsigned integer dtypes (`pto.ui*`), hex strings keep their unsigned value. +- Hex strings must fit within the target bit width. For example, `pto.i16("0x10000")` is rejected because the literal exceeds 16 bits. + +### Floating-Point Literal Forms + +`pto.f16(...)`, `pto.bf16(...)`, and `pto.f32(...)` accept multiple literal forms. + +```python +# Signed numeric literals +a = pto.f16(-1.5) +b = pto.bf16(+2.5) +c = pto.f32(-3.5) + +# Special floating-point values +pos_inf = pto.f32("inf") +neg_inf = pto.f32("-inf") +qnan = pto.f32("nan") + +# Bit-pattern form (hex string, interpreted by target dtype) +f16_neg_inf = pto.f16("0xFC00") +bf16_neg_inf = pto.bf16("0xFF80") +f32_neg_inf = pto.f32("0xFF800000") +``` + +Notes: +- Prefer dtype constructors for reduction seeds and boundary values (for example rowmax initialization). +- For float bit-pattern constants, pass a **string** hex literal to the matching dtype constructor. +- Avoid passing raw integer bit-patterns directly into vector broadcast/dup APIs when a floating vector is expected. +- `float(...)` function calls are not part of the TileLang DSL public call surface; use constructor forms above. + +### Vector Register Type + +Vector registers have fixed 256-byte width: + +```python +v_f32 = pto.vreg(pto.f32) # !pto.vreg<64xf32> +v_f16 = pto.vreg(pto.f16) # !pto.vreg<128xf16> +v_i8 = pto.vreg(pto.i8) # !pto.vreg<256xi8> +``` + +`pto.vreg(dtype)` only takes the element type. The frontend infers the element count automatically from the fixed 256-byte register width: + +- `pto.f32` → `!pto.vreg<64xf32>` +- `pto.f16` → `!pto.vreg<128xf16>` +- `pto.bf16` → `!pto.vreg<128xbf16>` +- `pto.i32` → `!pto.vreg<64xi32>` +- `pto.si32` → `!pto.vreg<64xsi32>` +- `pto.ui32` → `!pto.vreg<64xui32>` +- `pto.i16` → `!pto.vreg<128xi16>` +- `pto.si16` → `!pto.vreg<128xsi16>` +- `pto.ui16` → `!pto.vreg<128xui16>` +- `pto.i8` → `!pto.vreg<256xi8>` +- `pto.si8` → `!pto.vreg<256xsi8>` +- `pto.ui8` → `!pto.vreg<256xui8>` + +Constraint: `element_count × bitwidth(element_type) = 2048` + +Use `pto.elements_per_vreg(dtype)` when you need the inferred element count explicitly: + +```python +v_dtype = pto.vreg(pto.f32) +lanes0 = v_dtype.elements_per_vreg # 64 +lanes1 = pto.elements_per_vreg(pto.f32) # 64 +``` + +Current TileLang DSL v1 vector lowering supports the 8/16/32-bit integer +families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32` element types. + +### Builtin Vector Type + +TileLang DSL v1 also exposes builtin MLIR vector types through +`pto.vector(element_dtype, shape)`. + +```python +executed_ty = pto.vector(pto.i16, (4,)) # vector<4xi16> +``` + +This type is different from `pto.vreg(...)`: + +- `pto.vreg(dtype)` models a VPTO vector register with fixed 256-byte width. +- `pto.vector(dtype, shape)` models a builtin MLIR `vector<...>` type with an + explicit static shape. + +Use `pto.vector(...)` when a kernel parameter or intermediate value must match +an existing builtin vector operand in PTO IR, for example an auxiliary +`vector<4xi16>` operand carried by a tile op template. + +```python +@pto.vkernel( + target="a5", + op="pto.tmrgsort ins(src0, src1, tmp) -> outs(dst, ex_vec)", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.i16)], +) +def template( + src0: pto.Tile, + src1: pto.Tile, + tmp: pto.Tile, + dst: pto.Tile, + ex_vec: pto.vector(pto.i16, (4,)), +): + return None +``` + +Notes: + +- `shape` must be a Python tuple of integers. For a 1-D vector, write `(4,)`, + not `(4)`. The trailing comma is Python's single-element tuple syntax. +- The current public surface is intended for static builtin vector types. +- In descriptor `dtypes=[...]`, builtin vector operands are matched by their + element dtype (`pto.i16` in the example above). The vector shape contract is + carried by the parameter annotation `pto.vector(...)`. + +### Vector Type Reinterpretation (vbitcast) + +Vector registers support bitwise type reinterpretation via `pto.vbitcast`: + +```python +result = pto.vbitcast(vector, to_type) +``` + +Interface summary: +- `vector`: a vector register value of type `!pto.vreg` +- `to_type`: target element dtype such as `pto.i32`, `pto.ui32`, `pto.f16`, `pto.bf16`, `pto.f32` +- return: a new vector register `!pto.vreg` whose element count is inferred from the fixed 256-byte vreg width + +Constraints: +- `vector` must be a vreg value; scalar values, pointers, `Tile`, and `TensorView` are rejected +- `to_type` must be a DSL-supported vreg element dtype +- `vbitcast` preserves the total register storage size, so only reinterpretations with the same total bit count are allowed +- the operation has no mask, rounding, saturation, or lane-placement parameters + +Lane count is recomputed from `to_type`: +- `!pto.vreg<64xf32> + pto.i32 -> !pto.vreg<64xi32>` +- `!pto.vreg<64xf32> + pto.f16 -> !pto.vreg<128xf16>` +- `!pto.vreg<128xbf16> + pto.ui16 -> !pto.vreg<128xui16>` + +```python +# Float to integer bitwise reinterpretation +fvec = pto.vlds(ub_ptr, lane) # !pto.vreg<64xf32> +ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> + +# Signed to unsigned integer reinterpretation +signed_vec = pto.vlds(ptr, lane) # !pto.vreg<64xsi32> +unsigned_vec = pto.vbitcast(signed_vec, pto.ui32) # !pto.vreg<64xui32> + +# Element size change (32-bit to 16-bit) +f32_vec = pto.vlds(ptr, lane) # !pto.vreg<64xf32> +f16_vec = pto.vbitcast(f32_vec, pto.f16) # !pto.vreg<128xf16> +``` + +Pythonic syntax sugar via `astype()` method: + +```python +ivec = fvec.astype(pto.i32) # Float to integer +unsigned_vec = signed_vec.astype(pto.ui32) # Signed to unsigned +f16_vec = f32_vec.astype(pto.f16) # 32-bit to 16-bit +``` + +`astype()` on a vector register is syntax sugar for `pto.vbitcast(...)`. In other words, it is a bit reinterpretation API, not a numeric conversion API. + +**Note**: `vbitcast` preserves the exact bit pattern (type punning), unlike `vcvt` which performs value conversion with rounding/saturation. Use `vcvt` when you want numeric conversion semantics; use `vbitcast` when you want the bits to stay unchanged. + +### Typed Masks + +Masks are typed by their bit granularity: + +| DSL Type | VPTO Type | Description | +|----------|-----------|-------------| +| `pto.mask_b8` | `!pto.mask` | 8-bit granularity mask | +| `pto.mask_b16` | `!pto.mask` | 16-bit granularity mask | +| `pto.mask_b32` | `!pto.mask` | 32-bit granularity mask | + +```python +mask_ty = pto.mask_b32 +mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) +``` + +Typed masks also support explicit type reinterpretation via `pto.pbitcast`: + +```python +mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) +mask_b32 = pto.pbitcast(mask_b16, pto.mask_b32) +``` + +`pto.pbitcast(...)` is the predicate analogue of `pto.vbitcast(...)`: +- it changes the static mask granularity seen by later DSL/VPTO consumers +- it preserves the underlying predicate bit image +- it does not perform pack/unpack or interleave/deinterleave by itself + +Mask operations must match the vector element family: +- `f32`, `i32`, `si32`, and `ui32` vectors use `mask_b32` +- `f16`, `bf16`, `i16`, `si16`, and `ui16` vectors use `mask_b16` +- `i8`, `si8`, and `ui8` vectors use `mask_b8` + +```python +# Correct: f32 vector with b32 mask +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) + +# Error: mismatched mask granularity +mask16 = pto.make_mask(pto.f16, PAT.ALL) +out = pto.vabs(vec_f32, mask16) # Type error! +``` + +### Pointer Types [Advanced Tier] + +Pointers combine element type and memory space: + +```python +from pto import MemorySpace + +ptr_gm = pto.ptr(pto.f32, MemorySpace.GM) # GM pointer to f32 +ptr_ub = pto.ptr(pto.f16, MemorySpace.UB) # UB pointer to f16 +``` + +The `MemorySpace` enum provides type-safe memory space specification: + +| Enum Value | Description | +|------------|-------------| +| `MemorySpace.GM` | Global Memory (off-chip HBM/DDR) | +| `MemorySpace.MAT` | Cube L1 / cbuf staging buffer | +| `MemorySpace.LEFT` | Cube L0A left-operand buffer | +| `MemorySpace.RIGHT` | Cube L0B right-operand buffer | +| `MemorySpace.ACC` | Cube L0C accumulator buffer | +| `MemorySpace.BIAS` | Cube bias table buffer | +| `MemorySpace.UB` | Unified Buffer (on-chip SRAM, 256KB) | + +This replaces ad-hoc string literals with compile-time checked enums and is +shared by both the Vector and Cube DSL surfaces. + +### Public Buffer Types + +TileLang uses three public buffer-facing type names in kernel signatures: + +| Public Type | Description | +|-------------|-------------| +| `pto.TensorView` | GM-facing tensor view descriptor used for DMA-oriented data access | +| `pto.PartitionTensorView` | Logical GM partition (slice) descriptor, corresponding to `!pto.partition_tensor_view<...>` | +| `pto.Tile` | Tile buffer value for hardware-resident staged compute/storage buffers | + +### TensorView Types + +TensorView types represent multi-dimensional (up to 5D) views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. + +#### TensorView Type Definition + +TensorView types are parameterized by shape (a tuple of up to 5 dimensions) and element type: + +```python +# Kernel parameter using TensorView +@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tensor: pto.TensorView, # GM tensor view + output_tensor: pto.TensorView, # GM tensor view + tile_buf: pto.Tile # UB tile +): + # Access tensor view properties + shape = input_tensor.shape # tuple of dimensions (dynamic or static, up to 5D) + dtype = input_tensor.element_type # e.g., pto.f32 + strides = input_tensor.strides # stride in elements +``` + +Important notes: +- TensorView is a read-only descriptor for GM data, though DMA store operations can write through it. +- Shape can be static (compile-time constants) or dynamic (determined at runtime). +- Strides are expressed in elements, not bytes. +- Memory space is always GM (Global Memory). +- Maximum rank is 5. PTO ISA right-aligns lower-rank shapes to 5D. +- When higher dimensions are 1, a 5D TensorView can be abbreviated to lower-rank forms. For example, shape `(1, 1, 64, 32, 16)` can be written as `(64, 32, 16)`, and shape `(1, 1, 1, 32, 16)` can be written as `(32, 16)`. + +#### TensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Tensor dimensions (supports up to 5 dimensions, right-aligned to 5D in PTO ISA) | +| `element_type` | `Type` | Element data type (for example `pto.f32`, `pto.f16`) | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from base pointer (internal) | + +#### Padding Mode Enum + +Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: + +| Enum Value | Description | +|------------|-------------| +| `PadMode.PadNull` | No padding. Out-of-bounds access is invalid | +| `PadMode.PadFirstElem` | Pad using the first element of the source | +| `PadMode.PadValue` | Pad using a specified value and requires `pad_value` | + +#### Slicing Syntax + +TensorView supports Python slicing syntax to create logical partitions: + +```python +# Create a partition from a tensor view +partition = tensor_view[dim0_start:dim0_end, dim1_start:dim1_end] + +# Example: extract a 16x16 tile from a larger tensor +tile_view = large_tensor[0:16, 0:16] + +# Dynamic offsets and sizes +dim0_start = tensor_view.shape[0] // 2 +dynamic_partition = tensor_view[dim0_start:tensor_view.shape[0], 4:20] + +# Static positive step on dimension 0 +stepped_partition = tensor_view[0:32:2, 0:16] + +# Right-aligned shorthand on a 5D descriptor +partition_3d = tensor_view[d2_start:d2_end, d3_start:d3_end, d4_start:d4_end] + +# Full 5D spelling remains available when needed +partition_5d = tensor_view[ + d0_start:d0_end, + d1_start:d1_end, + d2_start:d2_end, + d3_start:d3_end, + d4_start:d4_end, +] +``` + +Constraints: +- Slicing returns a new `pto.PartitionTensorView` representing the logical partition. +- The partition must be within the original tensor bounds. +- When fewer than 5 slice axes are written, they are right-aligned to the trailing physical axes of the 5D descriptor. +- `stop` must be explicit on all dimensions. +- `start` may be static or dynamic. +- `step` must be a static positive integer. +- Dimension 0 may use `step > 1`. +- Dimension 1 must keep `step == 1` in the current DMA-oriented implementation. + +### PartitionTensorView Types + +`pto.PartitionTensorView` models a logical partition of GM tensor data and maps to +`!pto.partition_tensor_view` in PTO IR. +Like `TensorView`, it is a descriptor type and does not own storage. + +#### PartitionTensorView Type Definition + +```python +@pto.vkernel(target="a5", op="custom_partition", dtypes=[(pto.f32, pto.f32)]) +def kernel(inp: pto.TensorView, out: pto.TensorView): + part: pto.PartitionTensorView = inp[0:16, 0:16] + p_rows, p_cols = part.shape + s_row, s_col = part.strides + return None +``` + +Important notes: +- A `PartitionTensorView` carries partition `shape` and `strides` metadata in element units. +- Element dtype is inherited from the source tensor view. +- Memory space remains GM. +- Rank handling follows the same right-aligned 5D contract as `TensorView`. +- `PartitionTensorView` can be used where DMA-oriented TensorView-like descriptors are accepted. +- Prefer direct indexing or tuple unpacking for `shape`/`strides` metadata values in current DSL v1 lowering. + +#### PartitionTensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Partition dimensions | +| `element_type` | `Type` | Element data type inherited from source tensor view | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from the base tensor pointer (internal) | + +### Tile Types + +Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. + +#### Tile Type Definition + +`pto.Tile` is the public tile type used for hardware buffer allocation in specific +address spaces. Tiles are constructed directly via the `pto.Tile` constructor: + +```python +pto.Tile( + shape: tuple[int, ...], # Buffer shape (required) + dtype: Type, # Element type (required) + memory_space: MemorySpace, # Address space (required) + valid_shape: tuple[int, ...] | None = None, # Valid region, defaults to shape + blayout: BLayout | None = None, # B layout, auto-detected from address space + slayout: SLayout | None = None, # S layout, auto-detected from address space + fractal_size: int | None = None, # Fractal size, auto-detected from address space + pad_value: PadValue = PadValue.Null, # Pad policy + compact_mode: CompactMode = CompactMode.Null, # Compact mode + addr: int | None = None, # Pre-assigned address (level3 only) +) -> Tile +``` + +Layout defaults are selected automatically based on the address space: + +| Address Space | blayout default | slayout default | fractal_size default | +|--------------|----------------|----------------|---------------------| +| `MAT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | +| `LEFT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | +| `RIGHT` | `RowMajor` | `ColMajor` | `TileConfig.fractalABSize` (512) | +| `ACC` | `ColMajor` | `RowMajor` | `TileConfig.fractalCSize` (1024) | +| `BIAS` | `RowMajor` | `NoneBox` | `TileConfig.fractalABSize` (512) | +| `UB` / `VEC` | `RowMajor` | `NoneBox` | `TileConfig.fractalABSize` (512) | + +Related enum types: + +| Enum | Values | +|------|--------| +| `BLayout` | `ColMajor` (0), `RowMajor` (1) | +| `SLayout` | `NoneBox` (0), `RowMajor` (1), `ColMajor` (2) | +| `PadValue` | `Null` (0), `Zero` (1), `Max` (2), `Min` (3) | +| `CompactMode` | `Null` (0), `Normal` (1), `RowPlusOne` (2) | + +Usage: + +```python +# Allocate tiles in @vkernel or @ckernel +tile_ub = pto.Tile([256, 128], pto.f32, MemorySpace.UB) +tile_left = pto.Tile([16, 64], pto.f16, MemorySpace.LEFT) +tile_acc = pto.Tile([16, 16], pto.f32, MemorySpace.ACC, valid_shape=(12, 12)) +``` + +Important notes on shape and valid shape: +- `shape` must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. +- `valid_shape` can be either static or dynamic and must be less than or equal to `shape` in each dimension. +- When `valid_shape` is not specified, it defaults to the full `shape`. + +#### Tile Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Full tile dimensions. These are compile-time constants | +| `element_type` | `Type` | Element data type (for example `pto.f32`) | +| `memory_space` | `MemorySpace` | Memory space such as UB, MAT, LEFT, RIGHT, ACC, or BIAS | +| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within the tile. Must be less than or equal to `shape` in each dimension | +| `config` | `TileConfig` | Layout and padding configuration | + +#### Tile Pad Values + +`TileConfig.pad_value` is modeled after the C++ `PadValue : uint64_t` design. + +Standard pad values use small integer encodings: + +| DSL Value | Encoded Value | Meaning | +|-----------|---------------|---------| +| `pto.PadValue.NULL` | `0` | No concrete fill value | +| `pto.PadValue.ZERO` | `1` | Zero fill | +| `pto.PadValue.MAX` | `2` | Maximum finite / integer max for the tile element dtype | +| `pto.PadValue.MIN` | `3` | Minimum finite / integer min for the tile element dtype | + +Custom pad values use the `CustomBase = 0x100000000` convention and are authored with `pto.PadValue.custom_f32(...)`: + +```python +pad0 = pto.PadValue.ZERO +pad1 = pto.PadValue.custom_f32(-1.0) +pad2 = pto.PadValue.custom_f32("0xBF800000") # float32 bit pattern for -1.0f +``` + +Notes: +- `PadValue.encoded` exposes the host-side uint64 payload. `PadValue.value` is intentionally unavailable to avoid confusion with `.eval(...)` scalar materialization. +- `PadValue.text` exposes the standard textual spelling for built-ins such as `null` and `zero`. +- Custom pad values currently model an `f32` payload. In DSL v1, materializing a custom pad into a scalar is only supported for floating tile element dtypes. +- `PadValue.NULL` does not denote a usable scalar fill constant. Calling `tile.pad_value.eval()` or `tile.config.pad_value.eval()` when the enum is `NULL` is a frontend error. +- **DMA padding**: When performing GM→UB DMA transfers with padding enabled (via `enable_ub_pad=True` in `pto.copy_gm_to_ubuf`), the pad value must be configured explicitly using `pto.set_mov_pad_val`. Tile `PadValue` descriptors are not automatically translated to hardware register configurations in TileLang DSL v1. See [Pad Fill Semantics](08-sync-dma-operations.md#pad-fill-semantics) for usage details. + +Host-side code can materialize a scalar with an explicit dtype: + +```python +pad_max_f32 = pto.PadValue.MAX.eval(pto.f32) +pad_min_i16 = pto.PadValue.MIN.eval(pto.i16) +``` + +#### Tile Shape Concepts + +- `shape` is the static physical allocation size of the tile buffer. +- `valid_shape` is the logical data region and may be static or dynamic. +- `valid_shape[i] <= shape[i]` must hold for each dimension. +- Fixed-size tiles with smaller valid regions are useful for padding and partial-tile cases. + +#### Basic Access Operations + +```python +# Get tile properties +shape = tile.shape # (256, 128) +elem_type = tile.element_type # pto.f32 +mem_space = tile.memory_space # MemorySpace.UB +valid_shape = tile.valid_shape # (240, 120) or same as shape + +# Get configuration properties +config = tile.config +b_layout = config.b_layout # pto.BLayout.ROW_MAJOR +s_layout = config.s_layout # pto.SLayout.NONE_BOX +s_fractal = config.s_fractal_size # pto.i32(512) +pad_desc = tile.config.pad_value # PadValue enum bound to the tile element dtype +pad_desc2 = tile.pad_value # direct sugar for the same PadValue enum + +# Dynamic properties +rank = tile.rank # 2 +``` + +`tile.config.pad_value` and `tile.pad_value` are enum-typed inside kernel code. Use `.eval()` to materialize the configured pad descriptor against the tile element dtype: + +- `tile.pad_value.eval()` with `PadValue.ZERO` becomes `0` / `0.0` +- `tile.pad_value.eval()` with `PadValue.MAX` becomes dtype-aware max +- `tile.pad_value.eval()` with `PadValue.MIN` becomes dtype-aware min +- `tile.pad_value.eval()` with `PadValue.custom_f32(...)` becomes the authored floating scalar +- `tile.pad_value.eval()` with `PadValue.NULL` raises a frontend error + +For dtype-dependent fill seeds, prefer `tile.pad_value.eval()` over handwritten +`if dtype == ...` ladders. + +For standalone `PadValue` symbols that are not bound to a tile, pass the target dtype explicitly: + +```python +pad_scalar = pto.PadValue.MAX.eval(pto.f32) +``` + +```python +@pto.vkernel(op="fill_pad_value", dtypes=[(pto.AnyType,)]) +def fill_pad_value(dst: pto.Tile): + pad_scalar = dst.pad_value.eval() + pad_vec = pto.vbr(pad_scalar) + # ... +``` + +Typical materialized values: + +- `PadValue.ZERO` -> `0` / `0.0` +- `PadValue.MAX` -> dtype-aware max, for example `4294967295` for `pto.ui32` +- `PadValue.MIN` -> dtype-aware min, for example `-2147483648` for `pto.i32` and `0` for `pto.ui32` + +This is usually simpler than spelling every dtype case manually with +`pto.constexpr(dst.element_type == ...)`. + +Example: reading pad value from a `Tile` + +```python +@pto.vkernel(op="fill_pad_demo", dtypes=[(pto.f16,)]) +def kernel(dst: pto.Tile): + mask, _ = pto.make_mask(pto.f16, 8) + + # Read the Tile-bound PadValue enum. + pad0 = dst.pad_value + + # Equivalent form through TileConfig metadata. + pad1 = dst.config.pad_value + + if pto.constexpr(pad0 != pto.PadValue.NULL): + scalar0 = pad0.eval() + scalar1 = pad1.eval() + vec0 = pto.vdup(scalar0, mask) + vec1 = pto.vdup(scalar1, mask) + pto.vsts(vec0, dst[0, 0:], mask) + pto.vsts(vec1, dst[1, 0:], mask) +``` + +If `dst` is specialized with `config=pto.TileConfig.from_mapping({"pad_value": pto.PadValue.ZERO})`, +both `pad0` and `pad1` are `PadValue.ZERO`, and `pad0.eval()` / `pad1.eval()` materialize to the scalar `0.0` for an `f16` tile. + +#### Conversion Operations + +Basic mode syntax uses tile element-indexing directly in vector operations: + +```python +# 2D tile indexing +vec = pto.vlds(tile[row, col:]) +pto.vsts(vec, tile[row, col:], mask) + +# 1D tile indexing +vec = pto.vlds(tile[start:]) +pto.vsts(vec, tile[start:], mask) +``` + +Advanced mode syntax converts tiles to typed pointers for byte-offset operations: + +```python +# Convert tile to pointer +ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) + +# Use pointer with byte offset +vec = pto.vlds(ptr, offset) +pto.vsts(vec, ptr, offset, mask) +``` + +#### Kernel Parameter Usage + +```python +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tile: pto.Tile, + output_tile: pto.Tile, + scale: pto.f32 +): + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for i in range(0, 256, 64): + vec = pto.vlds(input_tile[i, 0:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, output_tile[i, 0:], all_mask) +``` + +### Alignment Type + +The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. diff --git a/ptodsl/docs/user_guide/06-control-flow.md b/ptodsl/docs/user_guide/06-control-flow.md new file mode 100644 index 000000000..41b623d1d --- /dev/null +++ b/ptodsl/docs/user_guide/06-control-flow.md @@ -0,0 +1,181 @@ +## Control Flow + +### Vector Scopes + +The TileLang DSL supports implicit vector scope inference, allowing developers to write vector operations directly without explicit `pto.vecscope()` blocks. The compiler automatically groups consecutive, data-dependent vector operations into implicit vector scopes during lowering. + +#### Implicit Scope Inference + +**Note:** `pto.vecscope()` is supported. Automatic scope inference runs only when the kernel does **not** contain explicit `with pto.vecscope():` blocks. + +When you write vector operations like `pto.vlds`, `pto.vadd`, `pto.vsts` directly in your code, the compiler's **Scope Inference Pass** analyzes the control flow graph and automatically creates vector scopes: + +```python +# No explicit vecscope needed - compiler infers scope boundaries +vec = pto.vlds(outer_ptr, offset) +result = pto.vadd(vec, vec, all_mask) +pto.vsts(result, dst_ptr, offset, all_mask) +``` + +The compiler automatically groups these three operations into a single implicit vector scope because they form a data-dependent chain (when no explicit `pto.vecscope()` appears in the kernel). + +**Scope boundary rules:** +1. **Control flow boundaries**: Branches (`if`/`else`), loops (`for`/`while`), and function calls create implicit scope boundaries +2. **Scalar operations**: Non-vector operations (e.g., scalar arithmetic, pointer arithmetic) create boundaries +3. **Explicit scope blocks**: User-defined `vecscope` and `strict_vecscope` blocks create hard boundaries + +#### Explicit Scope Boundaries with `strict_vecscope` [Advanced Tier] + +##### `pto.strict_vecscope(*captures: AnyType) -> ContextManager[Tuple[AnyType, ...]]` + +**Description**: Creates an explicit vector scope boundary with explicit value captures. Values used inside the scope must be passed as arguments; implicit capture from outer scope is rejected. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `*captures` | `AnyType` | Variable number of values to be captured and passed into the scope | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `context_manager` | `ContextManager[Tuple[AnyType, ...]]` | Context manager that yields a tuple of captured values when entered | + +**Constraints**: +- The scope body cannot implicitly capture values from the surrounding scope; all used values must be passed as `captures`. +- Creates a hard boundary that prevents the compiler from merging vector operations across the scope boundary. +- Useful for performance optimization, debugging, resource management, and hardware compatibility. + +For precise control over scope boundaries, use explicit `strict_vecscope` blocks. These create hard boundaries that prevent the compiler from merging operations across the block boundary: + +```python +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Operations inside this block are isolated from outside + # Compiler will not merge operations across this boundary + for i in range(lb, ub, 64): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, all_mask) +``` + +**Use cases for strict_vecscope:** +- Performance optimization: Isolate critical vector computation regions +- Debugging: Create explicit boundaries to isolate vector operations +- Resource management: Control vector register allocation boundaries +- Compatibility: Ensure deterministic scope placement for hardware constraints + +#### Explicit Scope Blocks with `vecscope` + +`pto.vecscope` provides an explicit vector-scope boundary without strict capture ABI constraints: + +```python +with pto.vecscope(): + vec = pto.vlds(src, 0) + vec = pto.vadd(vec, vec, mask) + pto.vsts(vec, dst, 0, mask) +``` + +**Rules**: +- `pto.vecscope()` takes no positional/keyword arguments. +- `pto.vecscope()` does not support `as (...)` bindings. +- When any explicit `pto.vecscope()` is present in a kernel body, automatic vecscope inference is disabled for that kernel. + +### Inline Procedures (`@pto.inline_proc`) + +TileLang DSL supports reusable top-level procedures decorated with `@pto.inline_proc`. +`inline_proc` follows function-call semantics in frontend IR and is force-inlined +later by the VPTO backend mainline in `ptoas`. + +```python +@pto.inline_proc +def store_row(dst: pto.Tile, src: pto.Tile, row: pto.i32): + vec = pto.vlds(src[row, 0:]) + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + pto.vsts(vec, dst[row, 0:], mask) + return None + +@pto.vkernel(op="pto.row_copy", dtypes=[(pto.f32, pto.f32, pto.i32)]) +def row_copy(dst: pto.Tile, src: pto.Tile, row: pto.i32): + store_row(dst, src, row) + return None +``` + +Important semantics: + +- `pto.(...)` and bare helper calls are different mechanisms. +- Calls written as `pto.vadd(...)`, `pto.vdiv(...)`, `pto.vlds(...)`, etc. target + built-in TileLang/VPTO surfaces directly. +- Calls written as bare Python names such as `store_row(...)` target a + user-defined `@pto.inline_proc` helper when the callee name resolves to a + registered top-level inline procedure in the current module. +- `inline_proc` helpers do not live in the `pto` namespace; using the same + basename as a `pto.` op is allowed because the frontend distinguishes + `pto.xxx(...)` from bare `xxx(...)` calls. +- Frontend preserves helper `func.func` and `func.call` in `mlir_text()` output. +- VPTO backend mainline force-inlines helper calls before downstream lowering. +- Helper definitions support default parameter values. +- Helper calls support positional arguments and keyword arguments. +- Helper calls can appear in statement and expression positions. +- Helper definitions can use trailing `return ` to return values. +- Implicit capture is rejected except module-level globals whose current bound value is `bool`/`int`/`float`/`str`; pass other required values as explicit arguments. +- Recursive/mutually-recursive helper call graphs are rejected. +- `*args`, `**kwargs`, and keyword-only parameters are unsupported in current version. + +Shared helpers can live in a separate Python file in the template directory and +be imported directly by templates: + +```python +# shared_rows.py +import tilelang_dsl as pto + +@pto.inline_proc +def touch_row(dst: pto.Tile, row: pto.i32): + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + vec = pto.vlds(dst[row, 0:]) + pto.vsts(vec, dst[row, 0:], mask) + return None + +# trow_template.py +import tilelang_dsl as pto +from shared_rows import touch_row + +@pto.vkernel(op="pto.row_touch", dtypes=[(pto.f32, pto.i32)]) +def row_touch(dst: pto.Tile, row: pto.i32): + touch_row(dst, row) + return None +``` + +Only directly imported `@pto.inline_proc` helpers are part of this shared-helper +surface. Ordinary Python functions remain unsupported in DSL bodies, and +qualified calls such as `shared_rows.touch_row(...)` are not part of this +version. If multiple imported helpers expose the same bare name, the frontend +rejects the template instead of choosing one by import order. + +### Loops + +Counted loops use Python's `range` syntax: + +```python +for i in range(lb, ub, step): + # Loop body + mask, rem = pto.make_mask(pto.f32, remaining) + # ... +``` + +Loop-carried state is automatically handled through variable updates within the loop. + +### Conditionals + +`if` statements support value merging: + +```python +flag: pto.i1 = some_condition +step: pto.i32 = 0 + +if flag: + step = pto.i32(64) +else: + step = pto.i32(128) + +# 'step' here is the merged result from both branches +``` + +Variables defined in only one branch are local to that branch. diff --git a/ptodsl/docs/user_guide/07-frontend-operations.md b/ptodsl/docs/user_guide/07-frontend-operations.md new file mode 100644 index 000000000..621a8c78f --- /dev/null +++ b/ptodsl/docs/user_guide/07-frontend-operations.md @@ -0,0 +1,352 @@ + +### Frontend-only Authoring Operations + +Operations in this family affect descriptor construction and code generation +shape. They are consumed by the frontend and do not correspond to runtime VPTO +instructions by themselves. + +#### `pto.constexpr(value: bool) -> bool` + +**Description**: Compile-time conditional construct for kernel specialization. Marks a boolean expression for evaluation during descriptor materialization, enabling branch elimination based on static compile-time information. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `bool` | Boolean expression that must be evaluable at compile time. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `bool` | A frontend-only compile-time boolean used to guard `if` statements. | + +**Behavior**: +- Evaluated during kernel descriptor materialization before semantic analysis and lowering. +- When used in `if pto.constexpr(...):` statements, only the selected branch is retained; the other branch is discarded entirely. +- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. +- Does not generate runtime control flow or value merging logic. + +**Examples**: +```python +# Specialize based on element size +dtype = dst.element_type +elem_bytes = pto.bytewidth(dtype) + +if pto.constexpr(elem_bytes == 2): + # Specialized path for 16-bit types (f16/bf16) + ... +else: + # Fallback path for other types + ... +``` + +```python +# Specialize based on tile shape +rows, cols = dst.shape + +if pto.constexpr(rows == 1 and cols == 16): + # Fast path for specific tile configuration + ... +``` + +**Constraints**: +- `pto.constexpr` is a frontend-only authoring construct with no runtime representation. +- The condition must be statically evaluable from descriptor-time information (data types, tile shapes, literals, etc.). +- For kernel-level specialization, prefer `constraints=[...]` and `pto.select_kernel(...)`. +- See [Compile-time Specialization with `pto.constexpr`](04-template-kernels.md#compile-time-specialization-with-ptoconstexpr) for detailed usage guidelines. + +### Type Query Operations + +Operations for querying type properties. + +#### `pto.bytewidth(dtype: Type) -> pto.i32` + +**Description**: Returns the size in bytes of a single element of the given data type. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `size` | `pto.i32` | Element size in bytes | + +**Example**: +```python +f32_size = pto.bytewidth(pto.f32) # Returns 4 +f16_size = pto.bytewidth(pto.f16) # Returns 2 +i8_size = pto.bytewidth(pto.i8) # Returns 1 +ui64_size = pto.bytewidth(pto.ui64) # Returns 8 +``` + +**Common Use Case**: Calculate byte offsets for memory access: +```python +element_type = pto.f32 +byte_offset = index * pto.bytewidth(element_type) +``` + +#### `pto.elements_per_vreg(dtype: Type) -> pto.i32` + +**Description**: Returns the number of elements per vector register for a given element type, based on the hardware vector register size (256 bytes). This function computes `256 // bytewidth(dtype)`, which represents the maximum number of elements of the given type that can fit in a single vector register. Useful for determining vector width and loop stride calculations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `elems` | `pto.i32` | Number of elements per vector register for the given element type | + +**Example**: +```python +f32_elems_per_vreg = pto.elements_per_vreg(pto.f32) # Returns 64 (256 / 4) +f16_elems_per_vreg = pto.elements_per_vreg(pto.f16) # Returns 128 (256 / 2) +i8_elems_per_vreg = pto.elements_per_vreg(pto.i8) # Returns 256 (256 / 1) +si16_elems_per_vreg = pto.elements_per_vreg(pto.si16) # Returns 128 (256 / 2) +``` + +**Common Use Case**: Loop stride calculation for vector operations: +```python +dtype = pto.f32 +elems_per_vreg = pto.elements_per_vreg(dtype) # Returns 64 for f32 +for col in range(0, cols, elems_per_vreg): + # Load/store vectors of 'elems_per_vreg' elements + pass +``` + +**Relationship with `pto.bytewidth`**: +```python +# The relationship between bytewidth and elements per vector register: +elems = 256 // pto.bytewidth(dtype) +# This is equivalent to: +elems = pto.elements_per_vreg(dtype) +``` + +### Runtime Block Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar +code. They are pure scalar producers: + +- they do not move data +- they do not allocate buffers +- they do not by themselves create `vecscope` boundaries + +Their main purpose is workload partitioning. A common pattern is: + +1. query the current block or subblock id +2. compute a per-instance starting offset +3. use that offset to derive GM/UB pointers or TensorView slices +4. run the local tile or vector loop for that partition + +#### `pto.get_block_idx() -> pto.i64` + +**Description**: Returns the current block ID for the running kernel instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `block` | `pto.i64` | Current block index in the range `[0, pto.get_block_num())` | + +**Behavior**: +- The returned value is launch-instance-local and may differ across concurrently running blocks. +- The value is stable for the lifetime of one kernel instance. +- The op is scalar-only and can be used before pointer arithmetic, TensorView partitioning, DMA setup, or loop construction. + +#### `pto.get_subblock_idx() -> pto.i64` + +**Description**: Returns the current subblock ID visible to the running kernel instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `subblock` | `pto.i64` | Current subblock index in the range `[0, pto.get_subblock_num())` | + +**Behavior**: +- Used when one block is further subdivided by the launch/runtime model. +- Like `pto.get_block_idx()`, this is a pure scalar query with no side effects. + +#### `pto.get_block_num() -> pto.i64` + +**Description**: Returns the total number of blocks visible to the current kernel launch. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `block_num` | `pto.i64` | Total block count for the current launch domain | + +**Behavior**: +- Typically paired with `pto.get_block_idx()` to compute per-block ranges. +- The result is a runtime value and should not be assumed to be a compile-time constant. + +#### `pto.get_subblock_num() -> pto.i64` + +**Description**: Returns the total number of subblocks visible to the current execution instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `subblock_num` | `pto.i64` | Total subblock count in the current runtime execution domain | + +**Behavior**: +- Typically paired with `pto.get_subblock_idx()` for finer-grained partitioning inside one block. + +**Example**: +```python +block = pto.get_block_idx() +block_num = pto.get_block_num() +subblock = pto.get_subblock_idx() +subblock_num = pto.get_subblock_num() +``` + +**Typical Use Case**: Compute a per-block base pointer. +```python +block = pto.get_block_idx() +block_len = 2048 +base_elem = block * block_len +block_src = pto.addptr(src_gm, base_elem) +block_dst = pto.addptr(dst_gm, base_elem) +``` + +**Constraints**: +- These ops return runtime scalar values, not compile-time specialization constants. +- They are intended for scalar address/control computation, not as vector operands. +- When mixing them with pointer arithmetic, remember that `pto.addptr(...)` uses element offsets, not byte offsets. + +### Scalar Pointer Helpers [Advanced Tier] + +These ops perform scalar element access on typed PTO pointers. Unlike +`pto.vlds(...)` / `pto.vsts(...)`, they operate on exactly one element and do +not create or consume vector registers or masks. + +They are useful when a kernel needs a small amount of scalar state next to +vector code, for example: + +- reading a scalar coefficient or loop-carried value from UB +- writing a scalar flag or reduction result +- patching a small header/metadata area without vector load-store semantics + +#### `pto.load_scalar(ptr: PtrType, offset: Index) -> ScalarType` +#### `pto.load_scalar(dtype: Type, ptr: PtrType, offset: Index) -> ScalarType` + +**Description**: Loads one scalar element from a typed PTO pointer at the given element offset. + +**Parameters (`load_scalar(ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed pointer created by `pto.ptr(...)`, `pto.castptr(...)`, `Tile.as_ptr()`, or `TensorView.as_ptr()` | +| `offset` | `Index` | Element displacement from `ptr` | + +**Parameters (`load_scalar(dtype, ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Optional explicit result dtype; must match the pointer element type | +| `ptr` | `PtrType` | Typed pointer source | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `value` | `ScalarType` | One scalar element loaded from `ptr[offset]` | + +**Behavior**: +- Access is element-based, not byte-based. +- The loaded value has the same scalar dtype as the pointer element type. +- This is a scalar memory helper; it does not participate in vector distribution families such as `dist`. +- It may target any memory space represented by the pointer type; the memory-space legality follows the pointer producer. + +#### `pto.store_scalar(ptr: PtrType, offset: Index, value: ScalarType) -> None` +#### `pto.store_scalar(value: ScalarType, ptr: PtrType, offset: Index) -> None` + +**Description**: Stores one scalar element to a typed PTO pointer at the given element offset. + +**Parameters (`store_scalar(ptr, offset, value)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | +| `value` | `ScalarType` | Scalar value to write | + +**Parameters (`store_scalar(value, ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar value to write | +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: None (side-effect operation) + +**Behavior**: +- Stores exactly one scalar element to `ptr[offset]`. +- Does not consume a predicate mask. +- Does not imply vector-store ordering semantics such as `dist` or unaligned store state. + +**Example**: +```python +value = pto.load_scalar(src_ptr, 0) +pto.store_scalar(dst_ptr, 0, value) +``` + +**Typical Use Case**: Read-modify-write scalar metadata next to vector code. +```python +flag = pto.load_scalar(status_ptr, 0) +# scalar compute on `flag` +pto.store_scalar(status_ptr, 0, flag) +``` + +**Constraints**: +- `ptr` must be a typed `pto.ptr(...)` value. +- `offset` is element-based and must be index-typed after frontend normalization. + Plain integer literals such as `0` are accepted and lowered as index constants. +- The scalar dtype must match the pointer element dtype. +- These ops are advanced pointer-surface operations; prefer Tile/TensorView authoring surfaces when scalar pointer manipulation is not required. + +### Pointer Construction [Advanced Tier] + +Operations for creating and manipulating typed pointers. + +#### `pto.castptr(offset: pto.i64, ptr_type: Type) -> PtrType` + +**Description**: Creates a typed pointer from an integer address, a memref-backed address value, or another typed pointer in the same memory space. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `offset` | `pto.i64` / address-like value | Integer address, memref-backed address value, or existing pointer | +| `ptr_type` | `Type` | Target pointer type (e.g., `pto.ptr(pto.f32, MemorySpace.GM)`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `ptr` | `PtrType` | Typed pointer value | + +**Example**: +```python +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +``` + +`TensorView.as_ptr()` and `Tile.as_ptr()` remain the preferred high-level APIs. They lower directly to address-extraction intrinsics (`pto.tensor_view_addr` / `pto.tile_buf_addr`) with pointer result types, while tile slice / buffer-view authoring paths continue to materialize memref results from the same intrinsics. + +#### `pto.addptr(ptr: PtrType, offset: pto.i64) -> PtrType` + +**Description**: Adds an element offset to an existing pointer. The offset is counted in elements, not bytes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Source pointer | +| `offset` | `pto.i64` | Element offset to add (counted in elements, not bytes) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `new_ptr` | `PtrType` | Pointer with element offset applied | + +**Example**: +```python +# Advance pointer by 1024 f32 elements (not bytes) +next_ptr = pto.addptr(ub_ptr, 1024) +``` + diff --git a/ptodsl/docs/user_guide/08-sync-dma-operations.md b/ptodsl/docs/user_guide/08-sync-dma-operations.md new file mode 100644 index 000000000..883e5104a --- /dev/null +++ b/ptodsl/docs/user_guide/08-sync-dma-operations.md @@ -0,0 +1,622 @@ +### Synchronization & Buffer Control + +Operations for pipeline synchronization and buffer management. + +#### Enum Types for Synchronization + +The following enum types provide type-safe parameter specification for synchronization operations: + +- **`BarrierType`**: Memory barrier types for `pto.mem_bar` + - `VV_ALL`, `VST_VLD`, `VLD_VST`, `VST_VST`: vector→vector barriers + - `VS_ALL`, `VST_LD`, `VLD_ST`, `VST_ST`: vector→scalar barriers + - `SV_ALL`, `ST_VLD`, `LD_VST`, `ST_VST`: scalar→vector barriers + +- **`Pipe`**: Hardware pipeline identifiers + - `MTE2`: Memory Transfer Engine 2 pipeline + - `V`: Vector pipeline + - `MTE3`: Memory Transfer Engine 3 pipeline + - `ALL`: All pipelines (for barrier operations) + +- **`Event`**: Event identifiers for synchronization + - `ID0`, `ID1`, `ID2`, `ID3`, ..., `ID31`: Event IDs 0-31 (A5 supports 32 event IDs, 0-15 for subblock 0, 16-31 for subblock 1) + +#### `pto.set_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Sets a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.wait_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Waits for a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.pipe_barrier(pipes: PIPE) -> None` + +**Description**: Executes a barrier across specified pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipes` | `PIPE` | Pipeline specification (e.g., `PIPE.ALL`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE + +pto.pipe_barrier(PIPE.ALL) +``` + +#### `pto.get_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` + +**Description**: Acquire buffer slot for inter-pipeline double-buffering coordination. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | +| `buf_id` | `pto.i64` | Buffer identifier | +| `mode` | `pto.i64` | Acquisition mode | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Pipe + +# Acquire buffer for MTE2 pipeline +pto.get_buf(Pipe.MTE2, 0, 0) +``` + +#### `pto.rls_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` + +**Description**: Release buffer slot to allow other pipeline to proceed. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | +| `buf_id` | `pto.i64` | Buffer identifier | +| `mode` | `pto.i64` | Release mode | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Pipe + +# Release buffer for MTE2 pipeline +pto.rls_buf(Pipe.MTE2, 0, 0) +``` + +#### `pto.mem_bar(barrier_type: BarrierType) -> None` + +**Description**: Memory barrier for pipeline synchronization within vector scope. Required when UB addresses alias between vector load/store operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `barrier_type` | `BarrierType` | Barrier type controlling prior/subsequent instruction ordering. Supported values are `BarrierType.VV_ALL`, `BarrierType.VST_VLD`, `BarrierType.VLD_VST`, `BarrierType.VST_VST`, `BarrierType.VS_ALL`, `BarrierType.VST_LD`, `BarrierType.VLD_ST`, `BarrierType.VST_ST`, `BarrierType.SV_ALL`, `BarrierType.ST_VLD`, `BarrierType.LD_VST`, and `BarrierType.ST_VST`. | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import BarrierType + +# Ensure stores are visible before loads to same UB region +pto.mem_bar(BarrierType.VST_VLD) +``` + +#### `pto.set_cross_core(core_id: pto.i64, event_id: Event) -> None` + +**Description**: Signal event to another core (cross-core synchronization). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Target/source core identifier (platform-specific mapping) | +| `event_id` | `Event` | Cross-core event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Signal event ID0 to core 0 +pto.set_cross_core(0, Event.ID0) +``` + +#### `pto.set_intra_block(block_id: pto.i64, event_id: Event) -> None` + +**Description**: Signal event within a block (A5). Specifies trigger pipe. 1:1 per subblock. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block/pipeline identifier specifying trigger pipe | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Signal event ID0 on block/pipeline 0 +pto.set_intra_block(0, Event.ID0) +``` + +#### `pto.set_intra_core(config: pto.i32) -> None` + +**Description**: Configures intra-core synchronization settings. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `config` | `pto.i32` | Configuration value for intra-core synchronization | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_intra_core(3) +``` + +#### `pto.wait_flag_dev(core_id: pto.i64, event_id: Event) -> None` + +**Description**: Wait for event from another core. SU-level blocking — entire core stalls. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Core identifier | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Wait for event ID0 from core 0 +pto.wait_flag_dev(0, Event.ID0) +``` + +#### `pto.wait_intra_core(block_id: pto.i64, event_id: Event) -> None` + +**Description**: Wait for event within block (A5). Specifies which pipeline should wait — only that pipe stalls, SU and other pipes continue. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block/pipeline identifier specifying which pipeline should wait | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Wait for event ID0 on block/pipeline 0 +pto.wait_intra_core(0, Event.ID0) +``` + +### DMA Programming [Advanced Tier] + +This section covers Direct Memory Access (DMA) operations for transferring data between Global Memory (GM) and Unified Buffer (UB). DMA operations are performance-critical and require careful configuration of stride parameters and transfer sizes. + +**Key Concepts:** +- **DMA Configuration**: Set stride parameters and loop sizes using `set_loop*_stride_*` and `set_loop_size_*` operations. +- **DMA Execution**: Perform transfers using `copy_gm_to_ubuf`, `copy_ubuf_to_gm`, and `copy_ubuf_to_ubuf` operations. +- **GM→UB Padding**: Optionally fill out-of-bounds regions with a specified value when copying from GM to UB. See [Pad Fill Semantics](#pad-fill-semantics) for details. + +**Usage Flow:** +1. Configure DMA parameters (strides, loop sizes) +2. Execute the DMA transfer operation +3. Optionally enable padding for GM→UB transfers + +**Note**: All DMA operations in this section are part of the **Advanced Tier** and require explicit buffer management and pointer arithmetic. For basic tile-based authoring, refer to the [Basic Authoring Mode](01-introduction.md#basic-vs-advanced-authoring-modes) documentation. + +#### Manual Configuration Example + +```python +# DMA configuration example (requires careful parameter tuning) +pto.set_loop2_stride_outtoub(src_stride=32, dst_stride=128) # Outer loop strides +pto.set_loop1_stride_outtoub(src_stride=1, dst_stride=32) # Inner loop strides +pto.set_loop_size_outtoub(loop1=16, loop2=16) # Transfer size +pto.copy_gm_to_ubuf(src=gm_ptr, dst=ub_ptr, n_burst=16, len_burst=128, gm_stride=128, ub_stride=128) + +``` + +#### Pad Fill Semantics + +When copying data from Global Memory (GM) to Unified Buffer (UB), you can enable padding to fill out-of-bounds regions with a specified value. This is useful when the source data dimensions don't perfectly match the destination tile allocation, or when you need to handle boundary conditions in tiled computations. + +##### How Padding Works + +1. **Configure the hardware pad register**: Call `pto.set_mov_pad_val` to set the pad value in the hardware register. This must be done before any `pto.copy_gm_to_ubuf` operation with padding enabled. + +2. **Enable padding in the DMA operation**: Set `enable_ub_pad=True` in the `pto.copy_gm_to_ubuf` call to activate the padded transfer path. The pad value from the hardware register will be used for filling out-of-bounds regions. + +3. **Hardware mapping**: The `pto.set_mov_pad_val` operation corresponds directly to the low-level VPTO instruction that configures the hardware pad register. There is no automatic translation from tile `PadValue` descriptors—you must explicitly set the pad register before padded DMA transfers. + +##### Example Workflow + +Configure the hardware pad register using `pto.set_mov_pad_val`, then perform the DMA transfer with padding enabled: + +```python +# First, configure the hardware pad register with a scalar value +# For zero fill, use an appropriate scalar type based on your data +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data + +# Then perform the DMA transfer with padding enabled +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, # Enable padded transfer +) +``` + +##### Accessing Pad Values in Kernel Code + +Tile `PadValue` descriptors can be used within kernel code for computation purposes (e.g., initializing vectors with a specific fill value). However, note that **these descriptors are not automatically used for DMA padding**—you must still call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for GM→UB transfers. + +To access a pad value from a tile descriptor in kernel code: + +```python +# Get the pad descriptor from the destination tile +pad_desc = dst.pad_value + +# Check if a valid pad value is configured +if pto.constexpr(pad_desc != pto.PadValue.NULL): + # Materialize the scalar value + pad_scalar = pad_desc.eval() + + # Use the scalar value (e.g., for vector duplication) + mask = pto.make_mask(pto.f32, PAT.ALL) + pad_vector = pto.vdup(pad_scalar, mask) +``` + +##### Important Notes + +- The `PadValue.NULL` descriptor indicates no pad value is configured. Attempting to call `.eval()` on `PadValue.NULL` will raise a frontend error. +- Custom pad values currently support only 32-bit float payloads (`PadValue.custom_f32(...)`). +- Padding only affects GM→UB transfers (`pto.copy_gm_to_ubuf`). UB→GM and UB→UB transfers do not support padding. +- The padded region is determined by the difference between the tile's `valid_shape` and its full `shape`. Ensure your tile is configured with appropriate dimensions. +- Tile `PadValue` descriptors are not automatically used for DMA padding. You must call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for padded GM→UB transfers. + +##### `pto.set_mov_pad_val` Operation [Advanced Tier] + +The `pto.set_mov_pad_val` operation configures the hardware pad register used for GM→UB transfers when padding is enabled. This operation must be called explicitly before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`, as the TileLang DSL v1 does not automatically translate tile `PadValue` descriptors to hardware register configurations. + +**Operation Signature**: +```python +pto.set_mov_pad_val(pad_value: ScalarType) -> None +``` + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: any 8/16/32-bit integer scalar (`pto.i8`, `pto.si8`, `pto.ui8`, `pto.i16`, `pto.si16`, `pto.ui16`, `pto.i32`, `pto.si32`, `pto.ui32`) plus `pto.f16`, `pto.bf16`, and `pto.f32`. The value's bit pattern is encoded into the hardware pad register. Integer inputs are automatically normalized to the corresponding signless hardware operand width during lowering, so no manual cast is required before calling `pto.set_mov_pad_val`. For standard pad values, use `PadValue.eval(...)` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | + +**Returns**: None (side-effect operation) + +**Example**: + +Using a scalar value directly: +```python +# Configure the hardware pad register for zero fill using an integer scalar +pto.set_mov_pad_val(pto.i32(0)) # Zero fill for integer types + +# Or using a float scalar for floating-point padding +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float types + +# Perform DMA transfer with padding enabled +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, +) +``` + +Using a tile's pad value descriptor: +```python +# Get the pad value from a tile configuration +pad_desc = tile.pad_value # PadValue enum +if pto.constexpr(pad_desc != pto.PadValue.NULL): + pad_scalar = pad_desc.eval() # Materializes to a scalar value + pto.set_mov_pad_val(pad_scalar) + + # Perform padded DMA transfer + pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, + ) +``` + +Using a standalone `PadValue` with an explicit dtype: +```python +pad_scalar = pto.PadValue.MAX.eval(pto.f32) +pto.set_mov_pad_val(pad_scalar) +``` + +For integer tile dtypes such as `pto.ui16` or `pto.si32`, `pad_desc.eval()` can be passed directly to `pto.set_mov_pad_val`. TileLang DSL v1 will automatically insert the required same-width bitcast to the signless hardware operand type during lowering. + +**Important**: You are responsible for ensuring the pad register is properly configured before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`. The pad register configuration persists until changed by another `pto.set_mov_pad_val` call. + +**Future Improvement**: Future versions of TileLang DSL may provide an implicit approach that automatically translates `PadValue` descriptors from tile configurations to hardware register configurations, similar to DMA syntax sugar features. + +#### `pto.set_loop2_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_outtoub(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for GM → UB transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop1` | `pto.i64` | Inner loop trip count | +| `loop2` | `pto.i64` | Outer loop trip count | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size_outtoub(loop1=1, loop2=1) +``` + +#### `pto.set_loop2_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_ubtoout(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for UB → GM transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop1` | `pto.i64` | Inner loop trip count | +| `loop2` | `pto.i64` | Outer loop trip count | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop(loop_id: pto.i32, src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for a generic loop. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop(1, src_stride=32, dst_stride=64) +``` + +#### `pto.set_loop_size(loop_id: pto.i32, size: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for a generic loop. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | +| `size` | `pto.i64` | Loop trip count | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size(1, 16) +``` + +#### DMA Execution Operations + +**Note**: These operations execute DMA transfers but require manual configuration of DMA parameters (loop strides, loop sizes) using the `set_loop*_stride_*` and `set_loop_size_*` operations described above. + +The following operations provide direct control over DMA transfers but require manual stride and size configuration. + +#### `pto.copy_gm_to_ubuf(src: GMPtr, dst: UBPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, left_padding_count: pto.i64 = 0, right_padding_count: pto.i64 = 0, enable_ub_pad: pto.i1 = False, l2_cache_ctl: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data from Global Memory (GM) to Unified Buffer (UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `GMPtr` | Source GM pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | +| `n_burst` | `pto.i64` | Number of bursts | +| `len_burst` | `pto.i64` | Bytes copied by each burst | +| `left_padding_count` | `pto.i64` | Left padding count, defaults to `0` | +| `right_padding_count` | `pto.i64` | Right padding count, defaults to `0` | +| `enable_ub_pad` | `pto.i1` | Convenience alias for `data_select_bit`, defaults to `False` | +| `l2_cache_ctl` | `pto.i64` | L2 cache control operand, defaults to `0` | +| `gm_stride` | `pto.i64` | GM-side stride in bytes | +| `ub_stride` | `pto.i64` | UB-side stride in bytes | + +**Returns**: None (side-effect operation) + +**Notes**: +- **Keyword arguments**: The keyword form shown above is the recommended public API surface. Use named arguments for clarity. +- **Padding control**: Set `enable_ub_pad=True` to enable padded GM→UB transfers. The pad value must be configured separately using `pto.set_mov_pad_val` before the DMA operation (see [Pad Fill Semantics](#pad-fill-semantics) for details). +- **Pad value source**: When padding is enabled, the fill scalar comes from the hardware pad register configured by `pto.set_mov_pad_val`. You must call this operation explicitly before the DMA transfer. +- **ABI compatibility**: The lowering preserves the underlying PTO operand order while providing a more ergonomic keyword interface. + +**Example**: +```python +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=128, + gm_stride=128, + ub_stride=128, + enable_ub_pad=False, +) +``` + +**Padding Example**: +```python +# First configure the hardware pad register with a scalar value +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data + +# Then perform padded DMA transfer +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, +) +``` + +#### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data within Unified Buffer (UB → UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | + +**Returns**: None (side-effect operation) + +#### `pto.copy_ubuf_to_gm(src: UBPtr, dst: GMPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, reserved: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data from Unified Buffer (UB) to Global Memory (GM). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `GMPtr` | Destination GM pointer | +| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | +| `n_burst` | `pto.i64` | Number of bursts | +| `len_burst` | `pto.i64` | Bytes copied by each burst | +| `reserved` | `pto.i64` | Reserved operand, defaults to `0` | +| `gm_stride` | `pto.i64` | GM-side stride in bytes | +| `ub_stride` | `pto.i64` | UB-side stride in bytes | + +**Returns**: None (side-effect operation) + +**Notes**: +- In TileLang DSL, the keyword form above is the recommended public surface. +- `gm_stride`/`ub_stride` are ergonomic aliases for the low-level `burst_dst_stride`/`burst_src_stride` operands. +- The lowering still maps to the underlying low-level PTO operand ABI in positional order. + +**Example**: +```python +pto.copy_ubuf_to_gm( + src=ub_ptr, + dst=gm_ptr, + n_burst=32, + len_burst=128, + gm_stride=128, + ub_stride=128, +) +``` diff --git a/ptodsl/docs/user_guide/09-vector-memory-operations.md b/ptodsl/docs/user_guide/09-vector-memory-operations.md new file mode 100644 index 000000000..f7a20fd76 --- /dev/null +++ b/ptodsl/docs/user_guide/09-vector-memory-operations.md @@ -0,0 +1,1058 @@ +### Enum Types for Vector Memory Operations + +The current DSL exposes type-safe Enum operands for the dual load/store +distribution families: + +- **`VLoadDist`** for `pto.vlds` + - `VLoadDist.NORM`: ordinary load + - `VLoadDist.UNPK_B8`, `VLoadDist.UNPK_B16`, `VLoadDist.UNPK_B32`: unpacking loads + - `VLoadDist.BRC_B8`, `VLoadDist.BRC_B16`, `VLoadDist.BRC_B32`: broadcast loads + - `VLoadDist.US_B8`, `VLoadDist.US_B16`, `VLoadDist.DS_B8`, `VLoadDist.DS_B16`: strided/narrow load families + +- **`VStoreDist`** for `pto.vsts` + - `VStoreDist.NORM_B8`, `VStoreDist.NORM_B16`, `VStoreDist.NORM_B32`: ordinary stores + - `VStoreDist.ONE_POINT_B8`, `VStoreDist.ONE_POINT_B16`, `VStoreDist.ONE_POINT_B32`: one-point stores + - `VStoreDist.PK_B16`, `VStoreDist.PK_B32`, `VStoreDist.PK_B64`: packed stores + - `VStoreDist.PK4_B32`, `VStoreDist.MRG4CHN_B8`, `VStoreDist.MRG2CHN_B8`, `VStoreDist.MRG2CHN_B16`: merged packed stores + +- **`DeinterleaveDist`** for `pto.vldsx2` + - `DeinterleaveDist.DINTLV`: alternating-element deinterleave + - `DeinterleaveDist.BDINTLV`: block deinterleave + - compatibility aliases: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, + `DeinterleaveDist.B32`, `DeinterleaveDist.BD` + +- **`InterleaveDist`** for `pto.vstsx2` + - `InterleaveDist.INTLV`: interleave two vectors into one destination stream + - compatibility aliases: `InterleaveDist.B8`, `InterleaveDist.B16`, + `InterleaveDist.B32` + +- **`PostUpdateMode`** for `pto.vstur` + - `PostUpdateMode.NO_POST_UPDATE`: preserve the current hardware AR state + - `PostUpdateMode.POST_UPDATE`: advance the hardware AR state after the store + +The canonical VPTO v0.3 spellings are the enum values: + +- `VLoadDist.UNPK_B16.value == "UNPK_B16"` +- `VStoreDist.PK_B32.value == "PK_B32"` +- `DeinterleaveDist.DINTLV.value == "DINTLV"` +- `DeinterleaveDist.BDINTLV.value == "BDINTLV"` +- `InterleaveDist.INTLV.value == "INTLV"` +- `PostUpdateMode.NO_POST_UPDATE.value == "NO_POST_UPDATE"` +- `PostUpdateMode.POST_UPDATE.value == "POST_UPDATE"` + +`pto.vstur` mode is intentionally Enum-only in the DSL. Unlike the legacy +distribution-token compatibility retained for some older load/store families, +raw strings such as `"POST_UPDATE"` are not accepted for `PostUpdateMode`. + +For migration convenience, the implementation still accepts legacy raw strings +such as `"DINTLV_B32"` and `"INTLV_B32"`, but new DSL code should prefer the +Enum operands. + +- **`StrideMode`**: Stride modes for `pto.vsld` + - `S3_B16`: Stride 3, block size 16 + - `S4_B64`: Stride 4, block size 64 + - `S8_B32`: Stride 8, block size 32 + - `S2_B64`: Stride 2, block size 64 + +### Address Generation Syntax Sugar + +To simplify address calculation and reduce manual byte offset computation errors, TileLang DSL provides syntactic sugar for vector load/store operations using element-based indexing. This syntax automatically computes the byte offset based on tile shape, element type, and layout. + +#### Indexing Syntax + +The syntax supports two indexing modes for different operations: + +1. **Vector-range indexing** (for vector load/store operations): + - **Row-major layout (default)**: `tile[row_index, col_start:]` + - `row_index`: Row index (0-based) + - `col_start:`: Starting column index followed by colon, indicating a vector-width contiguous region starting from this column + - The colon (`:`) indicates an implicit vector-width range determined by hardware vector size (256 bytes) and element type + + - **Column-major layout**: `tile[row_start:, col_index]` + - `row_start:`: Starting row index followed by colon, indicating a vector-width contiguous region starting from this row + - `col_index`: Column index (0-based) + - Used for column-major tiles (`BLayout.COL_MAJOR`) where elements are stored column-wise + + - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) + - `start:`: Starting element index followed by colon + + Tile indexing sugar only accepts an open-ended vector slice. Python slice + forms with an explicit `stop` or `step` are not supported for `Tile` + indexing. For example, `tile[row, col:col_end]`, `tile[row, col::2]`, + `tile[row_start:row_end, col]`, and `tile[start:stop:step]` are invalid. + +2. **Single-element indexing** (for scalar load operations like `pto.vsld`): + - **Row-major layout (default)**: `tile[row_index, col_index]` + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + + - **Column-major layout**: `tile[row_index, col_index]` (same syntax) + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Same syntax as row-major; the layout determines how the offset is computed + + - **1D tile indexing**: `tile[pos]` + - `pos`: Element index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + +#### Vector Width Calculation + +The number of elements loaded/stored in a single vector operation is determined by: + +``` +vector_lanes = 256 // element_size_bytes(element_type) +``` + +**Convenience API**: Use `pto.elements_per_vreg(dtype)` to compute the number of elements per vector register for a given element type (e.g., `pto.elements_per_vreg(pto.f32)` returns 64, `pto.elements_per_vreg(pto.f16)` returns 128). See [Type Query Operations](07-frontend-operations.md#type-query-operations) for full documentation. + +Where `element_size_bytes` is: +- 1 byte for `i8`, `si8`, `ui8` +- 2 bytes for `i16`, `si16`, `ui16`, `f16`, `bf16` +- 4 bytes for `i32`, `si32`, `ui32`, `f32` +- 8 bytes for `i64`, `si64`, `ui64` + +#### Offset Computation + +The byte offset is automatically computed based on tile layout: + +- **Row-major layout** (`BLayout.ROW_MAJOR`): + ``` + offset = (row_index * stride_row + col_start) * element_size_bytes + ``` + where `stride_row` is the row stride in elements (typically `tile.shape[1]` for contiguous tiles). + +- **Column-major layout** (`BLayout.COL_MAJOR`): + - For syntax `tile[row_start:, col_index]`: + ``` + offset = (col_index * stride_col + row_start) * element_size_bytes + ``` + - For backward compatibility with traditional offset calculation: + ``` + offset = (col_start * stride_col + row_index) * element_size_bytes + ``` + where `stride_col` is the column stride in elements (typically `tile.shape[0]` for contiguous tiles), `row_start` is the starting row index, and `col_index` is the column index. + +**Note**: +- For single-element indexing (`tile[row, col]` or `tile[pos]`), the same offset formulas apply with `col_start` replaced by `col_index` (or `start` replaced by `pos` for 1D tiles). +- For column-major vector-range indexing (`tile[row_start:, col_index]`), the offset formula uses `row_start` as the starting position along the contiguous dimension. +- The compiler automatically handles the appropriate substitution based on the indexing syntax and tile layout. + +#### Constraints + +1. **Boundary checks**: The requested region must be within tile bounds: + - **For vector-range indexing** (`:` syntax): + - **Row-major layout** (`tile[row_index, col_start:]`): + - `row_index < tile.shape[0]` and `col_start + vector_lanes <= tile.shape[1]` + - **Column-major layout** (`tile[row_start:, col_index]`): + - `row_start + vector_lanes <= tile.shape[0]` and `col_index < tile.shape[1]` + - **1D tile indexing**: `tile[start:]` + - `start + vector_lanes <= tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + - **For single-element indexing** (no `:` syntax): + - 2D: `row_index < tile.shape[0]` and `col_index < tile.shape[1]` (same for both layouts) + - 1D: `pos < tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + +2. **Alignment**: The computed offset must satisfy hardware alignment requirements for the operation. + +3. **Full vectors only**: The `:` syntax always loads/stores a full vector width. For partial vectors, use the traditional byte offset approach with explicit mask handling. + +4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. + +5. **No explicit slice bounds/stride for `Tile` indexing**: `Tile` vector-range + indexing only supports the open-ended forms `tile[start:]`, + `tile[row, col:]`, and `tile[row_start:, col_index]` (for column-major + layout). `stop` and `step` syntax are not accepted in user-guide Tile + indexing. + +#### Supported Operations + +The indexing syntax is supported for all vector load and store operations with the following syntax mapping: + +- **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): + - Load operations: `vlds`, `vldas`, `vldus`, `vldsx2` + - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstsx2` + +- **Single-element indexing** (`tile[row, col]` or `tile[pos]`): + - Load operations: `vsld` (scalar load with broadcast) + +#### Examples + +The following examples use row-major layout syntax. For column-major tiles, use `tile[row_start:, col_index]` syntax instead of `tile[row_index, col_start:]`. + +```python +# 2D tile indexing (row-major layout) +vec = pto.vlds(tile[i, j:]) # Load vector from row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[i, j:], mask) # Store vector with mask + +# 1D tile indexing +vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store vector with mask + +# Dual load with deinterleave +low, high = pto.vldsx2(tile[i, j:], "DINTLV") + +# Aligned load with indexing +vec = pto.vldas(tile[i, j:], align) + +# Scalar load (broadcast) +vec = pto.vsld(tile[i, j]) # Load scalar at tile[i,j] and broadcast to vector +``` + +#### Comparison with Manual Offset Calculation + +**Traditional approach (error-prone):** +```python +# Manual byte offset calculation for f32 tile +rows, cols = tile.shape +row_offset = i * cols * 4 # Hard-coded 4 bytes for f32 +col_offset = j * 4 +offset = row_offset + col_offset +vec = pto.vlds(tile, offset) +``` + +**New syntax (type-safe):** +```python +# Automatic offset calculation +vec = pto.vlds(tile[i, j:]) # Compiler computes correct offset for any element type +``` + +The syntax sugar eliminates manual byte calculations, reduces errors, and makes code generic across different element types (e.g., the same kernel works for both `f16` and `f32` without modification). + +### Vector Load Operations + +Operations for loading data from memory into vector registers. + +#### `pto.vlds(buf: ptr, offset: Index, dist: pto.VLoadDist | None = None) -> VRegType` [Advanced Tier] +#### `pto.vlds(tile[row, col:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] +#### `pto.vlds(tile[start:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] + +**Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements +- `dist` is optional. When omitted, the load uses the backend default layout for the vector family. +- `dist` must be a `pto.VLoadDist` enum value. + +**Examples**: +```python +# Traditional byte-offset syntax +vec = pto.vlds(ub_ptr, lane * 256) +vec_unpacked = pto.vlds(ub_ptr, lane * 128, dist=pto.VLoadDist.UNPK_B16) + +# New element-indexing syntax +vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 +vec = pto.vlds(tile[k:]) # Load from 1D tile, elements k to k+vector_lanes-1 + +# Generic kernel that works for both f16 and f32 +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_scale(src: pto.Tile, dst: pto.Tile, scale: pto.f32): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): # vector_lanes computed from element type + # No manual byte calculation needed! + vec = pto.vlds(src[i, j:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, dst[i, j:], all_mask) +``` + +#### `pto.vldas(buf: ptr) -> pto.align` [Advanced Tier] +#### `pto.vldas(tile[row, col:]) -> pto.align` [Basic Tier] +#### `pto.vldas(tile[start:]) -> pto.align` [Basic Tier] + +**Description**: Prime alignment buffer for subsequent unaligned load. Returns alignment state for use with `pto.vldus`. Supports both pointer syntax and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align` | `pto.align` | Alignment state for use with `pto.vldus` | + +**Examples**: +```python +# Pointer syntax +align = pto.vldas(ub_ptr) + +# Element-indexing syntax +align = pto.vldas(tile[i, j:]) +align = pto.vldas(tile[k:]) +``` + +#### `pto.vldus(buf: ptr, align: pto.align) -> (VRegType, pto.align, ptr)` [Advanced Tier] +#### `pto.vldus(tile[row, col:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] +#### `pto.vldus(tile[start:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] + +**Description**: Unaligned load using primed align state. Requires alignment state from `pto.vldas` or previous `pto.vldus`. Updates alignment state and base pointer for subsequent loads. Supports both pointer syntax and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | +| _or_ | | | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Assembled vector value | +| `align_out` | `pto.align` | Updated alignment state for next load | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- A matching `pto.vldas` must appear before the first dependent `pto.vldus` stream in the same vector loop +- Both alignment state and base address advance across the stream +- If DSL authoring uses explicit byte/element offsets, the frontend first rewrites them into pointer/index expressions before lowering to this VPTO form. + +**Examples**: +```python +# Pointer syntax - requires alignment state priming +align = pto.vldas(ub_ptr) +vec, align_out, base_out = pto.vldus(ub_ptr, align) + +# Element-indexing syntax +align = pto.vldas(tile[i, j:]) +vec, align_out, base_out = pto.vldus(tile[i, j:], align) + +# Multiple unaligned loads in a stream +align = pto.vldas(tile[k:]) +for n in range(4): + vec, align, base = pto.vldus(tile[k:], align) # alignment state updates +``` + + +#### `pto.vldsx2(buf: ptr, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` [Advanced Tier] +#### `pto.vldsx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] +#### `pto.vldsx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] + +**Description**: Dual vector load with deinterleave (AoS → SoA conversion). Loads interleaved data from a single buffer and deinterleaves into two vectors. Supports both byte-offset and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | +| _or_ | | | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | + +**Constraints**: +- Source buffer must be in UB memory space +- Offset must satisfy alignment requirements for the selected distribution mode +- The requested vector region must be within tile bounds (for element-indexing syntax) +- Distribution mode must match element type (e.g., `"DINTLV"` for 32-bit elements) + +**Examples**: +```python +# Byte-offset syntax +low, high = pto.vldsx2(ub_ptr, offset, pto.DeinterleaveDist.DINTLV) + +# Element-indexing syntax +low, high = pto.vldsx2(tile[i, j:], pto.DeinterleaveDist.DINTLV) +low, high = pto.vldsx2(tile[k:], pto.DeinterleaveDist.DINTLV) + +# Example: Load interleaved XY pairs into separate X/Y vectors +x_vec, y_vec = pto.vldsx2(xy_tile[i, j:], pto.DeinterleaveDist.DINTLV) +``` + +#### `pto.vsld(buf: ptr, offset: Index, stride: StrideMode) -> VRegType` [Advanced Tier] +#### `pto.vsld(tile[row, col], stride: StrideMode) -> VRegType` [Basic Tier] +#### `pto.vsld(tile[pos], stride: StrideMode) -> VRegType` [Basic Tier] + +**Description**: Strided load with fixed stride pattern. Loads elements from memory with regular stride pattern. The offset parameter encodes displacement with selected stride mode. This is a deprecated compatibility family. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte displacement encoded with selected stride mode | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. Determines which sub-elements are read from each source block. | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | +| _or_ | | | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector with strided pattern | + +**Constraints**: +- The selected stride token determines which sub-elements are read from each source block +- This operation family is deprecated; prefer other load patterns when possible + +**Examples**: +```python +from pto import StrideMode + +# Byte-offset syntax +vec = pto.vsld(ub_ptr, offset, StrideMode.S4_B64) + +# Element-indexing syntax +vec = pto.vsld(tile[i, j], StrideMode.S3_B16) +vec = pto.vsld(tile[k], StrideMode.S8_B32) +``` + +#### `pto.vgather2(buf: ptr, offsets: Index, active_lanes: Index) -> VRegType` [Advanced Tier] + +**Description**: Indexed gather from UB. Gathers elements from a single buffer using per-lane offsets, with participation bounded by active lanes count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of lanes that participate (bounds gather operation) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Constraints**: +- Only the first `active_lanes` offsets participate in the gather +- Index element width and interpretation must match selected gather form +- Each effective address must satisfy the gather form's alignment rules + +**Example**: +```python +vec = pto.vgather2(buf, offsets, active_lanes) +``` + +#### `pto.vgather2_bc(buf: ptr, offsets: Index, mask: MaskType) -> VRegType` [Advanced Tier] + +**Description**: Gather with broadcast, conditioned by mask. Gathers elements from a single buffer using per-lane offsets, with mask gating lane participation. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `mask` | `MaskType` | Mask gating which lanes participate | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Constraints**: +- Masked-off lanes do not participate in address coalescing and do not trigger address overflow exceptions +- Destination lanes for masked-off lanes are zero-filled +- This is a backward-compatible operation family + +**Example**: +```python +vec = pto.vgather2_bc(buf, offsets, mask) +``` + +#### `pto.vgatherb(buf: ptr, offsets: Index) -> VRegType` [Advanced Tier] + +**Description**: Byte‑granularity gather load. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer | +| `offsets` | `Index` | Byte offsets | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Example**: +```python +vec = pto.vgatherb(buf, offsets) +``` + +#### `pto.vsldb(buf: ptr, offset: Index, mask: MaskType) -> VRegType` [Advanced Tier] +#### `pto.vsldb(tile[row, col], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] +#### `pto.vsldb(tile[pos], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] + +**Description**: Block-strided load for 2D tile access. Loads elements with block stride pattern controlled by packed offset word and mask. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | +| _or_ | | | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector with block-strided pattern | + +**Constraints**: +- The offset encodes block stride and repeat pattern, not a plain byte displacement +- If a block is masked off, the corresponding destination block is zeroed +- Masked-off blocks must not raise address overflow exceptions + +**Example**: +```python +# Byte-offset syntax +vec = pto.vsldb(ub_ptr, control_word, mask) + +# Element-indexing syntax +vec = pto.vsldb(tile[i, j], control_word, mask) +vec = pto.vsldb(tile[k], control_word, mask) +``` + +### Vector Store Operations + +Operations for storing data from vector registers to memory. + +#### `pto.vsts(vec: VRegType, buf: ptr, offset: Index, mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Advanced Tier] +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] + +**Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | +| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `mask` | `MaskType` | Predicate mask | +| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements +- `dist` is optional. When omitted, the store uses the backend default layout for the vector family. +- Current TileLang DSL v1 accepts exactly one keyword attr on `pto.vsts`: `dist=...`. +- `dist` must be a `pto.VStoreDist` enum value. +- `mask` must match the effective store payload granularity, which may differ from the vector element family when `dist` repacks lanes. +- Common width-changing cases: + default / `NORM_B32` stores expect `mask_b32` for `f32`/`i32`-family vectors; + `PK_B32` also expects `mask_b32` and is used by narrow stores such as `f32 -> f16` `tcvt`; + `PK_B16` expects `mask_b16`. + +**Examples**: +```python +# Byte-offset syntax +pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) + +# Element-indexing syntax +pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 + +# VPTO-aligned packed store +vec_f16 = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, +) +pto.vsts(vec_f16, tile[i, j:], mask32, dist=pto.VStoreDist.PK_B32) + +# In a generic kernel +@pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_store(src: pto.Tile, dst: pto.Tile): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): + vec = pto.vlds(src[i, j:]) + pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation +``` + +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Predicate store (`pto.psts`) writes the packed payload represented by +`MaskType` to UB memory. This is the dynamic-offset form of the VPTO predicate-store +family (`psts` vs `psti`): the payload semantics are identical, and only the offset +delivery form differs. + +**Parameters (advanced byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate payload to store | +| `buf` | `ptr` | Pointer to destination UB buffer (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Runtime offset (`index`) | +| `dist` | `PredicateDist` | Predicate distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | + +**Returns**: None (side-effect operation) + +**DIST semantics (VPTO-aligned)**: +- `PredicateDist.NORM`: store packed predicate payload into a normal destination space of size `VL/8`. +- `PredicateDist.PK`: store packed predicate payload into a destination space of size `VL/16`, keeping one bit out of every two bits. + +**Notes**: +- `pto.psts` is intentionally documented as explicit `buf + offset` surface in DSL v1. +- Packed predicate payload layout is bit-level (`VL/8` or `VL/16`), so tile element-indexing is not part of the stable Basic Tier contract. +- The pointer + offset form maps directly to explicit `base[offset]`. +- Authoritative predicate-memory-family semantics are documented in `10-predicate-operations.md`. + +#### `pto.vsst(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` + +**Description**: Scalar to vector store (broadcast scalar to all lanes). Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `ptr` | Pointer to destination buffer (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vstsx2(low: VRegType, high: VRegType, buf: ptr, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` + +**Description**: Dual interleaved store (SoA → AoS conversion). Stores two vectors interleaved into a single buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Destination buffer must be in UB memory space +- Offset must satisfy alignment requirements for the selected distribution mode +- The destination vector region must be within tile bounds (for element-indexing syntax) +- Distribution mode must match element type (e.g., `"INTLV"` for 32-bit elements) +- The two source vectors form an ordered pair; interleave semantics must be preserved + +**Examples**: +```python +# Byte-offset syntax +pto.vstsx2(x_vec, y_vec, ub_ptr, offset, pto.InterleaveDist.INTLV, mask) + +# Element-indexing syntax +pto.vstsx2(x_vec, y_vec, tile[i, j:], pto.InterleaveDist.INTLV, mask) +pto.vstsx2(x_vec, y_vec, tile[k:], pto.InterleaveDist.INTLV, mask) + +# Example: Store separate X/Y vectors as interleaved XY pairs +pto.vstsx2(x_vec, y_vec, xy_tile[i, j:], pto.InterleaveDist.INTLV, all_mask) +``` + +#### `pto.vsta(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.vsta(align: pto.align, tile[row, col:]) -> None` [Basic Tier] +#### `pto.vsta(align: pto.align, tile[start:]) -> None` [Basic Tier] + +**Description**: Flush alignment state to memory. Writes buffered tail bytes from alignment state to UB memory. Consumes the alignment state after flush. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Flush displacement | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- The flush address must match the post-updated address expected by the preceding unaligned-store stream +- After the flush, the corresponding store alignment state is consumed +- A final flush operation is required to commit buffered bytes after unaligned-store sequences +- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream + +**Example**: +```python +# Byte-offset syntax +pto.vsta(align, ub_ptr, offset) + +# Element-indexing syntax +pto.vsta(align, tile[i, j:]) +pto.vsta(align, tile[k:]) +``` + +#### `pto.vscatter(vec: VRegType, buf: ptr, offsets: Index, active_lanes: Index) -> None` [Advanced Tier] + +**Description**: Indexed scatter to UB. Stores vector elements to irregular locations using per-lane offsets, with participation bounded by active lanes count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Source vector to scatter | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of lanes that participate (bounds scatter operation) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Only `b8`, `b16`, and `b32` element sizes are supported +- Current TileLang DSL / VPTO path requires `i32` index vectors +- Each computed address must be element-aligned +- If indices alias, only one write is guaranteed (winning lane is implementation-defined) +- Only the first `active_lanes` offsets participate in the scatter + +**Example**: +```python +pto.vscatter(vec, buf, offsets, active_lanes) +``` + +#### `pto.vsstb(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vsstb(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` [Basic Tier] +#### `pto.vsstb(scalar: ScalarType, tile[start:], mask: MaskType) -> None` [Basic Tier] + +**Description**: Scalar to vector store with broadcast (enhanced version of `vsst`). Supports both byte‑offset and element‑indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `ptr` | Pointer to destination buffer | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Example**: +```python +# Byte-offset syntax +pto.vsstb(pto.f32(0.0), ub_ptr, offset, mask) + +# Element-indexing syntax +pto.vsstb(pto.f32(1.0), tile[i, j:], mask) +``` + +#### `pto.vstar(align: pto.align, buf: ptr) -> None` [Advanced Tier] +#### `pto.vstar(align: pto.align, tile[row, col:]) -> None` [Basic Tier] +#### `pto.vstar(align: pto.align, tile[start:]) -> None` [Basic Tier] + +**Description**: Flush alignment state using the register-update form. Writes buffered tail bytes from alignment state to UB memory. The implicit update state must correspond to the same store stream that produced the alignment state. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- The implicit update state consumed by this flush must correspond to the same store stream that produced the alignment state +- A final flush operation is required to commit buffered bytes after unaligned-store sequences +- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream + +**Example**: +```python +# Byte-offset syntax +pto.vstar(align, ub_ptr) + +# Element-indexing syntax +pto.vstar(align, tile[i, j:]) +pto.vstar(align, tile[k:]) +``` + +#### `pto.vstas(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.vstas(align: pto.align, tile[row, col:], offset: Index) -> None` [Basic Tier] +#### `pto.vstas(align: pto.align, tile[start:], offset: Index) -> None` [Basic Tier] + +**Description**: Scalar-register-offset form of alignment-state flush. Writes buffered tail bytes from alignment state to UB memory with explicit scalar offset. Uses same buffered-tail semantics as `pto.vsta`. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | +| `offset` | `Index` | Scalar-register style displacement | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `offset` | `Index` | Scalar-register style displacement | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `offset` | `Index` | Scalar-register style displacement | + +**Returns**: None (side-effect operation) + +**Example**: +```python +# Byte-offset syntax +pto.vstas(align, ub_ptr, offset) + +# Element-indexing syntax +pto.vstas(align, tile[i, j:], offset) +pto.vstas(align, tile[k:], offset) +``` + +### Stateful Store Operations + +Operations for storing data with stateful semantics. + +#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Predicate unaligned store with align state update. Stores predicate mask with alignment state threading. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated alignment state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Part of stateful unaligned-store sequence with alignment state threading + +#### `pto.vstu(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, mode: Index) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Unaligned store with explicit threaded alignment/base state. Models a stateful unaligned-store sequence in SSA form. Requires a final `pto.vsta`/`pto.vstas`/`pto.vstar` to flush trailing buffered bytes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `base_in` | `ptr` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `mode` | `Index` | Mode selecting post-update behavior | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Models stateful unaligned-store sequence in SSA form +- Final flush operation required to commit buffered bytes + +**Example**: +```python +# Stateful unaligned store + final flush (vsta form) +align1, base1 = pto.vstu(align0, base0, vec0, ub_ptr, mode) +align2, base2 = pto.vstu(align1, base1, vec1, ub_ptr, mode) +pto.vsta(align2, ub_ptr, tail_offset) +``` + +#### `pto.vstus(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, offset: Index) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Scalar-offset unaligned store with threaded state. Same roles as `pto.vstu` but with explicit scalar displacement. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `base_in` | `ptr` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `offset` | `Index` | Scalar displacement | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Same final flush requirement and state-threading constraints as `pto.vstu` + +**Example**: +```python +# Scalar-offset threaded form + final flush (vstas form) +align1, base1 = pto.vstus(align0, base0, vec0, ub_ptr, offset0) +align2, base2 = pto.vstus(align1, base1, vec1, ub_ptr, offset1) +pto.vstas(align2, ub_ptr, flush_offset) +``` + +#### `pto.vstur(align_in: pto.align, vec: VRegType, buf: ptr, mode: PostUpdateMode = pto.PostUpdateMode.NO_POST_UPDATE) -> pto.align` [Advanced Tier] + +**Description**: Register-update unaligned store form. Updates only the residual alignment state without base pointer update. Requires matching flush operation to emit trailing bytes. The optional `mode` operand is a typed Enum and controls whether the hardware performs post-update on the implicit AR state. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `mode` | `PostUpdateMode` | Optional post-update mode. Defaults to `pto.PostUpdateMode.NO_POST_UPDATE`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | + +**Constraints**: +- Updates only residual alignment state (no base pointer update) +- Matching flush operation still required to emit trailing bytes + +**Example**: +```python +# Residual-state form + final flush (vstar form) +align1 = pto.vstur(align0, vec0, ub_ptr) +align2 = pto.vstur(align1, vec1, ub_ptr) +pto.vstar(align2, ub_ptr) + +# Explicit post-update mode with typed Enum +align3 = pto.vstur(align2, vec2, ub_ptr, pto.PostUpdateMode.POST_UPDATE) +``` + +#### Align-State Store Closed Loop + +For unaligned store families, the state must form a closed loop: + +1. Start from an incoming `align` state. +2. Thread state through one or more `vstu` / `vstus` / `vstur` operations. +3. Terminate the stream with exactly one flush op: `vsta` or `vstas` or `vstar`. +4. Do not reuse a flushed `align` state in another stream. diff --git a/ptodsl/docs/user_guide/10-predicate-operations.md b/ptodsl/docs/user_guide/10-predicate-operations.md new file mode 100644 index 000000000..8cc92da2c --- /dev/null +++ b/ptodsl/docs/user_guide/10-predicate-operations.md @@ -0,0 +1,637 @@ +### Predicate Operations + +Operations for creating and manipulating typed masks. + +**Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. + +**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +**Predicate Part Enum**: `pto.ppack` and `pto.punpack` require the `PredicatePart` enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`; these lower to the VPTO canonical `PART` tokens `"LOWER"` and `"HIGHER"`. + +**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate memory families. Load families (`plds`, `pld`, `pldi`) use `NORM`, `US`, and `DS`. Store families (`psts`, `pst`, `psti`) use `NORM` and `PK`. + +**Pattern coverage**: The VPTO canonical predicate-generation families use `PAT_*` tokens such as `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, `PAT_VL*`, `PAT_M3`, and `PAT_M4`. The Python DSL surface may expose only a subset through `pto.MaskPattern`; check the enum for currently available values. + +#### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Creates an 8-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +**Constraints**: +- Used with `i8` vector operations + +**Example**: +```python +mask8 = pto.pset_b8(PAT.ALL) +``` + +#### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Creates a 16-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations + +**Example**: +```python +mask16 = pto.pset_b16(PAT.ALL) +``` + +#### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Creates a 32-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations + +**Example**: +```python +mask32 = pto.pset_b32(PAT.ALL) +``` + +#### `pto.pge_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates an 8-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity tail mask | + +**Constraints**: +- Used with `i8` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask pattern lowered as `PAT_VL16` +tail_mask = pto.pge_b8(PAT.VL16) +``` + +#### `pto.pge_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 16-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity tail mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask for first 16 lanes +tail_mask = pto.pge_b16(PAT.VL16) +``` + +#### `pto.pge_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 32-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity tail mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask for first 32 lanes +tail_mask = pto.pge_b32(PAT.VL32) +``` + +#### `pto.plt_b8(scalar: pto.i32) -> (pto.mask_b8, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates an 8-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `i8` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b8(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.plt_b16(scalar: pto.i32) -> (pto.mask_b16, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 16-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b16(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.plt_b32(scalar: pto.i32) -> (pto.mask_b32, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 32-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `f32`/`i32` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b32(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.make_mask(element_type: Type, value: pto.i32 | pto.MaskPattern) -> MaskType | (MaskType, pto.i32)` + +**Description**: Creates a mask with appropriate bitwidth (8, 16, or 32) based on element type, automatically inferring whether to perform tail processing or pattern-based mask generation based on the `value` parameter type. This convenience function eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` and `pset_b8`/`pset_b16`/`pset_b32` operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | +| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (for example `pto.MaskPattern.ALL` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Generated mask with appropriate granularity | +| `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | + +**Constraints**: +- The `element_type` must be one of: `f32`, `f16`, `bf16`, or an 8/16/32-bit integer family member (`i*`, `si*`, `ui*`) +- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`/`si32`/`ui32`, 16-bit for `f16`/`bf16`/`i16`/`si16`/`ui16`, and 8-bit for `i8`/`si8`/`ui8` +- The function infers the operation mode from the `value` parameter type at compile time: + - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) + - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) + +**Implementation Note**: This function is a DSL macro that performs type-based dispatch at compile time: +- When `value` is a `pto.i32` expression: expands to corresponding `plt_b` instruction (`plt_b32`, `plt_b16`, or `plt_b8`) +- When `value` is a `pto.MaskPattern` enum value: expands to corresponding `pset_b` instruction (`pset_b32`, `pset_b16`, or `pset_b8`) + +**Example**: +```python +# Tail processing with f32 vectors: value is pto.i32 → expands to plt_b32 +mask_f32, remaining_f32 = pto.make_mask(pto.f32, remaining_elements) + +# Tail processing with f16 vectors: value is pto.i32 → expands to plt_b16 +mask_f16, remaining_f16 = pto.make_mask(pto.f16, remaining_elements) + +# Tail processing with i8 vectors: value is pto.i32 → expands to plt_b8 +mask_i8, remaining_i8 = pto.make_mask(pto.i8, remaining_elements) + +# Pattern-based mask with f32 vectors: value is MaskPattern enum → expands to pset_b32 +mask_all_f32 = pto.make_mask(pto.f32, PAT.ALL) + +# Pattern-based mask with f16 vectors: value is MaskPattern enum → expands to pset_b16 +mask_even_f16 = pto.make_mask(pto.f16, PAT.EVEN) + +# Pattern-based mask with i8 vectors: value is MaskPattern enum → expands to pset_b8 +mask_all_i8 = pto.make_mask(pto.i8, PAT.ALL) + +# Type annotations help clarify expected parameter types +remaining: pto.i32 = 1024 +mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing +mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode +``` + +#### `pto.ppack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Narrowing pack of a predicate register. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | +| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `packed` | `MaskType` | Packed mask | + +**Example**: +```python +packed = pto.ppack(mask, pto.PredicatePart.LOWER) +``` + +#### `pto.punpack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Widening unpack of a predicate register. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | +| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Unpacked mask | + +**Example**: +```python +unpacked = pto.punpack(mask, pto.PredicatePart.HIGHER) +``` + +#### `pto.pbitcast(mask: MaskType, to_type: MaskType) -> MaskType` + +**Description**: Reinterprets a typed predicate mask as another typed mask granularity without changing the underlying predicate bit image. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | +| `to_type` | `MaskType` | Target mask type marker such as `pto.mask_b16` or `pto.mask_b32` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Reinterpreted mask with the requested target granularity | + +**Constraints**: +- `mask` must already be a typed predicate value +- `to_type` must be one of the DSL mask type markers: `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` +- this is a bit reinterpretation helper, not a logical predicate transform; it does not insert packing, unpacking, interleaving, or deinterleaving by itself +- use `pto.ppack`, `pto.punpack`, `pto.pdintlv_b8`, or `pto.pintlv_b16` when the predicate image itself must be rearranged + +**Example**: +```python +mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) + +mask0_b16, mask1_b16 = pto.pintlv_b16(mask_b16, pto.pset_b16(PAT.ALL)) +mask0_b32 = pto.pbitcast(mask0_b16, pto.mask_b32) +``` + +#### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` + +**Description**: Predicate negation under a same-granularity mask gate. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | +| `gate` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `negated` | `MaskType` | Negated mask | + +#### `pto.psel(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Selects between two masks using a third mask as selector. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Selection mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Selected mask | + +#### `pto.plds(buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with scalar-index style offset form. This is the default DSL surface for loading predicate masks from UB memory. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `offset` | `Index` | Scalar/index-style offset | +| `dist` | `PredicateDist` | Distribution mode (default: `PredicateDist.NORM`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.plds(buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.pld(buf: ptr, offset: Index, dist: PredicateDist) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with areg/index register style offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `offset` | `Index` | Areg/index-style offset | +| `dist` | `PredicateDist` | Distribution mode | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.pld(buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.pldi(buf: ptr, imm_offset: pto.i32, dist: PredicateDist) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with immediate-offset encoding form. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `imm_offset` | `pto.i32` | Immediate-offset operand | +| `dist` | `PredicateDist` | Distribution mode | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.pldi(buf, 0, pto.PredicateDist.NORM) +``` + +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to UB memory using the VPTO dynamic-offset +`psts` form. This is the dynamic counterpart of `psti`: both encode the same +predicate payload semantics, while offset delivery differs (runtime `index` vs +constant immediate). + +**Parameters (Advanced Tier: explicit pointer surface)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `offset` | `Index` | Runtime offset (`index`) | +| `dist` | `PredicateDist` | Distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | + +**DIST semantics (VPTO-aligned)**: +- `NORM`: stores packed predicate payload into destination space of size `VL/8`. +- `PK`: stores packed predicate payload into destination space of size `VL/16`, + keeping one bit out of every two bits. + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.psts(mask, buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.pst(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to UB memory using areg/index offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `offset` | `Index` | Areg/index-style offset | +| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.pst(mask, buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.psti(mask: MaskType, buf: ptr, imm_offset: pto.i32, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to UB memory using immediate-offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `imm_offset` | `pto.i32` | Immediate-offset operand | +| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.psti(mask, buf, pto.i32(8), pto.PredicateDist.PK) +``` + +#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Unaligned predicate store with align-state update. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Input alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated alignment state | +| `base_out` | `ptr` | Updated destination pointer | + +**Example**: +```python +align_out, base_out = pto.pstu(align_in, mask, buf) +``` + +#### `pto.pand(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise AND of two predicate masks under a gating mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise AND result | + +**Example**: +```python +result = pto.pand(mask1, mask2, gate) +``` + +#### `pto.por(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise OR of two predicate masks under a gating mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise OR result | + +**Example**: +```python +result = pto.por(mask1, mask2, gate) +``` + +#### `pto.pxor(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise XOR of two predicate masks under a gating mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise XOR result | + +**Example**: +```python +result = pto.pxor(mask1, mask2, gate) +``` + +#### `pto.pdintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` + +**Description**: Predicate deinterleave for 8-bit masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.mask_b8` | First input mask | +| `src1` | `pto.mask_b8` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `pto.mask_b8` | First result mask | +| `high` | `pto.mask_b8` | Second result mask | + +**Example**: +```python +low8, high8 = pto.pdintlv_b8(mask_a, mask_b) +``` + +#### `pto.pintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` + +**Description**: Predicate interleave for 16-bit masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.mask_b16` | First input mask | +| `src1` | `pto.mask_b16` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `pto.mask_b16` | First result mask | +| `high` | `pto.mask_b16` | Second result mask | + +**Example**: +```python +low16, high16 = pto.pintlv_b16(mask_a, mask_b) +``` + +**Note**: Prefer `pto.make_mask()` for automatic bitwidth selection and unified tail/pattern mask generation. diff --git a/ptodsl/docs/user_guide/11-vector-arithmetic-operations.md b/ptodsl/docs/user_guide/11-vector-arithmetic-operations.md new file mode 100644 index 000000000..ede8388df --- /dev/null +++ b/ptodsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -0,0 +1,1611 @@ +### Unary Vector Operations + +Element-wise unary operations on vector registers. + +#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Absolute value of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Absolute values | + +**Constraints**: +- Mask granularity must match vector element type (e.g., `f32` requires `mask_b32`) + +**Example**: +```python +abs_vec = pto.vabs(vec_f32, mask32) +``` + +#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Exponential of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential values | + +#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Natural logarithm of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Natural logarithm values | + +#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Square root of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Square root values | + +#### `pto.vrec(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reciprocal of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reciprocal values | + +#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: ReLU activation (max(0, x)) of vector elements. + +**Supported dtypes**: `si32`, `i32`, `f16`, `f32` + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated values | + +#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitwise NOT of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise NOT values | + +#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reduction add of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduction result vector | + +**Type Rules**: +- For floating-point inputs and `i32/ui32`, the result vector type matches the input vector type. +- For `i8/ui8` inputs, `pto.vcadd` returns a widened `i16/ui16` vector. +- For `i16/ui16` inputs, `pto.vcadd` returns a widened `i32/ui32` vector. +- The result mask granularity follows the result vector element type. + +#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex maximum of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex maximum result | + +#### `pto.vbcnt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bit count (population count) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bit count values | + +#### `pto.vneg(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Negation of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Negated values | + +**Constraints**: +- Mask granularity must match vector element type + +**Example**: +```python +neg_vec = pto.vneg(vec_f32, mask32) +``` + +#### `pto.vcls(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Count leading sign bits of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Count of leading sign bits | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vcmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex minimum of vector elements (treating pairs as complex numbers). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex minimum result | + +#### `pto.vrsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reciprocal square root of vector elements (1/√x). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reciprocal square root values | + +**Constraints**: +- For floating-point vector types only + +#### `pto.vprelu(vec: VRegType, alpha: VRegType, mask: MaskType) -> VRegType` + +**Description**: Parametric ReLU activation of vector elements: `x if x >= 0 else alpha * x`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `alpha` | `VRegType` | Slope parameter for negative values | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Parametric ReLU activated values | + +#### `pto.vmov(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector move (data movement). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Copied vector | + +#### `pto.vsunpack(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Signed unpack of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Unpacked signed values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vzunpack(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Zero-extended unpack of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Unpacked zero-extended values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vusqz(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Unsigned squeeze (compression) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Compressed unsigned values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vsqz(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Signed squeeze (compression) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Compressed signed values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, part: pto.VcvtPartMode) -> VRegType` + +**Description**: Fused exponential difference `exp(vec - max_vec)` for numerically stable softmax lowering. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `max_vec` | `VRegType` | Per-lane max vector subtracted before exponentiation | +| `mask` | `MaskType` | Predicate mask. Use `b16` for `f16` inputs and `b32` for `f32` inputs. | +| `part` | `pto.VcvtPartMode` | Output part selector enum. Use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential difference values; result element type is `f32` | + +**Constraints**: +- Supports `f16` and `f32` input vectors only +- `vec` and `max_vec` must use the same vector type +- `mask` granularity must match the input vector element width +- `part` should use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD` +- Canonical strings `"EVEN"` / `"ODD"` are still accepted for compatibility + +### Binary Vector Operations + +Element-wise binary operations on vector registers. + +#### `pto.vadd(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise addition of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum of vectors | + +**Example**: +```python +sum_vec = pto.vadd(vec_a, vec_b, mask32) +``` + +#### `pto.vsub(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise subtraction of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference of vectors | + +#### `pto.vmul(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise multiplication of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Product of vectors | + +#### `pto.vdiv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise division of two vectors. + +- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16` and `f32`. +- `f16`/`f32` authoring code stays on the public `pto.vdiv` VPTO path. +- Integer `pto.vdiv` also uses the same public surface, but lowers through an internal soft-helper path. +- For `i8`/`ui8`, the integer lowering widens to 16-bit lanes, computes the soft division, then narrows back to 8-bit lanes. +- Internal helper names such as `_tl_soft_vdiv_*` are implementation details and are not part of the supported DSL call surface. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Quotient of vectors | + +#### `pto.vmod(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise modulo of two vectors. + +- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`). +- Floating-point `vmod` is not part of the current TileLang DSL v1 public surface. +- `pto.vmod` is the only public vector modulo entry point in TileLang DSL v1. +- The current implementation lowers through an internal soft-helper path; helper names such as `_tl_soft_vmod_*` are intentionally hidden implementation details. +- For `i8`/`ui8`, the modulo path uses an explicit widen-to-16-bit, soft-compute, narrow-back-to-8-bit profile. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | Dividend vector | +| `vec2` | `VRegType` | Divisor vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Remainder vector | + +#### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise maximum | + +#### `pto.vmin(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise minimum | + +#### `pto.vand(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +#### `pto.vor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +#### `pto.vxor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +#### `pto.vshl(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift left (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshr(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift right (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vaddrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Addition with ReLU activation (max(0, vec1 + vec2)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated sum of vectors | + +#### `pto.vaddreluconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Convolution addition with ReLU activation (convolution-specific fused operation). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated convolution sum | + +**Constraints**: +- Optimized for convolution-specific patterns + +#### `pto.vsubrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Subtraction with ReLU activation (max(0, vec1 - vec2)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated difference of vectors | + +#### `pto.vaxpy(alpha: VRegType, x: VRegType, y: VRegType, mask: MaskType) -> VRegType` + +**Description**: BLAS AXPY operation (αx + y). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `alpha` | `VRegType` | Scaling factor | +| `x` | `VRegType` | Input vector x | +| `y` | `VRegType` | Input vector y | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result of αx + y | + +#### `pto.vmulconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Convolution multiplication (convolution-specific multiplication). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Convolution product | + +**Constraints**: +- Optimized for convolution-specific patterns + +#### `pto.vmull(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, VRegType)` + +**Description**: Widening multiply with split low/high results (extended arithmetic). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Low part of widened product (`r & 0xFFFFFFFF`) | +| `high` | `VRegType` | High part of widened product (`r >> 32`) | + +**Constraints**: +- Current A5 documented form is native `i32/u32` 32x32->64 widening multiply +- Result is split into two vector outputs instead of a single widened vector + +**Example**: +```python +low, high = pto.vmull(lhs_i32, rhs_i32, mask32) +``` + +#### `pto.vmula(vec1: VRegType, vec2: VRegType, vec3: VRegType, mask: MaskType) -> VRegType` + +**Description**: Fused multiply-add (vec1 * vec2 + vec3). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector (multiplier) | +| `vec2` | `VRegType` | Second input vector (multiplicand) | +| `vec3` | `VRegType` | Third input vector (addend) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result of vec1 * vec2 + vec3 | + +### Vector-Scalar Operations + +Operations between vectors and scalars. + +#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector multiplied by scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar multiplier | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Scaled vector | + +**Example**: +```python +scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) +``` + +#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector plus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar addend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Maximum values | + +#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Minimum values | + +#### `pto.vlrelu(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Leaky ReLU activation (max(αx, x)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Alpha coefficient | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Leaky ReLU activated values | + +#### `pto.vshls(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` + +**Description**: Vector shift left by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `i16` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshrs(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` + +**Description**: Vector shift right by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `i16` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vands(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vxors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vsubs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector minus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar subtrahend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | + +#### `pto.vbr(value: ScalarType) -> VRegType` + +**Description**: Broadcast scalar to all vector lanes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar source | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Vector whose active lanes all carry `value` | + +**Constraints**: +- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. +- For integer types, only the low bits of the scalar source are consumed according to the bit width (8, 16, or 32 bits). + +**Example**: +```python +# Broadcast scalar constant to vector +zero_vec = pto.vbr(0.0) +one_vec = pto.vbr(1.0) + +# Reduction seed with explicit floating dtype +rowmax_seed_f32 = pto.vbr(pto.f32("-inf")) +rowmax_seed_f16 = pto.vbr(pto.f16("0xFC00")) +``` + +#### `pto.vdup(input: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vdup(input: VRegType, mask: MaskType, position: PositionMode = PositionMode.LOWEST) -> VRegType` + +**Description**: Duplicate a scalar value or one selected vector element into +the active lanes of a destination vector. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `input` | `ScalarType` or `VRegType` | Input scalar or source vector | +| `mask` | `MaskType` | Predicate mask controlling which lanes are written | +| `position` | `PositionMode` | Optional enum for the vector-input overload, selecting the source vector element to duplicate (default: `PositionMode.LOWEST`) | + +**Position Mode Enum**: The `PositionMode` enum provides type-safe source-lane +selection for `pto.vdup`. `LOWEST` selects the lowest-index element of the +source vector and `HIGHEST` selects the highest-index element. The enum is only +used by the vector-input overload. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Vector whose active lanes receive the duplicated value | + +**Constraints**: +- `mask` granularity must match the destination vector element type. For + example, `f32`/`i32`/`si32`/`ui32` vectors require `mask_b32`. +- When `input` is a scalar, the scalar value is duplicated to every active lane. +- When `input` is a vector, `position` selects a single source element and that + value is duplicated to every active lane. +- The scalar overload does not accept `position`. +- Inactive lanes follow VPTO predicate semantics and are not guaranteed to carry + meaningful values for subsequent masked-off use. +- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. +- `position` is only meaningful for vector input. TileLang DSL currently exposes + `PositionMode.LOWEST` and `PositionMode.HIGHEST`, matching VPTO v0.3. + +**Example**: +```python +mask32 = pto.make_mask(pto.f32, pto.PAT.ALL) + +# Duplicate a scalar into all active lanes. +broadcast = pto.vdup(3.14, mask32) + +# Use dtype constructors for floating-point special values. +seed = pto.vdup(pto.f32("-inf"), mask32) +seed_f16 = pto.vdup(pto.f16("0xFC00"), pto.make_mask(pto.f16, pto.PAT.ALL)) + +# Assume `vec` is an existing `f32` vector register value. +vec = pto.vlds(src, 0) + +# Duplicate the lowest source lane to all active lanes. +dup_lowest = pto.vdup(vec, mask32) # position defaults to "LOWEST" + +# Duplicate the highest source lane to all active lanes. +dup_highest = pto.vdup(vec, mask32, pto.PositionMode.HIGHEST) +``` + +**Type Safety Note**: +- For floating-point seeds, prefer `pto.f16(...)` / `pto.bf16(...)` / `pto.f32(...)` constructors. +- Do not pass integer bit-pattern literals directly (for example `0xFF800000`) when a floating vector type is intended. + +### Carry & Select Operations + +Operations with carry propagation and selection. + +**Comparison Mode Enum**: The `CmpMode` enum provides type-safe comparison mode specification for `pto.vcmp` and `pto.vcmps` operations. It includes the following values: `EQ` (equal), `NE` (not equal), `LT` (less than), `LE` (less than or equal), `GT` (greater than), `GE` (greater than or equal). + +Implemented current-package carry/select surface also includes: +- `pto.vselr(vec0, vec1) -> VRegType` +- `pto.vselrv2(vec0, vec1) -> VRegType` +- `pto.vaddcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` +- `pto.vsubcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` + +#### `pto.vcmp(vec0: VRegType, vec1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Element-wise vector comparison with seed mask. Compares two vectors element-wise and generates a predicate mask based on the specified comparison mode. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec0` | `VRegType` | First input vector | +| `vec1` | `VRegType` | Second input vector | +| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | +| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ` (equal), `CmpMode.NE` (not equal), `CmpMode.LT` (less than), `CmpMode.LE` (less than or equal), `CmpMode.GT` (greater than), `CmpMode.GE` (greater than or equal) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Generated predicate mask based on element-wise comparison | + +**Constraints**: +- Only lanes enabled by `seed_mask` participate in the comparison +- The two input vectors must have the same element type and vector length +- The output mask granularity matches the input vector element type + +**Example**: +```python +# Compare two vectors for less-than relation +all_mask = pto.make_mask(pto.f32, PAT.ALL) +lt_mask = pto.vcmp(vec_a, vec_b, all_mask, CmpMode.LT) +``` + +#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Vector-scalar comparison with seed mask. Compares each element of a vector against a scalar value and generates a predicate mask based on the specified comparison mode. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value to compare against (must match vector element type) | +| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | +| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ`, `CmpMode.NE`, `CmpMode.LT`, `CmpMode.LE`, `CmpMode.GT`, `CmpMode.GE` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Generated predicate mask based on vector-scalar comparison | + +**Constraints**: +- Only lanes enabled by `seed_mask` participate in the comparison +- The scalar type must match the vector element type +- The output mask granularity matches the input vector element type + +**Example**: +```python +# Check which elements are greater than zero +all_mask = pto.make_mask(pto.f32, PAT.ALL) +positive_mask = pto.vcmps(values, pto.f32(0.0), all_mask, CmpMode.GT) +``` + +#### `pto.vaddc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` + +**Description**: Vector addition with carry output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum vector | +| `carry_out` | `MaskType` | Output carry mask | + +#### `pto.vsubc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` + +**Description**: Vector subtraction with borrow output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | +| `borrow_out` | `MaskType` | Output borrow mask | + +#### `pto.vsel(true_vec: VRegType, false_vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector select based on mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `true_vec` | `VRegType` | Vector selected when mask bit is 1 | +| `false_vec` | `VRegType` | Vector selected when mask bit is 0 | +| `mask` | `MaskType` | Selection mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Selected vector | + +**Example**: +```python +result = pto.vsel(scaled_vec, original_vec, mask32) +``` + +### Reduction Operations + +Reduction operations across vector lanes or channels. + +#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group addition reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced sum across groups | + +#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group maximum reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced maximum across groups | + +#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group minimum reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced minimum across groups | + +#### `pto.vcpadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-channel addition reduction (reduction across channels). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced sum across channels | + +### Data Rearrangement + +Operations for rearranging data within vectors. + +Predicate rearrangement ops `pto.pdintlv_b8` and `pto.pintlv_b16` are documented in `10-predicate-operations.md` because they operate on predicate masks rather than vector registers. + +Implemented current-package rearrangement surface also includes: +- `pto.vintlvv2(vec0, vec1, part) -> VRegType` +- `pto.vdintlvv2(vec0, vec1, part) -> VRegType` + +#### `pto.vintlv(vec1: VRegType, vec2: VRegType) -> (VRegType, VRegType)` + +**Description**: Interleave two vectors and return the low/high results. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Low interleaved result | +| `high` | `VRegType` | High interleaved result | + +#### `pto.vdintlv(vec0: VRegType, vec1: VRegType) -> (VRegType, VRegType)` + +**Description**: Deinterleave a pair of vectors into low/high results. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec0` | `VRegType` | First input vector | +| `vec1` | `VRegType` | Second input vector | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec1` | `VRegType` | First deinterleaved vector | +| `vec2` | `VRegType` | Second deinterleaved vector | + +#### `pto.vpack(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector packing (combine elements from two vectors). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Packed vector | + +#### `pto.vperm(vec: VRegType, indices: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector permutation (reorder elements according to index vector). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `indices` | `VRegType` | Permutation indices | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Permuted vector | + +#### `pto.vshift(vec: VRegType, shift_amount: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Generic vector shift (shift all elements by same amount). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift_amount` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted vector | + +#### `pto.vslide(vec: VRegType, window_size: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector sliding window (create overlapping windows). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `window_size` | `ScalarType` | Size of sliding window | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sliding window result | + +#### `pto.vsort32(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: 32-element sorting of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (32 elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sorted vector | + +**Constraints**: +- Input vector must have exactly 32 elements + +#### `pto.vmrgsort(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Merge sort of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Merged and sorted vector | + +#### `pto.vtranspose(dest: ptr, src: ptr, config: pto.i64) -> None` [Advanced Tier] + +**Description**: UB-to-UB transpose operation. This op works on UB memory directly (not `vreg -> vreg`). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src` | `ptr` | Source pointer in UB memory space | +| `config` | `pto.i64` | ISA control/config operand that encodes transpose layout behavior | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `None` | `None` | Side-effect operation that writes transposed data to `dest` | + +**Constraints**: +- `dest` and `src` must be UB pointers +- Correctness depends on the `config` encoding and UB layout contract + +**Example**: +```python +pto.vtranspose(dst_ub_ptr, src_ub_ptr, config_word) +``` + +### Conversion & Special Operations + +Type conversion and specialized operations. + +#### `pto.vtrc(vec: VRegType, mask: MaskType, rnd: pto.VcvtRoundMode | None = None) -> VRegType` + +**Description**: Truncate/round float to integer-valued float (stays in float type). This is the TileLang DSL surface for the VPTO `pto.vtrc` operation. + +**Attribute Enums**: +- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` (note: `vtrc` does not support `O`) + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | +| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `round_mode`. Defaults to `R` if not specified. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Truncated vector with integer-valued float elements | + +**Constraints**: +- Current TileLang DSL v1 accepts exactly two positional arguments: `pto.vtrc(vec, mask)`. Optional `rnd` attribute is exposed as keyword argument: `rnd=...`. +- The underlying VPTO op syntax is `pto.vtrc %input, %mask, "RND"`. +- Supported rounding modes are `R` (round to nearest), `A` (round away from zero), `F` (floor), `C` (ceil), `Z` (truncate toward zero). +- The enum form is preferred. For compatibility, canonical strings such as `"R"`, `"A"`, `"F"`, `"C"`, `"Z"` are also accepted. +- This op does not change the element type; input and output have the same vector type. +- Only floating-point element types are supported: `f16`, `bf16`, `f32`. + +#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType, rnd: pto.VcvtRoundMode | None = None, sat: pto.VcvtSatMode | None = None, part: pto.VcvtPartMode | None = None) -> VRegType` + +**Description**: Convert vector elements between supported float and integer +families. This is the TileLang DSL surface for the VPTO `pto.vcvt` conversion +family. + +**Attribute Enums**: +- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` +- `pto.VcvtSatMode`: `SAT`, `NOSAT` +- `pto.VcvtPartMode`: `EVEN`, `ODD`, `P0`, `P1`, `P2`, `P3` + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `to_type` | `Type` | Target scalar dtype symbol for the result vector element type | +| `mask` | `MaskType` | Predicate mask selecting active source lanes. Its granularity must match the source vector family, not the destination family | +| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `rnd` | +| `sat` | `pto.VcvtSatMode` \| `None` | Optional saturation attribute lowered to VPTO `sat` | +| `part` | `pto.VcvtPartMode` \| `None` | Optional width-changing lane-placement selector lowered to VPTO `part` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Converted vector with the vreg shape implied by `to_type` | + +**Constraints**: +- Current TileLang DSL v1 accepts exactly three positional arguments: + `pto.vcvt(vec, to_type, mask)`. Optional attributes are exposed as keyword + arguments: `rnd=...`, `sat=...`, `part=...`. +- The underlying VPTO op family is the fuller + `pto.vcvt %input, %mask {rnd, sat, part}` surface, and the DSL keywords map + directly to those VPTO attributes. +- `mask` always follows the source vector family: + `f32`/`i32`/`si32`/`ui32` use `mask_b32`; + `f16`/`bf16`/`i16`/`si16`/`ui16` use `mask_b16`; + `i8`/`si8`/`ui8` use `mask_b8`. +- The enum form is preferred. For compatibility, canonical strings such as + `"R"`, `"SAT"`, and `"EVEN"` are also accepted. +- VPTO `part` supports two families: `Part` (`EVEN`/`ODD`) for ordinary + width-changing conversions (e.g. `32 -> 16`, `16 -> 32`), and `Part_T` + (`P0`–`P3`) for 4-way packed placement (e.g. `32 -> 8`, fp8/fp4 flows). + + | Mode | VPTO spelling | Family | Description | TileLang DSL v1 status | + |------|---------------|--------|-------------|------------------------| + | `EVEN` | `PART_EVEN` | `Part` | Output to even-indexed lanes | Exposed as `pto.VcvtPartMode.EVEN` | + | `ODD` | `PART_ODD` | `Part` | Output to odd-indexed lanes | Exposed as `pto.VcvtPartMode.ODD` | + | `P0` | `PART_P0` | `Part_T` | Output to sub-part 0 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P0` | + | `P1` | `PART_P1` | `Part_T` | Output to sub-part 1 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P1` | + | `P2` | `PART_P2` | `Part_T` | Output to sub-part 2 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P2` | + | `P3` | `PART_P3` | `Part_T` | Output to sub-part 3 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P3` | +- Only backend-supported source/destination type pairs are legal. For the full + A5 `vcvt` type matrix, width-changing packing rules, and attribute-sensitive + forms, refer to + [`../vpto_spec/vpto-spec-current.md`](../vpto_spec/vpto-spec-current.md). +- Attribute requirements are type-pair specific. The DSL enforces the same + per-form contract as VPTO, so some pairs require attributes while others + reject them. +- Examples: + `f32 -> si32` requires `rnd` and `sat`; + `f16 -> si32` requires `rnd` and `part`, and rejects `sat`; + `bf16 -> f16` requires `rnd` and `sat`; + `f16 -> f32` requires `part`; + `f32 -> f16` requires `rnd`, `sat`, and `part`; + `si32 -> f32` requires `rnd`. +- VPTO does not define a `mask_b64` form. Conversions that produce `si64` + results still use the typed mask granularity of the source vector family. +- Width-changing conversions continue to follow VPTO packing semantics even on + the simplified DSL surface. For example, `f16 -> f32` uses an `f16`-family + `mask_b16`, because the mask is attached to the source vector family. +- A common `tcvt`-style pair is: + `f16 -> f32`: `pto.vlds(..., dist=pto.VLoadDist.UNPK_B16)` + `pto.vcvt(..., part=pto.VcvtPartMode.EVEN)`; + `f32 -> f16`: `pto.vcvt(..., rnd=..., sat=..., part=pto.VcvtPartMode.EVEN)` + `pto.vsts(..., dist=pto.VStoreDist.PK_B32)`. +- In those `tcvt` flows, the `vcvt` mask still follows the source vector family: + `f16 -> f32` uses `mask_b16`, while `f32 -> f16` uses `mask_b32`. +- The follow-on `vsts` mask is checked against the store `dist`, not the narrowed element dtype alone. For example, `pto.vsts(vec_f16, ..., mask32, dist=pto.VStoreDist.PK_B32)` is valid and expected for `f32 -> f16` rowwise `tcvt`. + +**Example**: +```python +mask16 = pto.make_mask(pto.f16, PAT.ALL) +vec_f16 = pto.vlds(src, 0) +vec_f32 = pto.vcvt(vec_f16, pto.f32, mask16) + +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_i32 = pto.vcvt(vec_f32, pto.si32, mask32) + +vec_i32_wide = pto.vcvt( + vec_f16, + pto.si32, + mask16, + rnd=pto.VcvtRoundMode.R, + part=pto.VcvtPartMode.EVEN, +) + +vec_f16_from_bf16 = pto.vcvt( + vec_bf16, + pto.f16, + mask16, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, +) + +vec_f16_narrow = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD, +) + +# Rowwise tcvt-style widening from f16 to f32 +vec_f16_unpacked = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) +vec_f32_from_f16 = pto.vcvt( + vec_f16_unpacked, + pto.f32, + mask16, + part=pto.VcvtPartMode.EVEN, +) + +# Rowwise tcvt-style narrowing from f32 to f16 +vec_f16_packed = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, +) +pto.vsts(vec_f16_packed, dst, 0, mask32, dist=pto.VStoreDist.PK_B32) +``` + +#### `pto.vbitsort(dest: ptr, src: ptr, indices: ptr, repeat_times: index) -> None` [Advanced Tier] + +**Description**: Sort 32 region proposals by score and materialize sorted proposal +records into UB memory. This is a UB helper and not a `vreg -> vreg` operation. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src` | `ptr` | Source score pointer in UB memory space | +| `indices` | `ptr` | Source index pointer in UB memory space | +| `repeat_times` | `index` | Repeat count; each repeat processes the next adjacent group of 32 scores and 32 indices | + +**Returns**: +None. The op writes UB memory directly. + +**Constraints**: +- `dest`, `src`, and `indices` must be UB-backed pointers +- Scores are sorted in descending order +- Equal-score ties preserve the earlier input proposal first +- Output records occupy 8 bytes each: upper 4 bytes for the index and lower 4 bytes for the score + +#### `pto.vmrgsort4(dest: ptr, src0: ptr, src1: ptr, src2: ptr, src3: ptr, count: pto.i64, config: pto.i64) -> None` [Advanced Tier] + +**Description**: Merge-sort 4 pre-sorted UB inputs. This op writes UB memory +directly and does not return a vector SSA value. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src0` | `ptr` | First pre-sorted input pointer in UB memory space | +| `src1` | `ptr` | Second pre-sorted input pointer in UB memory space | +| `src2` | `ptr` | Third pre-sorted input pointer in UB memory space | +| `src3` | `ptr` | Fourth pre-sorted input pointer in UB memory space | +| `count` | `pto.i64` | Number of valid input elements participating in the merge | +| `config` | `pto.i64` | Operation control word encoding sort behavior | + +**Returns**: +None. The op writes UB memory directly. + +**Constraints**: +- `dest` and `src0` through `src3` must be UB-backed pointers +- Inputs must already be sorted according to the order encoded by `config` + +#### `pto.get_vms4_sr() -> (pto.i16, pto.i16, pto.i16, pto.i16)` [Advanced Tier] + +**Description**: Read `VMS4_SR` after exhausted `pto.vmrgsort4` and return the +finished element counts for source lists 0 through 3. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `list0` | `pto.i16` | Finished count from `VMS4_SR[15:0]` | +| `list1` | `pto.i16` | Finished count from `VMS4_SR[31:16]` | +| `list2` | `pto.i16` | Finished count from `VMS4_SR[47:32]` | +| `list3` | `pto.i16` | Finished count from `VMS4_SR[63:48]` | + +**Example**: +```python +list0, list1, list2, list3 = pto.get_vms4_sr() +``` + +**Order Mode Enum**: The `OrderMode` enum provides type-safe order selection for `pto.vci` operations. `ASC` and `DESC` are supported. + +#### `pto.vci(index: ScalarType, order: OrderMode = OrderMode.ASC) -> VRegType` + +**Description**: Generate a lane-index vector from a scalar seed/index value (DSA/SFU operation). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `index` | `ScalarType` | Scalar seed or base index value | +| `order` | `OrderMode` | Order mode enum (default: `OrderMode.ASC`; supported values: `ASC`, `DESC`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Generated index vector | + +**Constraints**: +- This is an index-generation family, not a numeric conversion +- The `order` parameter and result element type together determine how indices are generated +- Supported order modes are ascending (`OrderMode.ASC`) and descending (`OrderMode.DESC`) + +**Example**: +```python +# Generate ascending indices starting from 0 +indices = pto.vci(pto.i32(0), OrderMode.ASC) + +# Generate descending indices starting from the seed value +indices_desc = pto.vci(pto.i32(63), OrderMode.DESC) + +# Keyword form for the optional order argument is also supported +indices_kw = pto.vci(pto.i32(0), order=OrderMode.ASC) +``` diff --git a/ptodsl/docs/user_guide/12-cube-operations.md b/ptodsl/docs/user_guide/12-cube-operations.md new file mode 100644 index 000000000..275039838 --- /dev/null +++ b/ptodsl/docs/user_guide/12-cube-operations.md @@ -0,0 +1,454 @@ +# Cube Matrix Multiply Operations + +Cube operations target the AIC (Cube) hardware unit for matrix multiplication and +staged data movement. They are only available inside `@pto.ckernel` function +bodies. All Cube operands use `pto.ptr` raw pointers — no +`vecscope` execution scope is used. + +## Address Spaces + +Cube operations use the following address spaces via the `MemorySpace` enum. +The IR type column shows the canonical `!pto.ptr` spelling. Older +`mat`/`left`/`right`/`acc`/`bias`/`scaling` pointer spellings are accepted as +parser aliases and print back as `l1`/`l0a`/`l0b`/`l0c`/`bt`/`fb`. + +| Address Space | Enum Value | Canonical IR Type | Legacy ptr alias | Description | +|--------------|------------|-------------------|------------------|-------------| +| `GM` | `MemorySpace.GM` | `!pto.ptr` | - | Global memory | +| `MAT` | `MemorySpace.MAT` | `!pto.ptr` | `mat` | L1 buffer (cbuf) | +| `LEFT` | `MemorySpace.LEFT` | `!pto.ptr` | `left` | L0A left-operand buffer | +| `RIGHT` | `MemorySpace.RIGHT` | `!pto.ptr` | `right` | L0B right-operand buffer | +| `ACC` | `MemorySpace.ACC` | `!pto.ptr` | `acc` | L0C accumulator buffer | +| `BIAS` | `MemorySpace.BIAS` | `!pto.ptr` | `bias` | Bias table | +| `UB` | `MemorySpace.UB` | `!pto.ptr` | `vec` | Unified buffer (Vector side) | + +## Shared Infrastructure + +Cube operations reuse general tile and pointer facilities documented elsewhere: + +| Facility | Description | Reference | +|----------|-------------|-----------| +| `pto.Tile` | Allocate a tile buffer with address space | [Type System — Tile Type Definition](05-type-system.md#tile-type-definition) | +| `.as_ptr()` | Get raw pointer from Tile / TensorView | [Frontend Operations — Pointer Construction](07-frontend-operations.md#pointer-construction-advanced-tier) | +| `pto.addptr` | Element-offset a pointer | [Frontend Operations — Pointer Construction](07-frontend-operations.md#pointer-construction-advanced-tier) | + +--- + +## Matrix Compute Operations + +### `pto.mad` — zero-init matmul + +#### `pto.mad(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Zero-init cube matrix multiply. Clears the accumulator and computes +`dst = lhs * rhs`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `lhs` | `pto.ptr` | L0A left operand | +| `rhs` | `pto.ptr` | L0B right operand | +| `dst` | `pto.ptr` | L0C accumulator destination | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `k` | `int` | K dimension size | +| `unit_flag_ctrl` | `int` | Accumulator control flag (0 / 2 / 3) | +| `disable_gemv` | `bool` | GEMV disable control | + +**Constraints**: +- `lhs` must be in `l0a` address space. +- `rhs` must be in `l0b` address space. +- `dst` must be in `l0c` address space. + +**Example**: +```python +pto.mad(l0a, l0b, l0c, 16, 16, 64) +``` + +--- + +### `pto.mad_acc` — accumulating matmul + +#### `pto.mad_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Accumulating cube matrix multiply. Computes `dst += lhs * rhs`. + +**Parameters**: Same as `pto.mad`. + +**Example**: +```python +pto.mad_acc(l0a, l0b, l0c, 16, 16, 64, unit_flag_ctrl=2) +``` + +--- + +### `pto.mad_bias` — bias-init matmul + +#### `pto.mad_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Bias-init cube matrix multiply. Computes `dst = lhs * rhs + bias`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `bias` | `pto.ptr` | Bias table pointer | + +Other parameters are the same as `pto.mad`. + +**Constraints**: +- `bias` must be in `bt` address space. + +**Example**: +```python +pto.mad_bias(l0a, l0b, l0c, bt, 16, 16, 64) +``` + +--- + +### `pto.mad_mx` — zero-init MX matmul + +#### `pto.mad_mx(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Zero-init MX (micro-scaling) cube matrix multiply. Same semantics +as `pto.mad`, for MX-capable dtypes such as `f8E4M3FN`. + +**Parameters**: Same as `pto.mad`. + +**Example**: +```python +pto.mad_mx(l0a, l0b, l0c, 16, 16, 64) +``` + +--- + +### `pto.mad_mx_acc` — accumulating MX matmul + +#### `pto.mad_mx_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Accumulating MX cube matrix multiply. Computes `dst += lhs * rhs`. + +**Parameters**: Same as `pto.mad`. + +--- + +### `pto.mad_mx_bias` — MX bias-init matmul + +#### `pto.mad_mx_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: MX bias-init cube matrix multiply. Computes `dst = lhs * rhs + bias`. + +**Parameters**: Same as `pto.mad_bias`. + +--- + +## Data Movement Operations + +### `pto.cube_load` — GM → L1 (cbuf) + +#### `pto.cube_load(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured GM-to-L1 (`cbuf` / `l1`) data movement wrapper. Lowers +to loop/stride setup plus `pto.copy_gm_to_cbuf`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | Global memory source pointer | +| `dst` | `pto.ptr` | L1 (cbuf) destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop params, each `(count_i, src_stride_i, dst_stride_i)` | + +**Constraints**: +- `src` must be in `gm` address space. +- `dst` must be in `l1` address space. + +**Example**: +```python +pto.cube_load(a_ptr, l1_a.as_ptr(), 16, nburst=(1, 0, 0)) +``` + +--- + +### `pto.cube_store` — L1 (cbuf) → UB + +#### `pto.cube_store(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured L1 (`cbuf`) to UB data movement wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | UB destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop params | + +**Example**: +```python +pto.cube_store(l1_src.as_ptr(), ub_dst.as_ptr(), 16, nburst=(1, 0, 0)) +``` + +--- + +### `pto.cube_load_frac` — fractal load + +#### `pto.cube_load_frac(src: PtrType, dst: PtrType, mode: pto.FractalMode, *, shape: tuple[int, int], src_layout: tuple[int, int], dst_group: tuple[int, int, int, int], ctrl: tuple[int, bool]) -> None` + +**Description**: Structured fractal-load wrapper for `nd2nz` and `dn2nz` modes. +Lowers to `set_mte2_nz_para` plus `copy_gm_to_cbuf_multi_nd2nz` or +`copy_gm_to_cbuf_multi_dn2nz`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | Global memory source pointer | +| `dst` | `pto.ptr` | L1 destination pointer | +| `mode` | `pto.FractalMode` | `pto.FractalMode.ND2NZ` or `pto.FractalMode.DN2NZ` | +| `shape` | `tuple[int, int]` | `(n_value, d_value)` | +| `src_layout` | `tuple[int, int]` | `(inner_stride, outer_stride)` | +| `dst_group` | `tuple[int, int, int, int]` | `(group_count, loop2_stride, loop3_stride, loop4_stride)` | +| `ctrl` | `tuple[int, bool]` | `(l2_cache_ctrl, smallc0_en)` | + +**Constraints**: +- `src` must be in `gm` address space. +- `dst` must be in `l1` address space. + +**Example**: +```python +pto.cube_load_frac(a_ptr, l1_a.as_ptr(), pto.FractalMode.ND2NZ, + shape=(16, 16), src_layout=(4, 8), + dst_group=(1, 0, 0, 0), ctrl=(0, False)) +``` + +--- + +### `pto.bias_load` — L1 (cbuf) → bias table + +#### `pto.bias_load(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0)) -> None` + +**Description**: Structured L1 (`cbuf`) to bias-table load wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | Bias table destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_gap, dst_gap)` | + +**Constraints**: +- Supported source/destination type pairs: `f32→f32`, `i32→i32`, `f16→f32`, `bf16→f32`. + +**Example**: +```python +pto.bias_load(l1_bias.as_ptr(), bt.as_ptr(), 16, nburst=(1, 0, 0)) +``` + +--- + +### `pto.left_load` — L1 (cbuf) → L0A + +#### `pto.left_load(src: PtrType, dst: PtrType, m: int, k: int) -> None` + +**Description**: Structured L1-to-L0A wrapper. Lowers to `pto.load_cbuf_to_ca`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | L0A destination pointer | +| `m` | `int` | M dimension size | +| `k` | `int` | K dimension size | + +**Constraints**: +- `src` must be in `l1` address space. +- `dst` must be in `l0a` address space. + +**Example**: +```python +pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), 16, 64) +``` + +--- + +### `pto.right_load` — L1 (cbuf) → L0B + +#### `pto.right_load(src: PtrType, dst: PtrType, k: int, n: int) -> None` + +**Description**: Structured L1-to-L0B wrapper. Lowers to `pto.load_cbuf_to_cb`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | L0B destination pointer | +| `k` | `int` | K dimension size | +| `n` | `int` | N dimension size | + +**Constraints**: +- `src` must be in `l1` address space. +- `dst` must be in `l0b` address space. + +**Example**: +```python +pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), 64, 16) +``` + +--- + +### `pto.left_load_mx` — MX L1 → L0A + +#### `pto.left_load_mx(src: PtrType, dst: PtrType, m: int, k: int) -> None` + +**Description**: MX-mode L1-to-L0A wrapper. Lowers to `pto.load_cbuf_to_ca_mx`. + +**Parameters**: Same as `pto.left_load`. + +--- + +### `pto.right_load_mx` — MX L1 → L0B + +#### `pto.right_load_mx(src: PtrType, dst: PtrType, k: int, n: int) -> None` + +**Description**: MX-mode L1-to-L0B wrapper. Lowers to `pto.load_cbuf_to_cb_mx`. + +**Parameters**: Same as `pto.right_load`. + +--- + +## Result Writeback Operations + +### `pto.acc_store` — L0C (acc) → L1 (cbuf) + +#### `pto.acc_store(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (`l0c`) to L1 (`cbuf`) writeback wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L0C source pointer | +| `dst` | `pto.ptr` | L1 (cbuf) destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `mode` | `pto.FractalMode` | Layout mode: `NZ2ND` / `NZ2DN` / `NZ2NZ` | + +Mode-dependent parameters: + +| Mode | Required | Not Accepted | +|------|----------|--------------| +| `pto.FractalMode.NZ2ND` | (none) | — | +| `pto.FractalMode.NZ2DN` | `loop0_src_stride` | — | +| `pto.FractalMode.NZ2NZ` | `split` | `loop3` | + +Optional for `pto.FractalMode.NZ2ND` and `pto.FractalMode.NZ2DN`: +`loop3=(count, src_stride3, dst_stride3)`. + +**Example**: +```python +pto.acc_store(l0c.as_ptr(), l1_out.as_ptr(), + 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) +``` + +--- + +### `pto.acc_store_gm` — L0C (acc) → GM + +#### `pto.acc_store_gm(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, sid: int = 0, l2_cache_ctrl: int = 0, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (`l0c`) to GM writeback wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L0C source pointer | +| `dst` | `pto.ptr` | GM destination pointer | +| `sid` | `int` | Stream ID | +| `l2_cache_ctrl` | `int` | L2 cache control | + +Other parameters are the same as `pto.acc_store`. + +**Example**: +```python +pto.acc_store_gm(l0c.as_ptr(), c_ptr, 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) +``` + +--- + +### `pto.acc_store_ub` — L0C (acc) → UB + +#### `pto.acc_store_ub(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, dual_dst_mode: int = 0, sub_blockid: int = 0, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, channel_split_en: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (`l0c`) to UB writeback wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L0C source pointer | +| `dst` | `pto.ptr` | UB destination pointer | +| `dual_dst_mode` | `int` | Dual destination mode | +| `sub_blockid` | `int` | Sub-block ID | +| `channel_split_en` | `int` or `None` | Channel split enable (required for `mode=pto.FractalMode.NZ2NZ`) | + +Other parameters are the same as `pto.acc_store`. + +**Example**: +```python +pto.acc_store_ub(l0c.as_ptr(), ub_out.as_ptr(), + 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) +``` + +--- + +## Quick Reference + +### By Data Flow + +| Data Flow | Operation | Src Space | Dst Space | +|-----------|-----------|-----------|-----------| +| GM → L1 | `pto.cube_load` | gm | l1 | +| GM → L1 (fractal) | `pto.cube_load_frac` | gm | l1 | +| L1 → UB | `pto.cube_store` | l1 | ub | +| L1 → L0A | `pto.left_load` | l1 | l0a | +| L1 → L0B | `pto.right_load` | l1 | l0b | +| L1 → L0A (MX) | `pto.left_load_mx` | l1 | l0a | +| L1 → L0B (MX) | `pto.right_load_mx` | l1 | l0b | +| L1 → Bias | `pto.bias_load` | l1 | bt | +| L0A×L0B → L0C | `pto.mad` | l0a, l0b | l0c | +| L0A×L0B → L0C (acc) | `pto.mad_acc` | l0a, l0b | l0c | +| L0A×L0B+Bias → L0C | `pto.mad_bias` | l0a, l0b, bt | l0c | +| L0C → L1 | `pto.acc_store` | l0c | l1 | +| L0C → GM | `pto.acc_store_gm` | l0c | gm | +| L0C → UB | `pto.acc_store_ub` | l0c | ub | + +### MX Variants + +| Base Op | MX Variant | Description | +|---------|------------|-------------| +| `pto.mad` | `pto.mad_mx` | Zero-init MX matmul | +| `pto.mad_acc` | `pto.mad_mx_acc` | Accumulating MX matmul | +| `pto.mad_bias` | `pto.mad_mx_bias` | Bias-init MX matmul | + +--- + +## Template Slot Support + +Cube operations support `pto.tpl()` template-slot dispatch, consistent with the +Vector DSL mechanism. See [Template Kernels](04-template-kernels.md) for general +`pto.tpl()` usage. + +**Constraints**: Variants within the same slot must have identical parameter +signatures. For example, `mad` and `mad_acc` can share a slot, but `mad_bias` +(which adds a `bias` parameter) requires a separate slot. + +--- + +## See Also + +- [Kernel Declaration](03-kernel-declaration.md) — `@pto.ckernel` decorator specification +- [Examples](13-examples.md) — full Cube kernel code examples +- [Design doc](../../../docs/designs/tilelang-cube-dsl-design.md) — Cube DSL design details diff --git a/ptodsl/docs/user_guide/13-examples.md b/ptodsl/docs/user_guide/13-examples.md new file mode 100644 index 000000000..16105b853 --- /dev/null +++ b/ptodsl/docs/user_guide/13-examples.md @@ -0,0 +1,417 @@ +## Examples + +### Template-based Kernel Examples + +#### Unified Arithmetic Operations + +A single kernel implementing multiple arithmetic operations using templates: + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + target="a5", + ops=["tadd", "tsub", "tmul", "tdiv"], + dtypes=[(T, T, T)], + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + } +) +def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + """Single implementation for four arithmetic operations.""" + dtype = dst.element_type + rows, cols = dst.valid_shape + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, pto.elements_per_vreg(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) +``` + +#### Multiple Templates with Postprocess + +Kernel using separate templates for arithmetic and postprocess operations: + +```python +@pto.vkernel( + target="a5", + ops=["add_relu", "sub_relu", "add_abs", "sub_abs"], + dtypes=[(T, T, T)], + templates={ + "arithmetic": { + "add_relu": "vadd", + "sub_relu": "vsub", + "add_abs": "vadd", + "sub_abs": "vsub", + }, + "postprocess": { + "add_relu": "vrelu", + "sub_relu": "vrelu", + "add_abs": "vabs", + "sub_abs": "vabs", + } + } +) +def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, pto.elements_per_vreg(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + + # Use arithmetic template + arith_result = pto.tpl("arithmetic", lhs, rhs, mask) + + # Apply postprocess template + activated = pto.tpl("postprocess", arith_result, mask) + + pto.vsts(activated, dst[row, col:], mask) +``` + +#### Compile-time Substitution + +Template substitution happens before semantic analysis and lowering: + +```python +selected = pto.select_kernel("a5", "tadd", (ptype, ptype, ptype)) +# frontend resolves: +# pto.tpl("core", lhs, rhs, mask) +# into: +# pto.vadd(lhs, rhs, mask) +``` + +#### Benefits of Template-based Authoring + +1. **Code Reuse**: Single implementation serves multiple operations +2. **Maintenance**: Bug fixes and optimizations apply to all related operations +3. **Consistency**: Ensures uniform behavior across operation families +4. **Reduced Boilerplate**: Eliminates duplicate control flow and data movement code +5. **Type Safety**: Type variables ensure consistent operand types + +### Simple Vector Copy + +```python +@pto.vkernel(...) +def vector_copy(src: pto.Tile, dst: pto.Tile): + all_mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) + for offset in range(0, 256, 64): + vec = pto.vlds(src, offset) + pto.vsts(vec, dst, offset, all_mask) +``` + +### Conditional Computation + +```python +@pto.vkernel(...) +def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM), + threshold: pto.f32): + # ... setup ... + + with pto.strict_vecscope(ub_in, ub_out, threshold) as (vin, vout, thresh): + for i in range(0, 1024, 64): + vec = pto.vlds(vin, i) + + # Compare with threshold + mask = pto.pge_b32(vec, thresh) + + # Scale values above threshold + scaled = pto.vmuls(vec, pto.f32(2.0), mask) + + # Keep original values below threshold + result = pto.vsel(scaled, vec, mask) + + pto.vsts(result, vout, i, all_mask) +``` + +### Loop with Carry + +```python +@pto.vkernel(...) +def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), + dst: pto.ptr(pto.i32, MemorySpace.UB)): + all_mask = pto.make_mask(pto.i32, PAT.ALL) + carry = all_mask + + for i in range(0, 256, 64): + vec = pto.vlds(src, i) + result, carry = pto.vaddcs(vec, vec, carry, all_mask) + pto.vsts(result, dst, i, all_mask) +``` + +--- + +## Cube Kernel Examples + +Cube kernels target the AIC (Cube) hardware unit for matrix multiplication. GM data is expressed through `PartitionTensorView`, while hardware buffers in specific address spaces are constructed via `pto.Tile`. + +### Basic GEMM + +A full-pipeline matrix multiplication C = A × B: + +```python +from tilelang_dsl import ckernel, Tile, MemorySpace + +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm", +) +def gemm(a_tv: pto.PartitionTensorView, # [M, K] in GM + b_tv: pto.PartitionTensorView, # [K, N] in GM + c_tv: pto.PartitionTensorView, # [M, N] in GM, output + M: int, K: int, N: int): + # Get GM pointers from PartitionTensorViews + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + # Allocate L1 (MAT) tile buffers + l1_a_tile = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b_tile = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + + # Allocate L0 tile buffers + l0a_tile = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b_tile = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c_tile = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # GM → L1 + pto.cube_load(a_ptr, l1_a_tile.as_ptr(), K, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b_tile.as_ptr(), N, nburst=(1, 0, 0)) + + # L1 → L0 + pto.left_load(l1_a_tile.as_ptr(), l0a_tile.as_ptr(), M, K) + pto.right_load(l1_b_tile.as_ptr(), l0b_tile.as_ptr(), K, N) + + # Compute: C = A × B + pto.mad(l0a_tile.as_ptr(), l0b_tile.as_ptr(), l0c_tile.as_ptr(), M, N, K) + + # L0C → GM writeback + pto.acc_store_gm(l0c_tile.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### Split-K GEMM + +Matrix multiplication with K-dimension splitting for large K values: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_splitk", +) +def gemm_splitk(a_tv: pto.PartitionTensorView, # [M, K] + b_tv: pto.PartitionTensorView, # [K, N] + c_tv: pto.PartitionTensorView, # [M, N] + M: int, K: int, N: int, BASEK: int): + iters = K // BASEK + + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + # Allocate buffers sized for one split-K step + l1_a = pto.Tile([M, BASEK], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([BASEK, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, BASEK], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([BASEK, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + for k_step in range(iters): + k_off = k_step * BASEK + + # Offset GM pointers for this K-slice + a_k = pto.addptr(a_ptr, k_off) + b_k = pto.addptr(b_ptr, k_off) + + # GM → L1 → L0 + pto.cube_load(a_k, l1_a.as_ptr(), BASEK, nburst=(1, 0, 0)) + pto.cube_load(b_k, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, BASEK) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), BASEK, N) + + # First step: zero-init; subsequent steps: accumulate + if k_step == 0: + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) + else: + pto.mad_acc(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) + + # L0C → GM + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### GEMM with Bias + +Matrix multiplication with bias addition C = A × B + bias: + +```python +@pto.ckernel( + target="a5", + op="pto.mad_bias", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_bias", +) +def gemm_bias(a_tv: pto.PartitionTensorView, + b_tv: pto.PartitionTensorView, + c_tv: pto.PartitionTensorView, + bias_tv: pto.PartitionTensorView, + M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + bias_ptr = bias_tv.as_ptr() + + # L1 buffers + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l1_bias = pto.Tile([1, N], pto.f32, MemorySpace.MAT) + + # L0 buffers + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # Bias table + bt = pto.Tile([1, N], pto.f32, MemorySpace.BIAS) + + # Data movement + pto.cube_load(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.cube_load(bias_ptr, l1_bias.as_ptr(), N, nburst=(1, 0, 0)) + pto.bias_load(l1_bias.as_ptr(), bt.as_ptr(), N, nburst=(1, 0, 0)) + + # L1 → L0 + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) + + # Compute: C = A × B + bias + pto.mad_bias(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), bt.as_ptr(), M, N, K) + + # Writeback + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### Fractal Load (nd2nz) Example + +Using fractal load for ND-layout to NZ-fractal data loading: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_frac", +) +def gemm_frac(a_tv: pto.PartitionTensorView, + b_tv: pto.PartitionTensorView, + c_tv: pto.PartitionTensorView, + M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # Fractal load: ND → NZ + pto.cube_load_frac(a_ptr, l1_a.as_ptr(), "nd2nz", + shape=(M, K), + src_layout=(K,), + dst_group=(1, 0, 0, 0), + ctrl=(0, False)) + pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### Pure-Compute Kernel (Pre-Allocated Tiles) + +When tiles are pre-allocated externally, the kernel only performs computation: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="matmul_compute", +) +def matmul_compute(a_left: pto.Tile, # Pre-allocated LEFT tile (L0A) + b_right: pto.Tile, # Pre-allocated RIGHT tile (L0B) + c_acc: pto.Tile, # Pre-allocated ACC tile (L0C) + M: int, K: int, N: int): + pto.mad(a_left.as_ptr(), b_right.as_ptr(), c_acc.as_ptr(), M, N, K) +``` + +### Template-based Multi-Op Cube Kernel + +Reusing a single template body for multiple Cube matmul variants: + +```python +@pto.ckernel( + target="a5", + ops=["mad", "mad_acc"], + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_template", + templates={ + "compute": {"mad": "mad", "mad_acc": "mad_acc"}, + }, +) +def gemm_template(a_tv: pto.PartitionTensorView, + b_tv: pto.PartitionTensorView, + c_tv: pto.PartitionTensorView, + M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + pto.cube_load(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) + + # Template slot: resolved at specialization time + pto.tpl("compute", l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +Usage: + +```python +k_mad = pto.select_kernel("a5", "gemm_template", selected_op="mad") +k_acc = pto.select_kernel("a5", "gemm_template", selected_op="mad_acc") +``` diff --git a/ptodsl/docs/user_guide/14-common-errors.md b/ptodsl/docs/user_guide/14-common-errors.md new file mode 100644 index 000000000..46abe09b9 --- /dev/null +++ b/ptodsl/docs/user_guide/14-common-errors.md @@ -0,0 +1,51 @@ +## Common Errors + +### Typed Mask Mismatch + +``` +Error: f32 vector operation cannot consume mask_b16 +``` + +**Solution:** Ensure mask granularity matches vector element size: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +### Strict Scope Implicit Capture + +``` +Error: strict_vecscope body cannot capture outer value 'ub_in' implicitly +``` + +**Solution:** Pass all required values in the capture list: + +```python +# Wrong: +with pto.strict_vecscope() as (): + vec = pto.vlds(ub_in, offset) # ub_in from outer scope + +# Correct: +with pto.strict_vecscope(ub_in) as (ub): + vec = pto.vlds(ub, offset) +``` + +### Untyped Loop Carried State + +``` +Error: loop-carried value must have explicit machine type +``` + +**Solution:** Add type annotations to loop-carried variables: + +```python +# Wrong: +remaining = 1024 # Plain Python int +for i in range(0, N, step): + mask, remaining = pto.make_mask(pto.f32, remaining) + +# Correct: +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + diff --git a/ptodsl/docs/user_guide/15-compatibility-notes.md b/ptodsl/docs/user_guide/15-compatibility-notes.md new file mode 100644 index 000000000..defcf704c --- /dev/null +++ b/ptodsl/docs/user_guide/15-compatibility-notes.md @@ -0,0 +1,9 @@ +## Compatibility Notes + +The current experimental implementation in `python/pto/dialects/pto.py` differs from this specification in several ways: + +1. **Mask types**: The experimental version uses untyped `mask` instead of `mask_b8`/`mask_b16`/`mask_b32` +2. **Barrier operation**: Uses `pto.barrier()` instead of `pto.pipe_barrier()` +3. **Operation coverage**: Implements only a subset of operations + +When implementing new code, follow this specification. The experimental implementation will be updated to match over time. diff --git a/ptodsl/docs/user_guide/16-next-steps.md b/ptodsl/docs/user_guide/16-next-steps.md new file mode 100644 index 000000000..2fe63b9a4 --- /dev/null +++ b/ptodsl/docs/user_guide/16-next-steps.md @@ -0,0 +1,7 @@ +## Next Steps + +- Explore the ISA documentation in `docs/isa/` for detailed operation semantics +- Check `test/samples/` for example kernels +- Refer to `docs/vpto-spec.md` for the underlying VPTO instruction specification + +For compiler developers, see `docs/PTO_IR_manual.md` for MLIR-level details. From 60f4c6a693d6da79b4e1ebd09485f8d25b477194 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 15 May 2026 08:57:54 +0000 Subject: [PATCH 13/31] Add a flash attention demo --- ptodsl/docs/demos/flash_attention_sketch.py | 502 ++++++++++++++++++++ 1 file changed, 502 insertions(+) create mode 100644 ptodsl/docs/demos/flash_attention_sketch.py diff --git a/ptodsl/docs/demos/flash_attention_sketch.py b/ptodsl/docs/demos/flash_attention_sketch.py new file mode 100644 index 000000000..af809ea18 --- /dev/null +++ b/ptodsl/docs/demos/flash_attention_sketch.py @@ -0,0 +1,502 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Flash Attention redesign sketch. + +This file is intentionally a design demo rather than runnable ``ptodsl`` code. +The goal is to make the *proposed* layering explicit and keep the semantic +contracts clean: + + @pto.tkernel L1: tile-level orchestration and logical blocking + └─ @pto.ukernel L2: one KV-block worth of MTE/sync orchestration + ├─ @pto.cube L3: matrix products (QK^T and P@V) + ├─ @pto.simd L3: row-wise online softmax + └─ @pto.simt L3: scalar metadata and output blending + +Design rules illustrated here: + +1. ``tkernel`` owns logical tiling, tile allocation, and loop scheduling. + It should not manually spell low-level DMA details for every micro step. +2. ``ukernel`` owns the per-block execution sandwich: stage the current K/V + block with explicit micro-instructions, synchronize, call hardware-bound + sub-kernels, and manage scratch/state. +3. ``tkernel`` may use tile ops such as ``tload`` / ``tstore`` at the logical + scheduling boundary, but ``ukernel`` stays below that abstraction level. + Once execution enters ``ukernel``, GM<->UB movement is expressed with + ptr-based micro-instructions such as ``dma_load`` instead of tile ops. + The DSL may make pointer materialization ergonomic, but the micro-instruction + boundary itself stays explicit in authored code via ``as_ptr()``. +4. ``simd`` / ``simt`` / ``cube`` are hardware boundaries. They do not expose + vreg values across the function boundary. Data crosses the boundary through + UB-backed tiles or typed UB pointers only. +5. Online-softmax state is made explicit with ping-pong tiles + (``m_prev``/``m_next``, ``l_prev``/``l_next``, ``o_prev``/``o_next``). + Hiding these dependencies with in-place aliases makes the algorithm harder + to read and obscures what the DSL needs to express. + +The API spellings below are approximate and intentionally favor the redesign +surface over today's exact binding details. + +Because this sketch targets a tracing-style frontend, any control flow that +must reach MLIR is expressed with structured DSL constructs such as +``pto.for_`` instead of native Python ``for`` loops. + +Scalar literals and simple index/integer conversions are also shown in their +authored form. The intended frontend behavior is to lift Python ``int`` +literals and obvious scalar arithmetic into the corresponding MLIR scalar ops +implicitly, rather than forcing authors to spell ``pto.const(...)`` or +``index_cast(...)`` at every use site. +""" + +from ptodsl import pto + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Level 3: hardware-bound sub-kernels +# ═══════════════════════════════════════════════════════════════════════════════ +# +# Boundary contract: +# - Tile arguments are UB-backed or cube-local buffers carrying addressable +# storage. +# - No vector register escapes a simd function. +# - No implicit global-memory access happens inside these kernels. + + +@pto.cube +def qk_matmul( + q_tile: pto.Tile, # UB, [Br, dim] + k_tile: pto.Tile, # UB, [Bc, dim] + q_l0a: pto.Tile, # LEFT scratch + k_l0b: pto.Tile, # RIGHT scratch + s_acc: pto.Tile, # ACC scratch + s_tile: pto.Tile, # UB, [Br, Bc] output +): + """ + Compute ``S = Q @ K^T`` for one attention block. + + The key point for the redesign is that the cube kernel consumes UB tiles and + explicit cube-local scratch, rather than pretending a UB tile can also stand + in for LEFT/RIGHT/ACC state. + """ + m = pto.tile_valid_rows(q_tile) + k = pto.tile_valid_cols(q_tile) + n = pto.tile_valid_rows(k_tile) + + # Caller owns scratch lifetime. The cube kernel only expresses dataflow. + pto.left_load(q_tile, q_l0a, m, k) + pto.right_load(k_tile, k_l0b, k, n, transpose=True) + pto.mad(q_l0a, k_l0b, s_acc) + pto.acc_store_ub(s_acc, s_tile, m, n) + + +@pto.cube +def pv_matmul( + p_tile: pto.Tile, # UB, [Br, Bc] + v_tile: pto.Tile, # UB, [Bc, dim] + p_l0a: pto.Tile, # LEFT scratch (reused) + v_l0b: pto.Tile, # RIGHT scratch (reused) + pv_acc: pto.Tile, # ACC scratch (reused) + pv_tile: pto.Tile, # UB, [Br, dim] output +): + """ + Compute ``PV = P @ V`` for the current block. + + This keeps the second matrix product on the cube path as well, instead of + accidentally collapsing it into an elementwise vector expression. + """ + m = pto.tile_valid_rows(p_tile) + k = pto.tile_valid_cols(p_tile) + n = pto.tile_valid_cols(v_tile) + + pto.left_load(p_tile, p_l0a, m, k) + pto.right_load(v_tile, v_l0b, k, n) + pto.mad(p_l0a, v_l0b, pv_acc) + pto.acc_store_ub(pv_acc, pv_tile, m, n) + + +@pto.simd +def online_softmax_rows( + s_tile: pto.Tile, # UB, [Br, Bc] + p_tile: pto.Tile, # UB, [Br, Bc], output + m_prev_tile: pto.Tile, # UB, [Br, 1] + l_prev_tile: pto.Tile, # UB, [Br, 1] + m_next_tile: pto.Tile, # UB, [Br, 1], output + l_next_tile: pto.Tile, # UB, [Br, 1], output + alpha_tile: pto.Tile, # UB, [Br, 1], output + beta_tile: pto.Tile, # UB, [Br, 1], output + row_start: pto.i32, + row_stop: pto.i32, + valid_cols: pto.i32, +): + """ + Per-row online softmax update. + + For each active row:: + + m_next = max(m_prev, row_max(S)) + P = exp(S - m_next) + l_next = l_prev * exp(m_prev - m_next) + row_sum(P) + alpha = l_prev * exp(m_prev - m_next) / l_next + beta = 1 / l_next + + ``alpha`` and ``beta`` are kept explicitly because the output update needs + both the old accumulator and the newly computed ``P @ V`` contribution. + """ + with pto.for_(row_start, row_stop, step=1) as row: + col_mask = pto.make_mask(pto.f32, valid_cols) + + s_row = pto.vlds(s_tile[row, 0:]) + m_prev = pto.lds(m_prev_tile[row, 0]) + l_prev = pto.lds(l_prev_tile[row, 0]) + + row_max = pto.vcgmax(s_row, col_mask) + m_next = pto.max(m_prev, row_max) + + s_shifted = pto.vsubs(s_row, m_next, col_mask) + p_row = pto.vexp(s_shifted, col_mask) + + row_sum = pto.vcgadd(p_row, col_mask) + l_scaled = l_prev * pto.exp(m_prev - m_next) + l_next = l_scaled + row_sum + + alpha = l_scaled / l_next + beta = 1.0 / l_next + + pto.vsts(p_row, p_tile[row, 0:], col_mask) + pto.sts(m_next_tile[row, 0], m_next) + pto.sts(l_next_tile[row, 0], l_next) + pto.sts(alpha_tile[row, 0], alpha) + pto.sts(beta_tile[row, 0], beta) + + +@pto.simt +def blend_output_rows( + o_prev_tile: pto.Tile, # UB, [Br, dim] + pv_tile: pto.Tile, # UB, [Br, dim] + alpha_tile: pto.Tile, # UB, [Br, 1] + beta_tile: pto.Tile, # UB, [Br, 1] + o_next_tile: pto.Tile, # UB, [Br, dim], output + row_start: pto.i32, + row_stop: pto.i32, + valid_dim: pto.i32, +): + """ + Update the output accumulator with SIMT-style scalar element work:: + + O_next[row, col] = alpha[row] * O_prev[row, col] + beta[row] * PV[row, col] + + This intentionally contrasts with ``online_softmax_rows``: the softmax step + stays on the SIMD path because it is dominated by row-wise vector math, + while the final blend is expressed here as explicit scalar work-items over + the tile domain. + """ + with pto.for_(row_start, row_stop, step=1) as row: + alpha = pto.lds(alpha_tile[row, 0]) + beta = pto.lds(beta_tile[row, 0]) + + with pto.for_(0, valid_dim, step=1) as col: + o_prev = pto.lds(o_prev_tile[row, col]) + pv_val = pto.lds(pv_tile[row, col]) + + o_next = alpha * o_prev + beta * pv_val + pto.sts(o_next_tile[row, col], o_next) + + +@pto.simt +def materialize_tile_bounds( + meta_ptr: pto.ptr(pto.i32, pto.MemorySpace.UB), # [out] {row_start, row_stop, valid_cols} + valid_rows: pto.i32, + valid_cols: pto.i32, +): + """ + Materialize tile-local loop bounds for the current block. + + The SIMT kernel stays intentionally small here: it is responsible for + scalar control metadata, not for rewriting the vector or cube logic. + """ + pto.sts(meta_ptr + 0, 0) + pto.sts(meta_ptr + 4, valid_rows) + pto.sts(meta_ptr + 8, valid_cols) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Level 2: ukernel — one KV block worth of execution orchestration +# ═══════════════════════════════════════════════════════════════════════════════ + + +@pto.ukernel +def kv_block_process( + q_tile: pto.Tile, # UB, reused across inner KV loop + k_part: pto.PartitionTensorView, # GM view for current K block + v_part: pto.PartitionTensorView, # GM view for current V block + k_tile: pto.Tile, # UB scratch + v_tile: pto.Tile, # UB scratch + o_prev_tile: pto.Tile, # UB state + o_next_tile: pto.Tile, # UB state + m_prev_tile: pto.Tile, # UB state + l_prev_tile: pto.Tile, # UB state + m_next_tile: pto.Tile, # UB state + l_next_tile: pto.Tile, # UB state + s_tile: pto.Tile, # UB scratch for QK^T + p_tile: pto.Tile, # UB scratch for probabilities + pv_tile: pto.Tile, # UB scratch for P@V + alpha_tile: pto.Tile, # UB scratch + beta_tile: pto.Tile, # UB scratch + q_l0a: pto.Tile, # LEFT scratch for Q + p_l0a: pto.Tile, # LEFT scratch for P + rhs_l0b: pto.Tile, # RIGHT scratch, reused by K/V + qk_acc_tile: pto.Tile, # ACC scratch for QK^T + pv_acc_tile: pto.Tile, # ACC scratch for P@V + meta_ptr: pto.ptr(pto.i32, pto.MemorySpace.UB), +): + """ + Process one KV block against an already-loaded Q tile. + + The ukernel owns: + - staging the current K/V block into reusable UB scratch with explicit + DMA-style micro-instructions, + - synchronizing the hand-off between MTE, cube, simd, and simt stages, + - wiring together the explicit state transition + (prev -> next for m/l/o). + """ + # ukernel deliberately stays below the tile-op abstraction boundary. + # Current-block GM->UB staging is expressed as ptr-based DMA instructions. + pto.dma_load(k_part.as_ptr(), k_tile.as_ptr()) + pto.dma_load(v_part.as_ptr(), v_tile.as_ptr()) + pto.mem_bar(pto.BarrierType.SYNC) + + materialize_tile_bounds( + meta_ptr, + pto.tile_valid_rows(q_tile), + pto.tile_valid_rows(k_tile), + ) + row_start = pto.lds(meta_ptr + 0) + row_stop = pto.lds(meta_ptr + 4) + valid_cols = pto.lds(meta_ptr + 8) + + # 1. S = Q @ K^T + qk_matmul(q_tile, k_tile, q_l0a, rhs_l0b, qk_acc_tile, s_tile) + pto.mem_bar(pto.BarrierType.SYNC) + + # 2. Row-wise online softmax over S + online_softmax_rows( + s_tile, + p_tile, + m_prev_tile, + l_prev_tile, + m_next_tile, + l_next_tile, + alpha_tile, + beta_tile, + row_start, + row_stop, + valid_cols, + ) + pto.mem_bar(pto.BarrierType.SYNC) + + # 3. PV = P @ V + pv_matmul(p_tile, v_tile, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) + pto.mem_bar(pto.BarrierType.SYNC) + + # 4. O_next = alpha * O_prev + beta * PV + blend_output_rows( + o_prev_tile, + pv_tile, + alpha_tile, + beta_tile, + o_next_tile, + row_start, + row_stop, + pto.tile_valid_cols(v_tile), + ) + pto.mem_bar(pto.BarrierType.SYNC) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Level 1: tkernel — tile-level orchestration +# ═══════════════════════════════════════════════════════════════════════════════ + + +@pto.tkernel +def flash_attention( + Q: pto.TensorView, # [batch, seq_q, heads, dim] + K: pto.TensorView, # [batch, seq_k, heads, dim] + V: pto.TensorView, # [batch, seq_k, heads, dim] + O: pto.TensorView, # [batch, seq_q, heads, dim] +): + """ + Flash Attention top-level orchestration sketch. + + To keep the demo focused, batch/head loops are omitted and we show the + per-head 2D core: ``[seq, dim]`` for Q/K/V/O. + """ + Br = 128 + Bc = 128 + seq_q = 4096 + seq_k = 4096 + dim = 64 + + q_blocks = (seq_q + Br - 1) // Br + kv_blocks = (seq_k + Bc - 1) // Bc + + # UB resident logical tiles + q_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) + k_tile = pto.alloc_tile(pto.TileType(pto.f32), Bc, dim) + v_tile = pto.alloc_tile(pto.TileType(pto.f32), Bc, dim) + + o_prev_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) + o_next_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) + m_prev_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) + m_next_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) + l_prev_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) + l_next_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) + + s_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, Bc) + p_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, Bc) + pv_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) + alpha_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) + beta_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) + + # Cube-local scratch is explicit; it should not be conflated with UB tiles. + q_l0a = pto.alloc_tile(pto.TileType(pto.f16, pto.MemorySpace.LEFT), Br, dim) + p_l0a = pto.alloc_tile(pto.TileType(pto.f16, pto.MemorySpace.LEFT), Br, Bc) + rhs_l0b = pto.alloc_tile(pto.TileType(pto.f16, pto.MemorySpace.RIGHT), Bc, dim) + qk_acc_tile = pto.alloc_tile(pto.TileType(pto.f32, pto.MemorySpace.ACC), Br, Bc) + pv_acc_tile = pto.alloc_tile(pto.TileType(pto.f32, pto.MemorySpace.ACC), Br, dim) + + # SIMT metadata buffer. A tiny raw-pointer island is acceptable at the + # ukernel boundary because this is scalar control data, not user-facing math. + meta_tile = pto.alloc_tile(pto.TileType(pto.i32), 3, 1) + meta_ptr = pto.tile_buf_addr(meta_tile) + + q_view = pto.make_tensor_view(Q, shape=[seq_q, dim]) + k_view = pto.make_tensor_view(K, shape=[seq_k, dim]) + v_view = pto.make_tensor_view(V, shape=[seq_k, dim]) + o_view = pto.make_tensor_view(O, shape=[seq_q, dim]) + + with pto.for_(0, q_blocks, step=1) as qi: + q_part = pto.partition_view(q_view, offsets=[qi * Br, 0], sizes=[Br, dim]) + o_part = pto.partition_view(o_view, offsets=[qi * Br, 0], sizes=[Br, dim]) + + pto.tload(q_part, q_tile) + + # Initial online-softmax state for this Q block. + pto.tile_fill(m_prev_tile, float("-inf")) + pto.tile_fill(l_prev_tile, 0.0) + pto.tile_fill(o_prev_tile, 0.0) + + with pto.for_( + 0, + kv_blocks, + step=1, + iter_args=(m_prev_tile, m_next_tile, l_prev_tile, l_next_tile, o_prev_tile, o_next_tile), + ) as kv_loop: + kj = kv_loop.iv + m_prev_cur, m_next_cur, l_prev_cur, l_next_cur, o_prev_cur, o_next_cur = kv_loop.iter_args + k_part = pto.partition_view(k_view, offsets=[kj * Bc, 0], sizes=[Bc, dim]) + v_part = pto.partition_view(v_view, offsets=[kj * Bc, 0], sizes=[Bc, dim]) + + kv_block_process( + q_tile, + k_part, + v_part, + k_tile, + v_tile, + o_prev_cur, + o_next_cur, + m_prev_cur, + l_prev_cur, + m_next_cur, + l_next_cur, + s_tile, + p_tile, + pv_tile, + alpha_tile, + beta_tile, + q_l0a, + p_l0a, + rhs_l0b, + qk_acc_tile, + pv_acc_tile, + meta_ptr, + ) + + # Loop-carried state makes the ping-pong ownership part of the IR. + pto.yield_( + m_next_cur, + m_prev_cur, + l_next_cur, + l_prev_cur, + o_next_cur, + o_prev_cur, + ) + + _, _, _, _, o_final_tile, _ = kv_loop.results + pto.tstore(o_final_tile, o_part) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Layer summary +# ═══════════════════════════════════════════════════════════════════════════════ +# +# ┌──────────────────────────────────────────────────────────────────────────┐ +# │ L1 @pto.tkernel Tile orchestration │ +# │ │ +# │ alloc_tile / make_tensor_view / partition_view / tload / tstore │ +# │ outer Q loop + inner KV loop + ping-pong state ownership │ +# │ │ +# │ Key idea: speak in logical tiles and block scheduling, not in │ +# │ instruction-sized address arithmetic. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ L2 @pto.ukernel Per-block execution sandwich │ +# │ │ +# │ explicit dma_load(ptr, ptr) staging for current K/V block, mem_bar, │ +# │ call cube/simd/simt sub-kernels, │ +# │ manage scratch/state hand-off │ +# │ │ +# │ Key idea: one place owns the "how this block runs on hardware" story. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ L3a @pto.cube Matrix-product kernels │ +# │ │ +# │ qk_matmul: Q @ K^T │ +# │ pv_matmul: P @ V │ +# │ explicit LEFT/RIGHT/ACC scratch + UB output │ +# │ │ +# │ Key idea: UB tiles are inputs/outputs; cube-local state is explicit. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ L3b @pto.simd Row-wise vector math │ +# │ │ +# │ online_softmax_rows │ +# │ vreg stays local; persistent state is written back to UB tiles │ +# │ │ +# │ Key idea: no cross-kernel vreg values, only UB-backed state. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ L3c @pto.simt Scalar metadata and pointwise blend │ +# │ │ +# │ materialize_tile_bounds / blend_output_rows │ +# │ │ +# │ Key idea: SIMT handles scalar control facts and scalar tile walks. │ +# └──────────────────────────────────────────────────────────────────────────┘ +# +# dataflow for one KV block +# +# tkernel alloc/schedule +# │ +# ▼ +# ukernel loads K/V block and sequences the pipeline +# │ +# ├─ cube: Q + K ───────────────► S +# ├─ simd: S + (m_prev, l_prev) ─► P, (m_next, l_next), alpha, beta +# ├─ cube: P + V ───────────────► PV +# └─ simt: (o_prev, PV, alpha, beta) ─► o_next +# +# After each KV block: +# (m_prev, l_prev, o_prev) := (m_next, l_next, o_next) +# +# The important part for the redesign is not the exact helper spelling, but +# that every cross-stage dependency is visible in the surface language. From d094fa28262b312a3a4086bab6e3507fc2fdc367 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 15 May 2026 21:54:02 +0800 Subject: [PATCH 14/31] Completed the first version of PTODSL user guide --- ptodsl/docs/demos/flash_attention_sketch.py | 478 +++-- ptodsl/docs/user_guide/01-introduction.md | 219 ++- ptodsl/docs/user_guide/02-quick-start.md | 281 ++- .../docs/user_guide/03-kernel-declaration.md | 528 ------ .../03-kernel-entry-and-subkernels.md | 412 +++++ ptodsl/docs/user_guide/04-template-kernels.md | 333 ---- .../user_guide/04-type-system-and-buffer.md | 209 +++ ptodsl/docs/user_guide/05-control-flow.md | 228 +++ ptodsl/docs/user_guide/05-type-system.md | 686 ------- ptodsl/docs/user_guide/06-control-flow.md | 181 -- .../user_guide/06-scalar-and-pointer-ops.md | 376 ++++ .../docs/user_guide/07-data-movement-ops.md | 1019 +++++++++++ .../docs/user_guide/07-frontend-operations.md | 352 ---- .../docs/user_guide/08-compute-operations.md | 659 +++++++ .../docs/user_guide/08-sync-dma-operations.md | 622 ------- .../user_guide/09-predicate-and-mask-ops.md | 392 ++++ .../user_guide/09-vector-memory-operations.md | 1058 ----------- .../user_guide/10-predicate-operations.md | 637 ------- ptodsl/docs/user_guide/10-sync-ops.md | 447 +++++ .../11-flash-attention-walkthrough.md | 527 ++++++ .../11-vector-arithmetic-operations.md | 1611 ----------------- .../docs/user_guide/12-additional-examples.md | 400 ++++ ptodsl/docs/user_guide/12-cube-operations.md | 454 ----- ptodsl/docs/user_guide/13-examples.md | 417 ----- ptodsl/docs/user_guide/14-common-errors.md | 51 - .../docs/user_guide/15-compatibility-notes.md | 9 - ptodsl/docs/user_guide/16-next-steps.md | 7 - 27 files changed, 5374 insertions(+), 7219 deletions(-) delete mode 100644 ptodsl/docs/user_guide/03-kernel-declaration.md create mode 100644 ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md delete mode 100644 ptodsl/docs/user_guide/04-template-kernels.md create mode 100644 ptodsl/docs/user_guide/04-type-system-and-buffer.md create mode 100644 ptodsl/docs/user_guide/05-control-flow.md delete mode 100644 ptodsl/docs/user_guide/05-type-system.md delete mode 100644 ptodsl/docs/user_guide/06-control-flow.md create mode 100644 ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md create mode 100644 ptodsl/docs/user_guide/07-data-movement-ops.md delete mode 100644 ptodsl/docs/user_guide/07-frontend-operations.md create mode 100644 ptodsl/docs/user_guide/08-compute-operations.md delete mode 100644 ptodsl/docs/user_guide/08-sync-dma-operations.md create mode 100644 ptodsl/docs/user_guide/09-predicate-and-mask-ops.md delete mode 100644 ptodsl/docs/user_guide/09-vector-memory-operations.md delete mode 100644 ptodsl/docs/user_guide/10-predicate-operations.md create mode 100644 ptodsl/docs/user_guide/10-sync-ops.md create mode 100644 ptodsl/docs/user_guide/11-flash-attention-walkthrough.md delete mode 100644 ptodsl/docs/user_guide/11-vector-arithmetic-operations.md create mode 100644 ptodsl/docs/user_guide/12-additional-examples.md delete mode 100644 ptodsl/docs/user_guide/12-cube-operations.md delete mode 100644 ptodsl/docs/user_guide/13-examples.md delete mode 100644 ptodsl/docs/user_guide/14-common-errors.md delete mode 100644 ptodsl/docs/user_guide/15-compatibility-notes.md delete mode 100644 ptodsl/docs/user_guide/16-next-steps.md diff --git a/ptodsl/docs/demos/flash_attention_sketch.py b/ptodsl/docs/demos/flash_attention_sketch.py index af809ea18..39db1af0e 100644 --- a/ptodsl/docs/demos/flash_attention_sketch.py +++ b/ptodsl/docs/demos/flash_attention_sketch.py @@ -9,32 +9,45 @@ Flash Attention redesign sketch. This file is intentionally a design demo rather than runnable ``ptodsl`` code. -The goal is to make the *proposed* layering explicit and keep the semantic +The goal is to make the *proposed* API layering explicit and keep the semantic contracts clean: - @pto.tkernel L1: tile-level orchestration and logical blocking - └─ @pto.ukernel L2: one KV-block worth of MTE/sync orchestration - ├─ @pto.cube L3: matrix products (QK^T and P@V) - ├─ @pto.simd L3: row-wise online softmax - └─ @pto.simt L3: scalar metadata and output blending + flash_attention(...) user-facing wrapper + └─ @pto.jit flash_attention_kernel + ├─ Tile Ops tload / tstore at the GM↔UB boundary + └─ @pto.ukernel one KV-block worth of MTE/sync orchestration + ├─ @pto.cube matrix products (QK^T and P@V) + ├─ @pto.simd row-wise online softmax + └─ @pto.simt scalar metadata and output blending Design rules illustrated here: -1. ``tkernel`` owns logical tiling, tile allocation, and loop scheduling. - It should not manually spell low-level DMA details for every micro step. -2. ``ukernel`` owns the per-block execution sandwich: stage the current K/V +1. ``@pto.jit`` marks a launchable kernel template. It owns JIT compilation, + cache lookup, and runtime launch binding, instead of forcing users to hop + through extra builder objects for common cases. +2. The Python wrapper owns ergonomic runtime concerns such as output allocation, + default stream handling, and extracting shape/stride metadata from tensors. +3. ``@pto.jit`` also owns the top-level logical tiling, tile allocation, and + loop scheduling for one already-selected per-head 2D slice. It should not + manually spell low-level DMA details for every micro step. +4. ``ukernel`` owns the per-block execution sandwich: stage the current K/V block with explicit micro-instructions, synchronize, call hardware-bound sub-kernels, and manage scratch/state. -3. ``tkernel`` may use tile ops such as ``tload`` / ``tstore`` at the logical +5. ``@pto.jit`` may use tile ops such as ``tload`` / ``tstore`` at the logical scheduling boundary, but ``ukernel`` stays below that abstraction level. Once execution enters ``ukernel``, GM<->UB movement is expressed with - ptr-based micro-instructions such as ``dma_load`` instead of tile ops. - The DSL may make pointer materialization ergonomic, but the micro-instruction - boundary itself stays explicit in authored code via ``as_ptr()``. -4. ``simd`` / ``simt`` / ``cube`` are hardware boundaries. They do not expose + MTE micro-instructions such as ``mte_load`` instead of tile ops. + ``mte_load`` / ``mte_store`` accept partitions and tiles directly, + deriving strides and burst sizes from the type metadata. +6. ``simd`` / ``simt`` / ``cube`` are hardware boundaries. They do not expose vreg values across the function boundary. Data crosses the boundary through UB-backed tiles or typed UB pointers only. -5. Online-softmax state is made explicit with ping-pong tiles +7. L3 sub-kernels can also be called directly from ``@pto.jit`` (compiler + handles MTE + sync) or written inline as context managers + (``with pto.simd():`` etc.). This sketch uses the explicit + ``@pto.ukernel`` → L3 path for full micro-instruction control, but + simpler kernels can skip the ukernel layer. +8. Online-softmax state is made explicit with ping-pong tiles (``m_prev``/``m_next``, ``l_prev``/``l_next``, ``o_prev``/``o_next``). Hiding these dependencies with in-place aliases makes the algorithm harder to read and obscures what the DSL needs to express. @@ -56,6 +69,240 @@ from ptodsl import pto +# ═══════════════════════════════════════════════════════════════════════════════ +# Public API sketch +# ═══════════════════════════════════════════════════════════════════════════════ +# +# This section intentionally sketches the *desired* public surface, not today's +# exact implementation details. The split follows the common industry pattern: +# +# - a user-facing tensor wrapper +# - a launchable JIT kernel entry +# - hardware-bound sub-kernels below it +# +# The low-level kernel body should not double as the user-facing runtime API. +# +# Two intended usage styles: +# +# 1. Direct call (most users): +# out = flash_attention(Q, K, V, causal=True) +# +# 2. Compile first, then launch repeatedly: +# compiled = flash_attention_kernel.compile(BLOCK_Q=128, BLOCK_KV=128, CAUSAL=True) +# compiled[batch * heads, stream]( +# Q, K, V, O, +# ) + +def flash_attention( + Q, + K, + V, + *, + O=None, + causal=False, + block_q=128, + block_kv=128, + stream=None, +): + """ + User-facing convenience wrapper. + + This is the API most end users should call. It mirrors mainstream tensor + libraries: infer runtime metadata from tensors, allocate the output when the + caller does not provide one, then compile and launch the JIT kernel. + """ + if O is None: + O = pto.empty_like(Q) + + batch, seq_q, heads, dim = Q.shape + _, seq_k, _, _ = K.shape + + compiled = flash_attention_kernel.compile( + BLOCK_Q=block_q, + BLOCK_KV=block_kv, + CAUSAL=causal, + ) + + compiled[batch * heads, stream](Q, K, V, O) + return O + +@pto.jit(target="a5") +def flash_attention_kernel( + Q, # Python/framework tensor, logical [batch, seq_q, heads, dim] + K, # Python/framework tensor, logical [batch, seq_k, heads, dim] + V, # Python/framework tensor, logical [batch, seq_k, heads, dim] + O, # Python/framework tensor, logical [batch, seq_q, heads, dim] + *, + BLOCK_Q: pto.constexpr = 128, + BLOCK_KV: pto.constexpr = 128, + CAUSAL: pto.constexpr = False, + NUM_STAGES: pto.constexpr = 2, +): + """ + Launchable device entry. + + ``@pto.jit`` is the compile + launch boundary. Inputs/outputs at this + boundary are Python-native tensor objects; PTO-specific ``TensorView`` + descriptors are materialized inside the JIT body rather than exposed in the + public signature. Tile sizes and specialization knobs remain constexpr + metadata. + + A launch instance is responsible for one ``(batch, head)`` slice. The + per-slice logical tiling is expressed directly in this top-level JIT entry. + """ + batch, seq_q, heads, dim = Q.shape + _, seq_k, _, _ = K.shape + + q_view = pto.make_tensor_view(Q, shape=[batch, seq_q, heads, dim], strides=Q.strides) + k_view = pto.make_tensor_view(K, shape=[batch, seq_k, heads, dim], strides=K.strides) + v_view = pto.make_tensor_view(V, shape=[batch, seq_k, heads, dim], strides=V.strides) + o_view = pto.make_tensor_view(O, shape=[batch, seq_q, heads, dim], strides=O.strides) + + # Make the SPMD launch contract explicit in the authored surface. + # This sketch uses one block per (batch, head) slice and does not further + # split work across subblocks, but the runtime indices still belong in a + # realistic launchable entry. + block_idx = pto.get_block_idx() + block_num = pto.get_block_num() + subblock_idx = pto.get_subblock_idx() + subblock_num = pto.get_subblock_num() + + # Current mapping: + # - launch grid = batch * heads + # - block_idx selects one (batch, head) slice + # - subblock_idx is queried explicitly, but no extra intra-block partition + # is modeled in this sketch yet + _ = block_num + _ = subblock_idx + _ = subblock_num + + batch_idx = block_idx // heads + head_idx = block_idx % heads + + q_head = pto.select_head_view( + q_view, + batch=batch_idx, + head=head_idx, + shape=[seq_q, dim], + ) + k_head = pto.select_head_view( + k_view, + batch=batch_idx, + head=head_idx, + shape=[seq_k, dim], + ) + v_head = pto.select_head_view( + v_view, + batch=batch_idx, + head=head_idx, + shape=[seq_k, dim], + ) + o_head = pto.select_head_view( + o_view, + batch=batch_idx, + head=head_idx, + shape=[seq_q, dim], + ) + + Br = BLOCK_Q + Bc = BLOCK_KV + + q_blocks = (seq_q + Br - 1) // Br + kv_blocks = (seq_k + Bc - 1) // Bc + + # UB resident logical tiles for one selected (batch, head) slice. + q_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) + k_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) + v_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) + + o_prev_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) + o_next_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) + m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + + s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) + p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) + pv_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) + alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + + # Cube-local scratch is explicit; it should not be conflated with UB tiles. + q_l0a = pto.alloc_tile(shape=[Br, dim], dtype=pto.f16, memory_space=pto.MemorySpace.LEFT) + p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f16, memory_space=pto.MemorySpace.LEFT) + rhs_l0b = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f16, memory_space=pto.MemorySpace.RIGHT) + qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) + pv_acc_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) + + # SIMT metadata buffer. A tiny raw-pointer island is acceptable at the + # ukernel boundary because this is scalar control data, not user-facing math. + meta_tile = pto.alloc_tile(shape=[3, 1], dtype=pto.i32) + meta_ptr = pto.tile_buf_addr(meta_tile) + + with pto.for_(0, q_blocks, step=1) as qi: + q_part = pto.partition_view(q_head, offsets=[qi * Br, 0], sizes=[Br, dim]) + o_part = pto.partition_view(o_head, offsets=[qi * Br, 0], sizes=[Br, dim]) + + pto.tload(q_part, q_tile) + + # Initial online-softmax state for this Q block. + # ``CAUSAL`` is threaded at the API boundary even though the masking + # details are intentionally omitted from this design-focused sketch. + m_prev_tile.fill(float("-inf")) + l_prev_tile.fill(0.0) + o_prev_tile.fill(0.0) + + kv_loop = pto.for_(0, kv_blocks, step=1).carry( + m=m_prev_tile, + l=l_prev_tile, + o=o_prev_tile, + ) + with kv_loop: + kj = kv_loop.iv + m_cur = kv_loop.m + l_cur = kv_loop.l + o_cur = kv_loop.o + k_part = pto.partition_view(k_head, offsets=[kj * Bc, 0], sizes=[Bc, dim]) + v_part = pto.partition_view(v_head, offsets=[kj * Bc, 0], sizes=[Bc, dim]) + + kv_block_process( + q_tile, + k_part, + v_part, + k_tile, + v_tile, + o_cur, + o_next_tile, + m_cur, + l_cur, + m_next_tile, + l_next_tile, + s_tile, + p_tile, + pv_tile, + alpha_tile, + beta_tile, + q_l0a, + p_l0a, + rhs_l0b, + qk_acc_tile, + pv_acc_tile, + meta_ptr, + ) + + # Loop-carried state is still explicit, but the authored surface no + # longer mirrors raw scf.iter_args / scf.yield spellings. + kv_loop.update( + m=m_next_tile, + l=l_next_tile, + o=o_next_tile, + ) + + o_final_tile = kv_loop.final("o") + pto.tstore(o_final_tile, o_part) + + # ═══════════════════════════════════════════════════════════════════════════════ # Level 3: hardware-bound sub-kernels # ═══════════════════════════════════════════════════════════════════════════════ @@ -88,10 +335,10 @@ def qk_matmul( n = pto.tile_valid_rows(k_tile) # Caller owns scratch lifetime. The cube kernel only expresses dataflow. - pto.left_load(q_tile, q_l0a, m, k) - pto.right_load(k_tile, k_l0b, k, n, transpose=True) + pto.mte_l1_l0a(q_tile, q_l0a, m, k) + pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) pto.mad(q_l0a, k_l0b, s_acc) - pto.acc_store_ub(s_acc, s_tile, m, n) + pto.mte_l0c_ub(s_acc, s_tile, m, n) @pto.cube @@ -113,10 +360,10 @@ def pv_matmul( k = pto.tile_valid_cols(p_tile) n = pto.tile_valid_cols(v_tile) - pto.left_load(p_tile, p_l0a, m, k) - pto.right_load(v_tile, v_l0b, k, n) + pto.mte_l1_l0a(p_tile, p_l0a, m, k) + pto.mte_l1_l0b(v_tile, v_l0b, k, n) pto.mad(p_l0a, v_l0b, pv_acc) - pto.acc_store_ub(pv_acc, pv_tile, m, n) + pto.mte_l0c_ub(pv_acc, pv_tile, m, n) @pto.simd @@ -151,27 +398,27 @@ def online_softmax_rows( col_mask = pto.make_mask(pto.f32, valid_cols) s_row = pto.vlds(s_tile[row, 0:]) - m_prev = pto.lds(m_prev_tile[row, 0]) - l_prev = pto.lds(l_prev_tile[row, 0]) + m_prev = scalar.load(m_prev_tile[row, 0]) + l_prev = scalar.load(l_prev_tile[row, 0]) row_max = pto.vcgmax(s_row, col_mask) - m_next = pto.max(m_prev, row_max) + m_next = scalar.max(m_prev, row_max) s_shifted = pto.vsubs(s_row, m_next, col_mask) p_row = pto.vexp(s_shifted, col_mask) row_sum = pto.vcgadd(p_row, col_mask) - l_scaled = l_prev * pto.exp(m_prev - m_next) + l_scaled = l_prev * scalar.exp(m_prev - m_next) l_next = l_scaled + row_sum alpha = l_scaled / l_next beta = 1.0 / l_next pto.vsts(p_row, p_tile[row, 0:], col_mask) - pto.sts(m_next_tile[row, 0], m_next) - pto.sts(l_next_tile[row, 0], l_next) - pto.sts(alpha_tile[row, 0], alpha) - pto.sts(beta_tile[row, 0], beta) + scalar.sts(m_next_tile[row, 0], m_next) + scalar.sts(l_next_tile[row, 0], l_next) + scalar.sts(alpha_tile[row, 0], alpha) + scalar.sts(beta_tile[row, 0], beta) @pto.simt @@ -196,15 +443,15 @@ def blend_output_rows( the tile domain. """ with pto.for_(row_start, row_stop, step=1) as row: - alpha = pto.lds(alpha_tile[row, 0]) - beta = pto.lds(beta_tile[row, 0]) + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) with pto.for_(0, valid_dim, step=1) as col: - o_prev = pto.lds(o_prev_tile[row, col]) - pv_val = pto.lds(pv_tile[row, col]) + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) o_next = alpha * o_prev + beta * pv_val - pto.sts(o_next_tile[row, col], o_next) + scalar.sts(o_next_tile[row, col], o_next) @pto.simt @@ -219,9 +466,9 @@ def materialize_tile_bounds( The SIMT kernel stays intentionally small here: it is responsible for scalar control metadata, not for rewriting the vector or cube logic. """ - pto.sts(meta_ptr + 0, 0) - pto.sts(meta_ptr + 4, valid_rows) - pto.sts(meta_ptr + 8, valid_cols) + scalar.sts(meta_ptr + 0, 0) + scalar.sts(meta_ptr + 4, valid_rows) + scalar.sts(meta_ptr + 8, valid_cols) # ═══════════════════════════════════════════════════════════════════════════════ @@ -264,10 +511,9 @@ def kv_block_process( - wiring together the explicit state transition (prev -> next for m/l/o). """ - # ukernel deliberately stays below the tile-op abstraction boundary. - # Current-block GM->UB staging is expressed as ptr-based DMA instructions. - pto.dma_load(k_part.as_ptr(), k_tile.as_ptr()) - pto.dma_load(v_part.as_ptr(), v_tile.as_ptr()) + # Current-block GM->UB staging via MTE micro-instructions. + pto.mte_load(k_part, k_tile) + pto.mte_load(v_part, v_tile) pto.mem_bar(pto.BarrierType.SYNC) materialize_tile_bounds( @@ -275,9 +521,9 @@ def kv_block_process( pto.tile_valid_rows(q_tile), pto.tile_valid_rows(k_tile), ) - row_start = pto.lds(meta_ptr + 0) - row_stop = pto.lds(meta_ptr + 4) - valid_cols = pto.lds(meta_ptr + 8) + row_start = scalar.load(meta_ptr + 0) + row_stop = scalar.load(meta_ptr + 4) + valid_cols = scalar.load(meta_ptr + 8) # 1. S = Q @ K^T qk_matmul(q_tile, k_tile, q_l0a, rhs_l0b, qk_acc_tile, s_tile) @@ -317,145 +563,29 @@ def kv_block_process( pto.mem_bar(pto.BarrierType.SYNC) -# ═══════════════════════════════════════════════════════════════════════════════ -# Level 1: tkernel — tile-level orchestration -# ═══════════════════════════════════════════════════════════════════════════════ - - -@pto.tkernel -def flash_attention( - Q: pto.TensorView, # [batch, seq_q, heads, dim] - K: pto.TensorView, # [batch, seq_k, heads, dim] - V: pto.TensorView, # [batch, seq_k, heads, dim] - O: pto.TensorView, # [batch, seq_q, heads, dim] -): - """ - Flash Attention top-level orchestration sketch. - - To keep the demo focused, batch/head loops are omitted and we show the - per-head 2D core: ``[seq, dim]`` for Q/K/V/O. - """ - Br = 128 - Bc = 128 - seq_q = 4096 - seq_k = 4096 - dim = 64 - - q_blocks = (seq_q + Br - 1) // Br - kv_blocks = (seq_k + Bc - 1) // Bc - - # UB resident logical tiles - q_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) - k_tile = pto.alloc_tile(pto.TileType(pto.f32), Bc, dim) - v_tile = pto.alloc_tile(pto.TileType(pto.f32), Bc, dim) - - o_prev_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) - o_next_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) - m_prev_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) - m_next_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) - l_prev_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) - l_next_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) - - s_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, Bc) - p_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, Bc) - pv_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, dim) - alpha_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) - beta_tile = pto.alloc_tile(pto.TileType(pto.f32), Br, 1) - - # Cube-local scratch is explicit; it should not be conflated with UB tiles. - q_l0a = pto.alloc_tile(pto.TileType(pto.f16, pto.MemorySpace.LEFT), Br, dim) - p_l0a = pto.alloc_tile(pto.TileType(pto.f16, pto.MemorySpace.LEFT), Br, Bc) - rhs_l0b = pto.alloc_tile(pto.TileType(pto.f16, pto.MemorySpace.RIGHT), Bc, dim) - qk_acc_tile = pto.alloc_tile(pto.TileType(pto.f32, pto.MemorySpace.ACC), Br, Bc) - pv_acc_tile = pto.alloc_tile(pto.TileType(pto.f32, pto.MemorySpace.ACC), Br, dim) - - # SIMT metadata buffer. A tiny raw-pointer island is acceptable at the - # ukernel boundary because this is scalar control data, not user-facing math. - meta_tile = pto.alloc_tile(pto.TileType(pto.i32), 3, 1) - meta_ptr = pto.tile_buf_addr(meta_tile) - - q_view = pto.make_tensor_view(Q, shape=[seq_q, dim]) - k_view = pto.make_tensor_view(K, shape=[seq_k, dim]) - v_view = pto.make_tensor_view(V, shape=[seq_k, dim]) - o_view = pto.make_tensor_view(O, shape=[seq_q, dim]) - - with pto.for_(0, q_blocks, step=1) as qi: - q_part = pto.partition_view(q_view, offsets=[qi * Br, 0], sizes=[Br, dim]) - o_part = pto.partition_view(o_view, offsets=[qi * Br, 0], sizes=[Br, dim]) - - pto.tload(q_part, q_tile) - - # Initial online-softmax state for this Q block. - pto.tile_fill(m_prev_tile, float("-inf")) - pto.tile_fill(l_prev_tile, 0.0) - pto.tile_fill(o_prev_tile, 0.0) - - with pto.for_( - 0, - kv_blocks, - step=1, - iter_args=(m_prev_tile, m_next_tile, l_prev_tile, l_next_tile, o_prev_tile, o_next_tile), - ) as kv_loop: - kj = kv_loop.iv - m_prev_cur, m_next_cur, l_prev_cur, l_next_cur, o_prev_cur, o_next_cur = kv_loop.iter_args - k_part = pto.partition_view(k_view, offsets=[kj * Bc, 0], sizes=[Bc, dim]) - v_part = pto.partition_view(v_view, offsets=[kj * Bc, 0], sizes=[Bc, dim]) - - kv_block_process( - q_tile, - k_part, - v_part, - k_tile, - v_tile, - o_prev_cur, - o_next_cur, - m_prev_cur, - l_prev_cur, - m_next_cur, - l_next_cur, - s_tile, - p_tile, - pv_tile, - alpha_tile, - beta_tile, - q_l0a, - p_l0a, - rhs_l0b, - qk_acc_tile, - pv_acc_tile, - meta_ptr, - ) - - # Loop-carried state makes the ping-pong ownership part of the IR. - pto.yield_( - m_next_cur, - m_prev_cur, - l_next_cur, - l_prev_cur, - o_next_cur, - o_prev_cur, - ) - - _, _, _, _, o_final_tile, _ = kv_loop.results - pto.tstore(o_final_tile, o_part) - - # ═══════════════════════════════════════════════════════════════════════════════ # Layer summary # ═══════════════════════════════════════════════════════════════════════════════ # # ┌──────────────────────────────────────────────────────────────────────────┐ -# │ L1 @pto.tkernel Tile orchestration │ +# │ L0 Python wrapper flash_attention(...) │ +# │ │ +# │ output allocation, shape/stride extraction, compile, launch │ +# │ │ +# │ Key idea: user-facing tensor API, not IR authoring. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ L1 @pto.jit compile + cache + launch + top-level orchestration │ # │ │ -# │ alloc_tile / make_tensor_view / partition_view / tload / tstore │ +# │ flash_attention_kernel[grid, stream](...) │ +# │ TensorView metadata / alloc_tile / partition_view / tload / tstore │ # │ outer Q loop + inner KV loop + ping-pong state ownership │ # │ │ -# │ Key idea: speak in logical tiles and block scheduling, not in │ -# │ instruction-sized address arithmetic. │ +# │ Key idea: one launchable entry owns both runtime binding and logical │ +# │ tile scheduling. │ # ├──────────────────────────────────────────────────────────────────────────┤ -# │ L2 @pto.ukernel Per-block execution sandwich │ +# │ L2 @pto.ukernel Per-block execution sandwich │ # │ │ -# │ explicit dma_load(ptr, ptr) staging for current K/V block, mem_bar, │ +# │ explicit mte_load(part, tile) staging for current K/V block, mem_bar, │ # │ call cube/simd/simt sub-kernels, │ # │ manage scratch/state hand-off │ # │ │ @@ -485,7 +615,7 @@ def flash_attention( # # dataflow for one KV block # -# tkernel alloc/schedule +# jit kernel alloc/schedule # │ # ▼ # ukernel loads K/V block and sequences the pipeline diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index 26012f781..cc6e8f134 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -1,47 +1,194 @@ -# TileLang Python DSL Guide +# 1. Introduction -The TileLang Python DSL provides a high-level, Pythonic interface for authoring vector compute kernels targeting the Ascend NPU hardware. This guide is intended for library developers and performance engineers who need to write efficient, hardware-aware kernels using the PTO micro instruction set. +**PTO** is a virtual instruction set designed for the Ascend NPU — a hardware-abstracted programming model that exposes the full capability of the Cube, Vector, and Scalar compute units through a unified operation set. **PTODSL** is the Python frontend for PTO. It wraps the PTO instruction set in a Python-embedded DSL with tracing-based compilation, so you can write PTO programs using familiar Python syntax. Under the hood, PTODSL traces your kernel function into PTO IR, which the PTOAS compiler then lowers, optimizes, and emits as NPU executables. In short: PTO defines the *what* (the instruction set), PTODSL provides the *how* (the authoring experience), and together they give you direct access to all three NPU compute units without leaving Python. -The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. +## 1.1 Target hardware -## Language Tier +The Ascend NPU is organized around three compute units and a shared on-chip buffer, connected through the Memory Transfer Engine (MTE): -The DSL surface is organized into multiple maturity tiers, reflecting the stability and intended use of different language features. As the design evolves, the basic authoring path is being explicitly separated from more advanced surfaces. Refer to the following table when reading this guide: +``` + ┌─────────────────────────┐ + │ Global Memory (GM) │ + │ (off-chip HBM) │ + └────────────┬──────────────┘ + │ + ┌──────────┴──────────┐ + │ MTE (DMA engine) │ + └──────────┬──────────┘ + │ + ┌────────────┴──────────────┐ + │ Unified Buffer (UB) │ + │ (on-chip scratchpad) │ + └──┬───────────┬──────────┬──┘ + │ │ │ + ┌────────┴──┐ ┌─────┴──────┐ │ + │ LEFT/RIGHT│ │ │ │ + │ /ACC/BIAS│ │ Vector │ │ + │ │ │ (SIMD) │ │ + │ Cube │ │ │ │ + │ │ └────────────┘ │ + └───────────┘ │ + ┌──────────┴──┐ + │ SIMT │ + │ (scalar PG) │ + └─────────────┘ +``` -| Surface Family | Tier | Usage Guidance | -|----------------|------|----------------| -| `TensorView` | `basic` | Default GM-facing data model for starter kernels. | -| `Tile` | `basic` | Default UB-facing compute tile for starter kernels. | -| Base vector ops (`make_mask`, `vlds`, `vsts`, `vadd`, `vmuls`, etc.) | `basic` | Default compute skeleton for starter kernels. | -| `strict_vecscope` | `advanced` | Explicit vector-scope management for expert authoring. | -| Raw pointer family (`ptr(...)`, `castptr`, `addptr`) | `advanced` | For expert authoring and migration; not required for Quick Start. | -| DMA family (`copy_*`, `set_loop*_stride_*`, `set_loop_size_*`, pad-fill control) | `advanced` | Direct DMA engine control for expert authoring, including GM→UB padding behavior. | -| Tile pointer helper (`tile.as_ptr()`) | `advanced` | Expert-only helper when advanced authoring needs explicit typed pointers. | +| Unit | Role | Typical workload | +|------|------|------------------| +| **Cube** | Matrix multiplication | GEMM, convolution | +| **SIMD** | Row-wise vector math | activation, normalization, reduction | +| **SIMT** | Scalar-programmable unit | pointwise tile walks, metadata | -For the authoritative tier classification, consult `tilelang-dsl/python/tilelang_dsl/support_matrix.py`. For known implementation gaps, refer to `tilelang-dsl/docs/unsupported-features.md`. +- **Global Memory (GM)** is off-chip HBM. All input and output tensors reside here. +- **Unified Buffer (UB)** is the on-chip scratchpad shared by all three compute units. Tile buffers and intermediate results live here during kernel execution. +- **MTE** (Memory Transfer Engine) handles DMA transfers between GM and UB, and between UB regions. +- **Cube** has its own private on-chip buffers — LEFT, RIGHT, ACC, and BIAS — for staging matrix operands and accumulators. +- **SIMD** executes row-wise vector instructions directly on UB-resident data. +- **SIMT** is a scalar-programmable processor group that executes scalar instructions across many work-items in parallel. It is well-suited for per-element control logic, tile boundary metadata, and pointwise blends. -### Basic vs Advanced Authoring Modes +PTODSL gives you direct access to all three units and explicit control over data movement, without abstracting away the hardware boundaries. -The TileLang DSL provides two distinct authoring modes: +## 1.2 Abstraction hierarchy -**Basic Mode (default)** -- Uses **Tile element/slice semantics** for buffer access -- Direct tile indexing syntax: `tile[start:]`, `tile[row, col:]`, `tile[row:, col]` (Tile indexing sugar only supports open-ended vector slices; explicit `stop` and `step` forms are not accepted for `Tile` indexing) -- Vector operations use element-indexing syntax: `pto.vlds(tile[row, col:])`, `pto.vsts(vec, tile[start:], mask)` -- No pointer arithmetic or explicit offset calculations -- Suitable for most kernel authoring with high-level abstractions +PTODSL organizes kernel code into three layers, each building on the one below it: -**Advanced Mode (`advanced=True` in `@pto.vkernel`)** -- Uses **raw pointer semantics** for explicit memory management -- Direct pointer operations correspond to `pto.ptr` types in MLIR -- Explicit pointer arithmetic: `ptr(...)`, `castptr`, `addptr` -- Manual DMA engine control with low-level copy operations and explicit GM→UB padding behavior -- Requires explicit buffer management and pointer arithmetic -- Intended for expert users and performance-critical optimizations +``` +Python Wrapper L0 user-facing wrapper (NumPy, torch-npu, pure Python) + └─ @pto.jit L1 compile + cache + launch + ├─ Tile Ops tile-level: tload, tstore, tadd, ... + └─ @pto.ukernel L2 micro-instruction orchestration + ├─ MTE Ops mte_load / mte_store / copy_gm_to_ubuf / ... + ├─ @pto.cube matrix products (mad, mte_l1_l0a, mte_l0c_ub, ...) + ├─ @pto.simd row-wise vector math (vlds, vadd, vexp, vsts, ...) + └─ @pto.simt scalar-like compute (lds, sts, pointwise blends, ...) +``` -**Key Differences** -- **Basic mode**: Uses tile element-indexing syntax (`tile[row, col:]`, `tile[start:]`) for vector operations -- **Advanced mode**: Uses pointer byte-offset syntax (`pto.vlds(buf: ptr, offset)`) for vector operations -- Tile slices in basic mode correspond to MLIR `memref` types -- Raw pointers in advanced mode correspond to MLIR `pto.ptr` types -- No automatic conversion between tile and pointer semantics - choose the appropriate syntax for your authoring mode +### L0 — Python wrapper + +The outermost layer is plain Python. It handles ergonomic runtime concerns: allocating output tensors, extracting shapes and strides from framework tensors, compiling the JIT kernel, and launching it. Because L0 is just Python, you can freely mix in NumPy, torch-npu, or any other Python framework for pre- and post-processing, data preparation, or composing multiple kernel launches. This layer knows nothing about NPU internals — it is just a convenience function that most end users will call. + +```python +def flash_attention(Q, K, V, *, O=None, causal=False): + if O is None: + O = pto.empty_like(Q) + compiled = flash_attention_kernel.compile( + BLOCK_Q=128, BLOCK_KV=128, CAUSAL=causal + ) + compiled[batch * heads, stream](Q, K, V, O) + return O +``` + +### L1 — `@pto.jit` + +Decorating a function with `@pto.jit` marks it as a launchable PTO kernel. This decoration means: + +- **Compilation**: the function body is traced once to record all PTO instructions, then lowered through the PTOAS compiler pipeline into an optimized NPU executable. +- **Caching**: compiled kernels are cached by key (function identity + constexpr parameter values), so repeated calls with the same configuration skip recompilation. +- **Launch binding**: the compiled kernel can be invoked with a grid and stream — `compiled[grid, stream](args...)` — which launches the executable on the NPU with the given SPMD grid. + +The parameters of a `@pto.jit` function are Python-native tensors (not PTODSL-specific descriptors). The kernel body materializes `TensorView` descriptors from them via `make_tensor_view`, then partitions the problem with `partition_view`. Compile-time constants are declared as keyword-only arguments with `pto.constexpr`: + +```python +@pto.jit(target="a5") +def flash_attention_kernel( + Q, K, V, O, + *, + BLOCK_Q: pto.constexpr = 128, + BLOCK_KV: pto.constexpr = 128, + CAUSAL: pto.constexpr = False, +): + ... +``` + +L1 is the primary layer for expressing **tile-level semantics**. Inside `@pto.jit`, you allocate tile buffers (`alloc_tile`), move data between GM and UB at block granularity (`tload`, `tstore`), and perform tile-level compute (`tadd`, `texp`, `trowsum`, etc.). When the built-in Tile Ops are not sufficient, you can drop down to `@pto.ukernel` to write custom tile-level semantics with micro-instructions. + +The SPMD launch contract is also owned here: the runtime grid (e.g., `batch * heads` blocks) is declared at the call site, and block/subblock indices are queried via `pto.get_block_idx()` and friends. + +### L2 — `@pto.ukernel` + +`@pto.ukernel` (short for *micro-instruction kernel*) is the entry point for expressing **PTO micro-instruction semantics**. Where L1 works with tile buffers as opaque wholes, L2 gives you direct control over individual MTE, vector, and scalar instructions. This layer is intended for users who pursue peak performance and need precise control over low-level hardware details — instruction ordering, DMA scheduling, per-byte data placement, and synchronization. + +Inside a ukernel, you write instructions targeting the three hardware units, and orchestrate data movement between them via **MTE Ops**: + +- **MTE Ops** (`mte_load`, `mte_store`, `copy_gm_to_ubuf`, etc.) move data between GM and UB, or between UB regions, at the DMA engine level. +- **`@pto.cube`**, **`@pto.simd`**, and **`@pto.simt`** sub-kernels execute the actual compute on their respective hardware units. + +The ukernel manages the execution sandwich for one block: staging data with MTE Ops, issuing synchronization barriers, dispatching sub-kernels, and managing loop-carried state between invocations. + +### L3 — `@pto.cube` / `@pto.simd` / `@pto.simt` + +These are hardware-bound compute sub-kernels, each mapped to a specific NPU compute unit: + +- **`@pto.cube`** consumes UB tiles and explicit cube-local scratch (LEFT, RIGHT, ACC, BIAS). Typical operations: `mad`, `mte_l1_l0a`, `mte_l1_l0b`, `mte_l0c_ub`. + +- **`@pto.simd`** operates on vector registers (`vreg`). Typical operations: `vlds`, `vadd`, `vexp`, `vcgmax`, `vsts`. Vector registers never cross the simd function boundary — persistent state is written back to UB tiles. + +- **`@pto.simt`** is a scalar-programmable processor group that executes scalar instructions across many work-items in parallel. Typical operations: `lds`, `sts`, scalar arithmetic and comparison. Well-suited for per-element tile walks, boundary metadata, and pointwise blends. + +L3 sub-kernels can be invoked in two ways: as named decorated functions (`@pto.cube` / `@pto.simd` / `@pto.simt`) — reusable and callable from `@pto.ukernel` or directly from `@pto.jit` — or inline as context managers (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) for quick prototyping. When called directly from `@pto.jit`, you stage data with `tload`/`tstore` instead of `mte_load`/`mte_store`; PTOAS handles the synchronization between Tile Ops and L3 compute automatically. + +The boundary contract is strict: vreg values do not escape a simd kernel, cube-local state does not leak into UB, and data crosses layer boundaries only through UB-backed tiles or typed UB pointers. + +## 1.3 Tracing execution model + +PTODSL uses a **tracing** compilation model. When you call `kernel.compile(...)`, PTODSL executes your Python function body once to record every PTO instruction into an intermediate representation — this pass is called *tracing*. The traced IR is then lowered and optimized into device code. Once compiled, invoking `compiled[grid, stream](args...)` launches the already-built device code directly on the NPU. + +This has one critical implication for how you write control flow and scalar logic: + +- **Python native control flow** (`for`, `if`, Python arithmetic) runs at trace time. A `for i in range(4)` loop gets unrolled — the device code contains four copies of the body, not a loop instruction. An `if` branch condition is evaluated at trace time, and only the taken branch is recorded. + +- **`pto.for_` / `pto.if_`** are recorded as structured control-flow IR. They preserve loop and branch semantics into the compiler pipeline, where the PTOAS compiler may further optimize them — unrolling, folding, or keeping them as runtime control flow depending on what is known at compile time. + +- **Python scalar expressions** (`alpha * x`, `1.0 / sqrt(d)`) are evaluated at trace time and their results are baked into the IR as constants — the compiler never sees the original expression. + +- **PTO scalar instructions** (`scalar.load(...)`, `scalar.max(...)`, `scalar.exp(...)`) are recorded as scalar IR and enter the compiler pipeline, where they may be constant-folded or lowered to runtime scalar operations depending on whether their inputs are compile-time known. + +A simple rule of thumb: **Python constructs are resolved before the compiler sees them. PTO constructs are recorded into IR and the compiler decides.** + +Chapter 5 (Control Flow) and Chapter 6 (Scalar & Pointer Operations) cover this in detail. + +## 1.4 A worked example + +The flash attention kernel from Section 1.2 is not just an architectural diagram — it is a complete, runnable design sketch distributed with PTODSL (`demos/flash_attention_sketch.py`). Here is how the layers map to actual code: + +**L1 (`@pto.jit`)** allocates tiles for the Q block, KV block, online-softmax state (m/l/o ping-pong tiles), and cube-local scratch. It loops over Q blocks (outer `pto.for_`) and KV blocks (inner `pto.for_` with carry state), calling `kv_block_process` for each KV block and using `tload`/`tstore` at the GM boundary. + +**L2 (`@pto.ukernel`)** stages the current K and V blocks with `mte_load`, issues `mem_bar` for synchronization, then sequences four sub-kernel calls: `qk_matmul` (cube), `online_softmax_rows` (simd), `pv_matmul` (cube), `blend_output_rows` (simt). + +**L3a (`@pto.cube`)** performs `mte_l1_l0a` / `mte_l1_l0b` / `mad` / `mte_l0c_ub` for both QK^T and P@V products. + +**L3b (`@pto.simd`)** implements the online softmax update: per-row max, exp, sum, and alpha/beta computation using vector ops (`vlds`, `vcgmax`, `vexp`, `vcgadd`, `vsts`). + +**L3c (`@pto.simt`)** blends the old and new output accumulators with per-element `lds`/`sts` and scalar arithmetic. + +Chapter 11 walks through this example in full detail. + +## 1.5 Reading guide + +| If you are... | Start with... | +|---------------|---------------| +| New to PTODSL | Chapter 2 (Quick Start), then Chapter 3 (Kernel Entries) | + +| Writing your first kernel | Chapter 2 → Chapter 4 (Type System) → Chapter 5 (Control Flow) | +| Looking up a specific operation | Chapters 6–10 (organized by topic) | +| Understanding the flash attention reference | Chapter 11 | + +**Chapter overview:** + +| Chapter | Topic | +|---------|-------| +| 1 | Introduction (this chapter) | +| 2 | Quick Start — a minimal working kernel | +| 3 | Kernel entry points: `@pto.jit`, `@pto.ukernel`, `@pto.cube`, `@pto.simd`, `@pto.simt` | +| 4 | Type system and buffer management: scalars, tiles, views, allocation | +| 5 | Control flow: trace-time Python vs device-side `pto.for_` / `pto.if_` | +| 6 | Scalar and pointer operations | +| 7 | Data movement: tile loads/stores, DMA, vector loads/stores, cube data movement | +| 8 | Compute operations: tile-level, vector, and cube arithmetic | +| 9 | Predicate and mask operations | +| 10 | Synchronization: barriers, flags, memory fences | +| 11 | Flash attention walkthrough | +| 12 | Additional examples | +| 13 | Migration from the old `@pto.vkernel`/`@pto.ckernel` API | +| 14 | Common errors and compatibility notes | diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md index 26b0ba58b..af0ec6298 100644 --- a/ptodsl/docs/user_guide/02-quick-start.md +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -1,78 +1,233 @@ -## Quick Start +# 2. Quick Start -**Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. +This chapter walks through a minimal but complete PTODSL kernel — elementwise vector addition — covering the essential concepts you need to start writing your own kernels. -TileLang DSL provides the following core constructs for kernel authoring: +## 2.1 A first kernel: elementwise vector add -- `TensorView` – Access global memory (GM) tensors -- `Tile` – Local computation buffers in unified buffer (UB) -- Base vector operations (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations +```python +from ptodsl import pto + + +@pto.jit(target="a5") +def vec_add(A, B, O, *, N: pto.constexpr): + """O = A + B, elementwise, for vectors of length N.""" + + # Describe the GM tensors. + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + # Allocate a UB tile to hold one block of each vector. + a_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) + + # Partition the GM views to cover the whole vector. + a_part = pto.partition_view(a_view, offsets=[0], sizes=[N]) + b_part = pto.partition_view(b_view, offsets=[0], sizes=[N]) + o_part = pto.partition_view(o_view, offsets=[0], sizes=[N]) + + # Load A and B from GM into UB tiles. + pto.tload(a_part, a_tile) + pto.tload(b_part, b_tile) + + # Elementwise add on the tiles. + pto.tadd(a_tile, b_tile, o_tile) + + # Store the result back to GM. + pto.tstore(o_tile, o_part) +``` + +Let us step through each piece. + +### The entry point + +```python +@pto.jit(target="a5") +def vec_add(A, B, O, *, N: pto.constexpr): +``` + +`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A`, `B`, `O` are Python-native tensors — they arrive from NumPy, torch-npu, or any framework that provides a shape and strides. The keyword-only argument `N` is a compile-time constant declared with `pto.constexpr`; the compiler specializes the kernel for each value of `N`. + +### Describing GM tensors + +```python +a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) +``` + +`make_tensor_view` wraps a Python tensor into a `TensorView` — a descriptor that tells the kernel how to address the tensor in global memory. You provide the logical shape and the stride (in elements) of each dimension. + +### Allocating on-chip buffers + +```python +a_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) +``` + +`alloc_tile` reserves space in the Unified Buffer (UB). A `Tile` is a 2D buffer that lives on-chip during kernel execution. Every tile has a `shape` and a `dtype`. + +### Partitioning GM views + +```python +a_part = pto.partition_view(a_view, offsets=[0], sizes=[N]) +``` + +`partition_view` creates a sub-view of a `TensorView` at a given offset and size. It describes *which part* of the GM tensor a `tload` or `tstore` should operate on. For this simple whole-vector example the offset is zero and the size equals the full length; in a blocked kernel you would slide the offset through a loop. + +### Moving data: tload and tstore + +```python +pto.tload(a_part, a_tile) # GM → UB +pto.tstore(o_tile, o_part) # UB → GM +``` + +`tload` copies a block of data from GM (described by a partition) into a UB tile. `tstore` copies a UB tile back to GM. These are **Tile Ops** — they operate on entire tile buffers at once. + +### Computing on tiles + +```python +pto.tadd(a_tile, b_tile, o_tile) +``` + +`tadd` performs elementwise addition of two tiles. The result is written to a third tile. PTODSL provides a rich set of Tile-level compute instructions — `texp`, `trowsum`, `tcvt`, `tsel`, and many more — covered in Chapter 8. + +## 2.2 A blocked version with a loop + +The kernel above assumes the entire vector fits in one UB tile. For vectors longer than the maximum tile size, you need to process them in blocks. The length `N` is not known until the kernel is launched — it comes from the actual input tensor: + +```python +@pto.jit(target="a5") +def vec_add_blocked(A, B, O, *, BLOCK: pto.constexpr): + N = A.shape[0] + + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + + num_blocks = (N + BLOCK - 1) // BLOCK + + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + + a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) + b_part = pto.partition_view(b_view, offsets=[offset], sizes=[BLOCK]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[BLOCK]) + + pto.tload(a_part, a_tile) + pto.tload(b_part, b_tile) + pto.tadd(a_tile, b_tile, o_tile) + pto.tstore(o_tile, o_part) +``` + +Here `N` is dynamic — it comes from `A.shape[0]` and can differ across launches. The loop bound `num_blocks` depends on `N`, so `pto.for_` records a structured loop in the IR rather than unrolling at trace time. The `BLOCK` parameter stays `constexpr` because it is a tuning knob, not data-dependent. Chapter 5 covers this distinction in detail. + +## 2.3 Compile and launch -A typical kernel follows the GM → UB → vector compute → GM pattern: +Once the kernel is defined, you compile it and then launch it: ```python -import tilelang_dsl as pto - -@pto.vkernel(target="a5", op="scale", dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32)]) -def tile_scale( - input_tensor: pto.TensorView, - output_tensor: pto.TensorView, - work_tile: pto.Tile, - scale_factor: pto.f32, -): - dim0 = 4 - dim1 = 16 - - # Stage one GM tile into UB. - # GM -> UB data movement (implementation detail) - - # Run vector compute over the UB tile using tile indexing sugar. - for i in range(0, dim0): - mask = pto.make_mask(pto.f32, PAT.ALL) - vec = pto.vlds(work_tile[i, 0:]) - scaled = pto.vmuls(vec, scale_factor, mask) - pto.vsts(scaled, work_tile[i, 0:], mask) - - # Write the UB result back to GM. - # UB -> GM data movement (implementation detail) +# Compile once, cache the result. +compiled = vec_add.compile(N=1024) + +# Allocate or obtain input/output tensors (NumPy, torch-npu, ...). +import numpy as np +A = np.random.randn(1024).astype(np.float32) +B = np.random.randn(1024).astype(np.float32) +O = np.empty_like(A) + +# Launch on the NPU. +compiled[1, None](A, B, O) ``` -The example illustrates the key components of a TileLang kernel: +- `.compile(**constexprs)` traces the kernel body, lowers it through the PTOAS pipeline, and returns a compiled handle. Repeated calls with the same configuration hit the cache. +- `compiled[grid, stream](args...)` launches the compiled kernel. `grid` is the number of SPMD blocks; `stream` is the NPU stream (or `None` for the default). -1. **`TensorView` parameters** – Access global memory tensors -2. **`Tile` parameters** – Local computation buffers in unified buffer (UB) -3. **Base vector operations** (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations +## 2.4 SPMD launch + +For workloads that can be parallelized across multiple blocks, specify a grid: + +```python +# Process batch * heads slices in parallel. +compiled[batch * heads, stream](Q, K, V, O) +``` -Here is a second example with two inputs and one output: +Inside the kernel, each block queries its index: ```python -@pto.vkernel( - target="a5", - op="elementwise_add", - dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)], -) -def elementwise_add( - lhs_gm: pto.TensorView, - rhs_gm: pto.TensorView, - out_gm: pto.TensorView, - lhs_tile: pto.Tile, - rhs_tile: pto.Tile, - dst_tile: pto.Tile, -): - dim0 = 4 - dim1 = 16 - - # GM -> UB data movement (implementation detail) - - for lane in range(0, 256, 64): - mask = pto.make_mask(pto.f32, PAT.ALL) - lhs_vec = pto.vlds(lhs_tile, lane) - rhs_vec = pto.vlds(rhs_tile, lane) - summed = pto.vadd(lhs_vec, rhs_vec, mask) - pto.vsts(summed, dst_tile, lane, mask) - - # UB -> GM data movement (implementation detail) +block_idx = pto.get_block_idx() +block_num = pto.get_block_num() ``` -Both examples follow the same fundamental pattern: load data from global memory into local tiles, perform vector operations, and store results back. The compiler automatically infers vector-scope boundaries for the base vector operations. The `Tile` parameters are specialized to concrete shapes during compilation. Later sections cover advanced features such as matchers, template slots, raw pointer operations, and explicit scope management with `strict_vecscope`. +This lets you map different data slices to different blocks — for example, one block per (batch, head) pair in flash attention. + +## 2.5 Dropping down to micro-instructions + +The examples above used Tile Ops (`tload`, `tadd`, `tstore`), which operate on entire tiles at once. When you need finer control — for instance, writing a custom softmax or an activation that maps directly to vector hardware — you can drop down to the micro-instruction level. This involves three layers working together: + +```python +# L3: hardware-bound SIMD kernel — vector instructions on individual rows. +@pto.simd +def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + rows: pto.i32, cols: pto.i32): + VEC = pto.elements_per_vreg(pto.f32) + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) + col_loop.update(remained=remained) + + +# L2: ukernel — DMA staging, then dispatch the SIMD kernel. +@pto.ukernel +def add_block(a_part: pto.PartitionTensorView, + b_part: pto.PartitionTensorView, + o_part: pto.PartitionTensorView, + a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + rows: pto.i32, cols: pto.i32): + pto.mte_load(a_part, a_tile) + pto.mte_load(b_part, b_tile) + pto.mem_bar(pto.BarrierType.SYNC) + + add_rows(a_tile, b_tile, o_tile, rows, cols) + pto.mem_bar(pto.BarrierType.SYNC) + + pto.mte_store(o_tile, o_part) + + +# L1: JIT entry — tile allocation, partitioning, launch. +@pto.jit(target="a5") +def vec_add_micro(A, B, O, *, BLOCK: pto.constexpr): + N = A.shape[0] + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + + num_blocks = (N + BLOCK - 1) // BLOCK + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) + b_part = pto.partition_view(b_view, offsets=[offset], sizes=[BLOCK]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[BLOCK]) + add_block(a_part, b_part, o_part, a_tile, b_tile, o_tile, 1, BLOCK) +``` + +- **L1 `@pto.jit`**: allocates tiles, partitions the GM views, and loops over blocks — the same tile-level orchestration as Section 2.2, but now calling a ukernel instead of Tile Ops. + +- **L2 `@pto.ukernel`**: stages data with `mte_load`, synchronizes with `mem_bar`, dispatches the SIMD kernel, synchronizes again, then writes back with `mte_store`. The ukernel owns the hardware-level sequencing. + +- **L3 `@pto.simd`**: the outer `pto.for_` iterates over rows, the inner `pto.for_` iterates over column chunks of the hardware vector width (`elements_per_vreg`). Each iteration loads a vector-width slice into a `vreg`, does the addition under a mask (for tail elements), and stores the result back. Both loops are recorded as structured control flow IR — the compiler decides whether to keep them or unroll them. +Chapter 3 covers the full decorator family; Chapters 7–10 cover each operation family in detail. diff --git a/ptodsl/docs/user_guide/03-kernel-declaration.md b/ptodsl/docs/user_guide/03-kernel-declaration.md deleted file mode 100644 index 73c9e1800..000000000 --- a/ptodsl/docs/user_guide/03-kernel-declaration.md +++ /dev/null @@ -1,528 +0,0 @@ -## Core Concepts - -### Kernel Declaration - -TileLang DSL exposes two kernel decorators: - -- `@pto.vkernel` for the Vector (AIV) execution model -- `@pto.ckernel` for the Cube (AIC) execution model - -#### Basic Syntax - -```python -@pto.vkernel( - target="a5", # Target architecture - op="pto.matmul ins(a, b) -> outs(c)", # PTO op + operand schema - dtypes=[(pto.f16, pto.f16, pto.f32)], # Type signatures - constraints=[ # Additional constraints - lambda a, b: a.shape[1] == b.shape[0], - lambda batch=1: batch >= 1, - ], - priority=100 # Priority for selection -) -def matmul_fallback(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # kernel implementation -``` - -#### Decorator Parameters - -| Parameter | Type | Required | Description | -|-----------|------|----------|-------------| -| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | -| `op` | `str` | No* | PTO operation matcher. Preferred form is schema mode: `"pto.op_name ins(in0, in1, ...) -> outs(out0, out1, ...)"`. Legacy bare-op form (`"pto.op_name"`) is still accepted for compatibility. **Mutually exclusive with `ops`**. | -| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. Use this when one descriptor should match multiple concrete ops (schema mode is currently only supported in `op`). | -| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands (inputs and outputs) in order. | -| `templates` | `Dict[str, Dict[str, str]]` | No | Static template-slot mappings. Each slot maps concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | -| `constraints` | `List[Callable[..., bool]]` | No | Additional selection-time predicates. Constraint arguments bind by name to kernel parameter proxy objects or `context_attrs` keys. Default: empty list. | -| `priority` | `int` | No | Selection priority when multiple kernels match. Higher values have higher priority. Default: `0`. | -| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | -| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is available in both modes and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | - -#### Operation Schema in `op` (ins/outs) - -`op` supports a schema string that declares how kernel parameter names map to PTO op operands: - -```python -op="pto.tadds ins(src, scalar) -> outs(dst)" -``` - -Schema form: - -```text - ins(, , ...) -> outs(, , ...) -``` - -Rules: - -1. `ins(...)` and `outs(...)` are both required in schema mode. -2. Names in `ins` and `outs` must be valid, unique Python identifiers. -3. The decorated function parameter list must exactly match `ins + outs` by both count and name. -4. MLIR function argument ordering is defined by schema order (`ins` first, then `outs`). -5. Constraint binding keeps using parameter names; schema mode makes these names explicit and stable. -6. Schema mode applies to `op=...` (single matcher op). `ops=[...]` remains bare-op matching. - -Example: - -```python -@pto.vkernel( - target="a5", - op="pto.tadds ins(src, scalar) -> outs(dst)", - dtypes=[(pto.f32, pto.f32, pto.f32)], -) -def template_tadds(src: pto.Tile, scalar: pto.f32, dst: pto.Tile): - return None -``` - -If names or order do not match, descriptor construction fails early with a schema mismatch error. - - -#### Type Matching Rules - -The `dtypes` parameter supports flexible type matching: - -1. **Concrete Types**: Exact type matches using DSL scalar types: - - `pto.f16`, `pto.f32`, `pto.bf16` - - `pto.i8`, `pto.si8`, `pto.ui8` - - `pto.i16`, `pto.si16`, `pto.ui16` - - `pto.i32`, `pto.si32`, `pto.ui32` - - `pto.i64`, `pto.si64`, `pto.ui64` - - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` - - Builtin vector operands still use their element dtype in `dtypes=[...]`. - For example, a parameter annotated as `ex_vec: pto.vector(pto.i16, (4,))` - contributes `pto.i16` to the signature tuple, while the vector shape - contract stays in the parameter annotation. - -2. **Type Wildcards**: Generic type patterns: - - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) - - `pto.AnyInt`: Matches any integer type (`i*`, `si*`, `ui*`) - - `pto.AnyType`: Matches any scalar type - - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) - -3. **Type Variables**: Named type variables that enforce consistency within a signature: - ```python - T = pto.TypeVar('T') # Define a type variable - - @pto.vkernel( - target="a5", - op="elementwise", - dtypes=[(T, T, T)], # All three operands must have the same type - constraints=[] - ) - def elementwise_same_type(x: pto.Tile, y: pto.Tile, out: pto.Tile) -> None: - # x, y, and out must have identical element types - pass - ``` - -4. **Mixed Signatures**: Multiple type signatures for the same operation: - ```python - @pto.vkernel( - target="a5", - op="add", - dtypes=[ - (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), # Float addition - (pto.AnyInt, pto.AnyInt, pto.AnyInt) # Integer addition - ] - ) - def generic_add(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Supports both float and integer types - pass - ``` - -#### Constraint System - -Constraints are compile-time predicates that refine kernel selection. In the current implementation, each entry in `constraints=[...]` is a Python callable returning `True` or `False`. - -##### Predefined Constraints - -| Constraint | Description | -|------------|-------------| -| `k_dim_aligned_64` | K dimension is aligned to 64 elements (for matmul kernels). | -| `continuous_memory` | Operands reside in contiguous memory regions. | -| `requires_ub_memory` | Operation requires Unified Buffer memory (vs. Global Memory). | -| `tensor_rank(rank)` | Operand tensor has specified rank (e.g., `tensor_rank(2)` for 2D tensors). | -| `broadcastable` | Operands are broadcastable according to NumPy-style broadcasting rules. | -| `static_shape` | All tensor dimensions are known at compile time (no dynamic shapes). | - -##### Logical Constraint Combinators - -| Combinator | Description | Example | -|------------|-------------|---------| -| `AnyOf(c1, c2, ...)` | At least one of the constraints must be satisfied. | `AnyOf(k_dim_aligned_64, continuous_memory)` | -| `AllOf(c1, c2, ...)` | All constraints must be satisfied. | `AllOf(tensor_rank(2), static_shape)` | -| `Not(c)` | The constraint must not be satisfied. | `Not(requires_ub_memory)` | - -##### Custom Constraints - -Users can define custom constraints using predicate functions: - -```python -# Define a custom constraint that consumes one context attr by name. -def large_batch(min_batch: int): - return lambda batch=0: batch >= min_batch - -@pto.vkernel( - target="a5", - op="pto.matmul ins(a, b) -> outs(c)", - dtypes=[(pto.f16, pto.f16, pto.f32)], - constraints=[large_batch(1024)] -) -def large_batch_matmul(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Optimized for large batch sizes - pass -``` - -Constraint callables bind by parameter name. - -- Kernel parameter names such as `src`, `dst`, `a`, `b` receive lightweight proxy objects, so constraints can use direct expressions like `src.shape[0] <= dst.shape[0]`. -- Extra `context_attrs` passed to `pto.select_kernel(...)` bind by key name, for example `batch`, `enabled`, or `expected_rows`. - -##### Parameter Proxy Objects - -When a constraint argument name matches a kernel parameter name, the callable receives a lightweight proxy object rather than raw Python data. - -- For `TensorView` parameters, the proxy exposes `rank`, `shape`, `strides`, `dtype`, and `memory_space`. -- For `Tile` parameters, the proxy exposes `rank`, `shape`, `valid_shape`, `dtype`, `memory_space`, and `config`. -- `shape`, `strides`, and `valid_shape` support index access such as `src.shape[0]` or `dst.valid_shape[1]`. -- Missing or not-yet-known metadata evaluates as "unknown", so comparisons conservatively pass rather than failing early. - -Example: - -```python -def tload_preconditions(src, dst): - logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] - logical_cols = src.shape[4] - return ( - src.rank == 5 - and src.strides[4] == 1 - and dst.valid_shape[0] <= logical_rows - and dst.valid_shape[1] <= logical_cols - and logical_rows <= dst.shape[0] - and logical_cols <= dst.shape[1] - ) - -@pto.vkernel( - target="a5", - op="pto.tload", - dtypes=[(pto.f32, pto.f32)], - constraints=[tload_preconditions], -) -def template_tload(src: pto.TensorView, dst: pto.Tile): - return None -``` - -This is the recommended constraint style for current TileLang DSL head. - -##### Builtin Vector Parameters - -When a kernel needs to match a builtin MLIR vector operand, annotate that -parameter with `pto.vector(element_dtype, shape)`. - -```python -@pto.vkernel( - target="a5", - op="pto.tmrgsort ins(src0, src1, tmp) -> outs(dst, ex_vec)", - dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.i16)], -) -def template( - src0: pto.Tile, - src1: pto.Tile, - tmp: pto.Tile, - dst: pto.Tile, - ex_vec: pto.vector(pto.i16, (4,)), -): - return None -``` - -Rules: - -- Use `pto.vector(...)` for builtin vector operands, not Python `list`. -- `shape` is a Python tuple. A 1-D vector of length 4 is written `(4,)`. -- `dtypes=[...]` still records only the element dtype for that operand (`pto.i16` - in the example above). -- `pto.vector(...)` is distinct from `pto.vreg(...)`: the former models builtin - `vector<...>`, the latter models fixed-width VPTO vector registers. - -#### Kernel Selection Mechanism - -When a PTO operation needs implementation, the system performs the following matching process: - -1. **Target Filtering**: Select kernels with matching `target` architecture. -2. **Operation Filtering**: Select kernels whose matcher metadata covers the concrete query op: - - `op="foo"` requires exact match - - `op="foo ins(...) -> outs(...)"` still matches by op name `foo`; `ins/outs` additionally defines parameter naming/order contract for descriptor validation and materialization - - `ops=[...]` requires the concrete query op to appear in that list -3. **Type Matching**: For each kernel's `dtypes` list, check if any signature matches the operation's operand types: - - Concrete types must match exactly. - - Wildcard types match according to their category. - - Type variables must be consistent within the signature. -4. **Constraint Validation**: For each matching kernel, evaluate all `constraints`. If any constraint fails, the kernel is rejected. -5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. -6. **Fallback**: If no kernel matches, compilation fails with an error. - -For multi-op descriptors selected through `ops=[...]`, `pto.select_kernel(...)` -also binds the concrete query op before materialization. This bound -`selected_op` is what template-slot expansion uses later. - -The package also exposes explicit selection utilities: - -```python -registry = pto.KernelRegistry() -registry.register(my_kernel) - -selected = pto.select_kernel( - "a5", - "matmul", - (pto.f16, pto.f16, pto.f32), - context_attrs={"k_aligned": True}, - registry=registry, -) -``` - -`pto.select_kernel(...)` also supports an opt-in diagnostics path for matcher debugging: - -```python -report = pto.select_kernel( - "a5", - "matmul", - (pto.f16, pto.f16, pto.f32), - context_attrs={"k_aligned": False}, - return_metadata=True, - include_mlir=False, -) -``` - -When `return_metadata=True`, the result is a `KernelSelectionReport` instead of one -selected descriptor. - -- `report.selected` carries the winner when one candidate is selected. -- `report.final_status` is one of `selected`, `no_candidate`, or `priority_tie`. -- `report.final_error` summarizes the final selection outcome. -- `report.candidates` contains one `KernelSelectionCandidateMetadata` per - `target/op`-matched descriptor, including `dtype_mismatch`, - `constraint_failed`, `constraint_error`, `priority_shadowed`, `selected`, and - `priority_tie` states. - -Constraint diagnostics in report mode include: - -- `failed_constraint_index` -- `failed_constraint_name` -- `failed_constraint_location` as `file:line` - -For best diagnostics, prefer splitting compound predicates into multiple -constraint entries instead of writing one large `cond0 and cond1 and cond2` -callable. Report mode can precisely identify which constraint entry failed, but -it does not introspect which sub-expression inside one Python boolean -expression returned `False`. - -When `include_mlir=True`, report mode also attempts `mlir_text()` for candidates -that pass constraint evaluation. - -- On success, the candidate carries `mlir_text`. -- On materialization failure such as missing `specialize()` bindings, the - candidate carries `mlir_error`. -- Use `include_mlir=False` to skip this extra materialization attempt. - -#### Examples - -##### Matmul with Multiple Implementations - -```python -# High-performance kernel for aligned K dimension -def k_aligned_64(k=0): - return k % 64 == 0 - -@pto.vkernel( - target="a5", - op="pto.matmul ins(a, b) -> outs(c)", - dtypes=[(pto.f16, pto.f16, pto.f32)], - constraints=[k_aligned_64], - priority=200 -) -def matmul_aligned_k(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Optimized implementation for aligned K - pass - -# General-purpose fallback -@pto.vkernel( - target="a5", - op="pto.matmul ins(a, b) -> outs(c)", - dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], - constraints=[], - priority=100 -) -def matmul_general(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Generic implementation - pass -``` - -##### Elementwise Operation with Type Polymorphism - -```python -def same_shape(a, b, out): - return a.shape[0] == out.shape[0] and b.shape[0] == out.shape[0] - -@pto.vkernel( - target="a5", - op="pto.add ins(a, b) -> outs(out)", - dtypes=[ - (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), - (pto.AnyInt, pto.AnyInt, pto.AnyInt) - ], - constraints=[same_shape] -) -def polymorphic_add(a: pto.Tile, b: pto.Tile, out: pto.Tile) -> None: - # Single implementation handles both float and integer types - dtype = a.element_type - all_mask = pto.make_mask(dtype, PAT.ALL) - # ... implementation using generic vector operations - pass -``` - -##### Constrained Convolution Kernel - -```python -def prefer_static_nhwc(src, weight): - return src.rank == 4 and weight.rank == 4 - -@pto.vkernel( - target="a5", - op="pto.conv2d ins(input, filter) -> outs(output)", - dtypes=[(pto.f16, pto.f16, pto.f32)], - constraints=[prefer_static_nhwc], - priority=150 -) -def conv2d_nhwc_f16_f32(input: pto.Tile, filter: pto.Tile, output: pto.Tile) -> None: - # Optimized for NHWC layout with static shapes - pass -``` - ---- - -### Cube Kernel Declaration - -Cube kernels target the AIC (Cube) hardware unit for matrix multiplication operations. Unlike Vector kernels, Cube kernels operate on raw `pto.ptr` pointers and do not use `vecscope` execution scopes. - -#### Basic Syntax - -```python -@pto.ckernel( - target="a5", - op="pto.mad", # concrete matcher op - dtypes=[(pto.f16, pto.f16, pto.f32)], # selection dtype signature - name="my_gemm", # optional registry/debug name -) -def gemm(inp: pto.TensorView): - # Cube kernel body — linear cube authoring IR - ... -``` - -#### Parameter Type Conventions - -Cube kernel parameters represent different roles in the data flow: - -| Parameter Type | Role | Description | -|---------------|------|-------------| -| `PartitionTensorView` | GM input/output | Tiled view of a logical tensor in GM, partitioned by the caller | -| `TensorView` | GM input/output | Full logical tensor view in GM (for non-partitioned use) | -| `Tile` (specific addr space) | Pre-allocated hardware buffer | Tile already allocated in LEFT/RIGHT/ACC/MAT/BIAS address space | -| `int` | Dimension | Scalar dimension parameter (M, K, N, etc.) | -| `pto.f16` / `pto.f32` etc. | Scalar | Scalar parameters (threshold, alpha, etc.) | - -GM payload is modeled through `TensorView` and `PartitionTensorView`. `Tile` -values represent staged hardware buffers allocated in concrete hardware address -spaces such as `MAT`, `LEFT`, `RIGHT`, `ACC`, and `BIAS` via `pto.Tile`. - -#### Decorator Parameters - -| Parameter | Type | Required | Description | -|-----------|------|----------|-------------| -| `target` | `str` | No | Target hardware architecture. Cube DSL v1 supports `"a5"`. Default: `"a5"`. | -| `op` | `str` | 与 `ops` 二选一 | Single concrete matcher op. Bare-op strings such as `"pto.mad"` are supported. **Mutually exclusive with `ops`**. | -| `ops` | `List[str]` | 与 `op` 二选一 | List of concrete matcher ops for shared-body selection and template-slot dispatch. **Mutually exclusive with `op`**. | -| `dtypes` | `List[Tuple[Type, ...]]` | Recommended | List of selection dtype signatures. For cube kernels, these signatures describe the concrete query op rather than necessarily mirroring the Python parameter list. | -| `templates` | `Dict[str, Dict[str, str]]` | No | Static template-slot mappings. Each slot maps concrete op names to real `pto.*` calls. Required when the kernel body uses `pto.tpl(...)`. | -| `name` | `str` | No | Descriptor name used for registration, debugging, and emitted symbol naming. Defaults to the decorated function name. | -| `priority` | `int` | No | Selection priority when multiple kernels match. Default: `0`. | - -#### Key Differences from `@pto.vkernel` - -| Feature | `@pto.vkernel` (Vector) | `@pto.ckernel` (Cube) | -|---------|--------------------------|------------------------| -| Hardware unit | AIV (Vector) | AIC (Cube) | -| Execution scope | `pto.vecscope` / `pto.strict_vecscope` | **No scope** — function body is linear IR | -| GM data input | `TensorView` / `Tile` | `TensorView` / `PartitionTensorView` | -| Operand abstraction | Tile + vector registers + masks | `pto.ptr` raw pointers | -| Core operations | Vector ALU, load/store | Data movement (cube_load/store) + matmul (mad) | -| Address spaces | GM, UB (VEC) | GM, MAT, LEFT, RIGHT, ACC, BIAS, UB | -| Generated IR attr | `#pto.kernel_kind` | `#pto.kernel_kind` | - -#### Programming Model - -Cube kernels follow a GM → L1 → L0 → compute → L0 → GM data flow: - -```python -@pto.ckernel( - target="a5", - op="pto.mad", - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="gemm", -) -def gemm(a_tv: pto.PartitionTensorView, # [M, K] in GM - b_tv: pto.PartitionTensorView, # [K, N] in GM - c_tv: pto.PartitionTensorView): # [M, N] in GM, output - # 1. Get GM pointers from PartitionTensorViews - a_ptr = a_tv.as_ptr() # -> pto.ptr - b_ptr = b_tv.as_ptr() # -> pto.ptr - c_ptr = c_tv.as_ptr() # -> pto.ptr - - # 2. Allocate L1 (MAT) tile buffers (returns Tile, then get ptr) - l1_a = pto.Tile([16, 32], pto.f16, pto.MemorySpace.MAT) - l1_b = pto.Tile([32, 16], pto.f16, pto.MemorySpace.MAT) - - # 3. Allocate L0 tile buffers - l0a = pto.Tile([16, 32], pto.f16, pto.MemorySpace.LEFT) - l0b = pto.Tile([32, 16], pto.f16, pto.MemorySpace.RIGHT) - l0c = pto.Tile([16, 16], pto.f32, pto.MemorySpace.ACC) - - # 4. GM → L1 data movement - pto.cube_load(a_ptr, l1_a.as_ptr(), 16, nburst=(1, 0, 0)) - pto.cube_load(b_ptr, l1_b.as_ptr(), 16, nburst=(1, 0, 0)) - - # 5. L1 → L0 data movement - pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), 16, 32) - pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), 32, 16) - - # 6. Matrix multiplication - pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), 16, 16, 32) - - # 7. L0C → GM writeback - pto.acc_store_gm( - l0c.as_ptr(), c_ptr, 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND - ) -``` - -This example shows a **full-pipeline** kernel that handles data movement and compute. Alternatively, a **pure-compute** kernel can take pre-allocated tiles directly: - -```python -@pto.ckernel( - target="a5", - op="pto.mad", - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="matmul_compute", -) -def matmul_compute(a_left: pto.Tile, # Pre-allocated LEFT tile (L0A) - b_right: pto.Tile, # Pre-allocated RIGHT tile (L0B) - c_acc: pto.Tile): # Pre-allocated ACC tile (L0C) - pto.mad_acc(a_left.as_ptr(), b_right.as_ptr(), c_acc.as_ptr(), 16, 16, 32) -``` - -#### Hardware Isolation - -- `@pto.ckernel` functions generate `#pto.kernel_kind` IR attribute. -- `@pto.vkernel` functions generate `#pto.kernel_kind` IR attribute. -- The IR verifier prevents Cube and Vector operations from appearing in the same function. -- The DSL semantic analyzer additionally checks that Cube kernel bodies do not contain Vector-specific operations (`vlds`, `vadd`, etc.) or `vecscope` scopes. -- Both kernel types can coexist in the same `.py` file; each compiles independently with conditional compilation macros (`__DAV_CUBE__` / `__DAV_VEC__`). - -For the complete Cube operation reference and `pto.Tile` constructor details, see [Cube Matrix Multiply Operations](12-cube-operations.md). diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md new file mode 100644 index 000000000..c7cd157eb --- /dev/null +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -0,0 +1,412 @@ +# 3. Kernel Entry Points and Sub-Kernels + +PTODSL provides five decorators that mark functions as PTO kernels, plus three context managers for inline use. This chapter is a reference for each entry point — its role, parameter contract, and boundary constraints. + +## 3.1 Decorator family overview + +``` +@pto.jit L1 Top-level JIT entry — compile, cache, launch +@pto.ukernel L2 Micro-instruction orchestration (MTE + sync) +@pto.cube L3 Matrix multiplication on the Cube unit +@pto.simd L3 Vector math on the SIMD unit +@pto.simt L3 Scalar compute on the SIMT unit +``` + +L3 sub-kernels can be invoked in two ways: + +1. **As decorated functions** (`@pto.cube` / `@pto.simd` / `@pto.simt`) — reusable, named sub-kernels that can be called from `@pto.ukernel` or directly from `@pto.jit`. +2. **As context managers** (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) — inline L3 blocks for quick prototyping or one-off compute snippets inside any kernel. + +Calling an L3 sub-kernel directly from `@pto.jit` skips the ukernel layer: you stage data with `tload`/`tstore` instead of `mte_load`/`mte_store`, and PTOAS handles the synchronization between Tile Ops and L3 compute automatically. This is the recommended path for most users — drop down to `@pto.ukernel` only when you need explicit control over micro-instruction ordering and synchronization. + +## 3.2 `@pto.jit` — top-level JIT entry + +### Role + +`@pto.jit` marks a function as a launchable PTO kernel. It owns compilation (tracing + lowering), caching, and runtime launch binding. This is the only decorator that can be invoked directly from the host — all other decorators define sub-kernels that are called from within `@pto.jit` or `@pto.ukernel`. + +### Signature + +```python +@pto.jit(target="a5") +def kernel_name( + tensor_arg_1, # Python-native tensor (positional) + tensor_arg_2, # Python-native tensor (positional) + ..., + *, + CONST_A: pto.constexpr = default, # compile-time constant (keyword-only) + CONST_B: pto.constexpr = default, # compile-time constant (keyword-only) +): +``` + +**Positional parameters** are Python-native tensors — they arrive from NumPy, torch-npu, or any framework with `.shape` and `.strides`. Inside the body, wrap them with `make_tensor_view` to create GM descriptors. + +**Keyword-only parameters** annotated with `pto.constexpr` are compile-time constants. They must be provided at `.compile()` time and cannot change between launches of the same compiled kernel. Use them for tile sizes, algorithmic knobs (e.g., `CAUSAL`), and other values that the compiler can specialize against. + +### Compilation and launch + +```python +# Compile (traces the body, lowers through PTOAS, caches the result) +compiled = kernel_name.compile(CONST_A=128, CONST_B=64) + +# Launch on NPU +compiled[grid, stream](tensor_1, tensor_2, ...) +``` + +- `.compile(**constexprs)` — traces the kernel body with the given constexpr values, lowers the IR, and returns a compiled handle. Subsequent calls with the same (function identity, constexpr values) hit the cache. +- `compiled[grid, stream](args...)` — launches the compiled kernel. `grid` is the number of SPMD blocks (an integer); `stream` is the NPU stream (`None` for default). + +### SPMD built-ins + +Available inside a `@pto.jit` body: + +| Built-in | Returns | Description | +|----------|---------|-------------| +| `pto.get_block_idx()` | `int` | Index of the current block (0-based) | +| `pto.get_block_num()` | `int` | Total number of blocks in the grid | +| `pto.get_subblock_idx()` | `int` | Index of the current sub-block | +| `pto.get_subblock_num()` | `int` | Total number of sub-blocks | + +### Typical body + +```python +@pto.jit(target="a5") +def my_kernel(A, B, O, *, BLOCK: pto.constexpr): + N = A.shape[0] + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + + num_blocks = (N + BLOCK - 1) // BLOCK + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) + b_part = pto.partition_view(b_view, offsets=[offset], sizes=[BLOCK]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[BLOCK]) + + pto.tload(a_part, a_tile) + pto.tload(b_part, b_tile) + pto.tadd(a_tile, b_tile, o_tile) + pto.tstore(o_tile, o_part) +``` + +### Calling L3 sub-kernels directly + +When you call an L3 sub-kernel directly from `@pto.jit`, data movement is handled by Tile Ops (`tload`/`tstore`) instead of MTE micro-instructions. PTOAS handles the synchronization between Tile Ops and L3 compute — the sub-kernel itself is unchanged: + +```python +@pto.cube +def my_matmul(a_tile, b_tile, l0a, l0b, acc, o_tile): + m = pto.tile_valid_rows(a_tile) + k = pto.tile_valid_cols(a_tile) + n = pto.tile_valid_rows(b_tile) + pto.mte_l1_l0a(a_tile, l0a, m, k) + pto.mte_l1_l0b(b_tile, l0b, k, n, transpose=True) + pto.mad(l0a, l0b, acc) + pto.mte_l0c_ub(acc, o_tile, m, n) + +@pto.jit(target="a5") +def my_kernel(A, B, O, *, BLOCK: pto.constexpr): + N = A.shape[0] + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32) + l0a = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32, memory_space=pto.MemorySpace.LEFT) + l0b = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32, memory_space=pto.MemorySpace.RIGHT) + acc = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) + + num_blocks = (N + BLOCK - 1) // BLOCK + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + a_part = pto.partition_view(a_view, offsets=[offset, 0], sizes=[BLOCK, BLOCK]) + b_part = pto.partition_view(b_view, offsets=[offset, 0], sizes=[BLOCK, BLOCK]) + o_part = pto.partition_view(o_view, offsets=[offset, 0], sizes=[BLOCK, BLOCK]) + + # Tile Ops stage data from GM to UB (replaces mte_load at L1) + pto.tload(a_part, a_tile) + pto.tload(b_part, b_tile) + + # Direct L3 call — PTOAS handles sync between tload and compute + my_matmul(a_tile, b_tile, l0a, l0b, acc, o_tile) + + pto.tstore(o_tile, o_part) +``` + +This is the recommended path for users who want hardware-unit compute without writing explicit MTE Ops and manual sync. Mixing direct L3 calls with Tile Ops and ukernel calls in the same `@pto.jit` body is supported — the compiler unifies the lowering. + +## 3.3 `@pto.ukernel` — micro-instruction orchestration + +### Role + +`@pto.ukernel` (short for *micro-instruction kernel*) is the entry point for writing PTO micro-instructions directly. Unlike `@pto.jit` where you work with tile-level ops (`tload`, `tadd`, etc.), a ukernel lets you write explicit MTE, SIMD, SIMT, and Cube instructions — staging data with `mte_load`, synchronizing with `mem_bar`, and dispatching L3 sub-kernels. This is an advanced programming mode for expert users who need precise control over instruction ordering and hardware-level data movement. + +### Signature + +```python +@pto.ukernel +def my_ukernel( + part: pto.PartitionTensorView, # GM partition descriptors + tile: pto.Tile, # UB tile buffers + scratch: pto.Tile, # cube-local scratch (LEFT, RIGHT, ...) + ptr: pto.ptr(dtype, space), # typed UB pointers + scalar: pto.i32, # PTO scalar values +): +``` + +Parameters are PTO-specific types — `Tile`, `PartitionTensorView`, `pto.ptr`, and PTO scalar types. Unlike `@pto.jit`, a ukernel does not accept Python-native tensors. + +### Typical body + +```python +@pto.ukernel +def process_block(k_part, v_part, k_tile, v_tile, + s_tile, o_tile, rows: pto.i32, cols: pto.i32): + # Stage current block from GM to UB + pto.mte_load(k_part, k_tile) + pto.mte_load(v_part, v_tile) + pto.mem_bar(pto.BarrierType.SYNC) + + # Dispatch sub-kernels + qk_matmul(q_tile, k_tile, s_tile) + pto.mem_bar(pto.BarrierType.SYNC) + + online_softmax(s_tile, o_tile, rows, cols) + pto.mem_bar(pto.BarrierType.SYNC) + + # Write result back + pto.mte_store(o_tile, o_part) +``` + +A ukernel stays below the tile-op boundary — GM↔UB movement is expressed with `mte_load`/`mte_store` (MTE Ops) rather than `tload`/`tstore`. + +## 3.4 `@pto.cube` — Cube unit sub-kernel + +### Role + +`@pto.cube` marks a function that executes on the Cube unit (matrix multiplication engine). It consumes UB-resident tiles and explicit cube-local scratch buffers. + +### Signature + +```python +@pto.cube +def my_cube_kernel( + input_tile: pto.Tile, # UB tile (source data) + output_tile: pto.Tile, # UB tile (destination) + left_scratch: pto.Tile, # LEFT buffer (cube-local) + right_scratch: pto.Tile, # RIGHT buffer (cube-local) + acc_scratch: pto.Tile, # ACC buffer (cube-local) +): +``` + +All parameters are `Tile` references. Tiles marked as cube-local must be allocated with the appropriate `memory_space` (e.g., `pto.MemorySpace.LEFT`, `pto.MemorySpace.ACC`). + +### Typical body + +```python +@pto.cube +def qk_matmul( + q_tile: pto.Tile, + k_tile: pto.Tile, + q_l0a: pto.Tile, + k_l0b: pto.Tile, + s_acc: pto.Tile, + s_tile: pto.Tile, +): + m = pto.tile_valid_rows(q_tile) + k = pto.tile_valid_cols(q_tile) + n = pto.tile_valid_rows(k_tile) + + pto.mte_l1_l0a(q_tile, q_l0a, m, k) + pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) + pto.mad(q_l0a, k_l0b, s_acc) + pto.mte_l0c_ub(s_acc, s_tile, m, n) +``` + +Cube-local state (LEFT, RIGHT, ACC, BIAS) never leaks into UB — it is the caller's responsibility to allocate scratch buffers and pass them in explicitly. + +**Invocation modes**: `@pto.cube` functions can be: +- Called from `@pto.ukernel` (manual MTE + sync in the ukernel's hands). +- Called directly from `@pto.jit` (compiler infers MTE + sync). +- Used inline as a context manager: `with pto.cube():` (see Section 3.7). + +## 3.5 `@pto.simd` — SIMD unit sub-kernel + +### Role + +`@pto.simd` marks a function that executes on the SIMD unit (vector engine). It operates on vector registers (`vreg`) loaded from UB tiles and stores results back to UB tiles. Vector registers are local to the function and never cross its boundary. + +### Signature + +```python +@pto.simd +def my_simd_kernel( + input_tile: pto.Tile, # UB tile + output_tile: pto.Tile, # UB tile + rows: pto.i32, # PTO scalar + cols: pto.i32, # PTO scalar +): +``` + +Parameters are UB `Tile` references and PTO scalar values (`pto.i32`, `pto.f32`, etc.). Scalar parameters may come from `lds` reads or compile-time constants. + +### Typical body + +```python +@pto.simd +def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + rows: pto.i32, cols: pto.i32): + VEC = pto.elements_per_vreg(pto.f32) + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) + col_loop.update(remained=remained) +``` + +The boundary contract: `vreg` values (`a_vec`, `b_vec`, `o_vec`) are local to the function. The only way to persist data across a `@pto.simd` call is to write it back to a UB tile via `vsts` (or `psts`, etc.). + +**Invocation modes**: `@pto.simd` functions can be: +- Called from `@pto.ukernel` (manual MTE + sync in the ukernel's hands). +- Called directly from `@pto.jit` (compiler infers MTE + sync). +- Used inline as a context manager: `with pto.simd():` (see Section 3.7). + +## 3.6 `@pto.simt` — SIMT unit sub-kernel + +### Role + +`@pto.simt` marks a function that executes on the SIMT unit. SIMT (Single Instruction, Multiple Threads) is a programming model where you write instructions in scalar syntax, and the hardware executes them in parallel across many threads — analogous to how a GPU SM runs a CUDA kernel. Each instruction appears to operate on a single element (`lds`, `sts`, `a + b`), but the same instruction is issued across a large number of work-items simultaneously. + +### Signature + +```python +@pto.simt +def my_simt_kernel( + tile: pto.Tile, # UB tile + ptr: pto.ptr(dtype, space), # typed UB pointer + scalar: pto.i32, # PTO scalar +): +``` + +### Typical body + +```python +@pto.simt +def blend_output_rows( + o_prev_tile: pto.Tile, pv_tile: pto.Tile, + alpha_tile: pto.Tile, beta_tile: pto.Tile, + o_next_tile: pto.Tile, + row_start: pto.i32, row_stop: pto.i32, valid_dim: pto.i32, +): + with pto.for_(row_start, row_stop, step=1) as row: + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + with pto.for_(0, valid_dim, step=1) as col: + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) +``` + +SIMT kernels read and write individual scalar elements from tiles. The unit executes the same scalar instruction across many work-items in parallel, making it efficient for per-element operations. + +**Invocation modes**: `@pto.simt` functions can be: +- Called from `@pto.ukernel` (manual MTE + sync in the ukernel's hands). +- Called directly from `@pto.jit` (compiler infers MTE + sync). +- Used inline as a context manager: `with pto.simt():` (see Section 3.7). + +## 3.7 Context manager syntax for L3 sub-kernels + +In addition to the decorator form, each L3 sub-kernel unit provides a context manager: `with pto.cube():`, `with pto.simd():`, and `with pto.simt():`. These open an inline L3 block without requiring a separate named function — useful for quick prototyping, one-off compute snippets, or when the logic is too trivial to extract. + +### Syntax + +```python +with pto.simd(): + # Direct L3 instructions — vreg ops, scalar loads/stores + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) +``` + +```python +with pto.simt(): + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) +``` + +```python +with pto.cube(): + pto.mte_l1_l0a(q_tile, q_l0a, m, k) + pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) + pto.mad(q_l0a, k_l0b, s_acc) + pto.mte_l0c_ub(s_acc, s_tile, m, n) +``` + +### Semantics + +- Inside the `with` block, instructions execute on the corresponding hardware unit. +- `vreg` values created inside `with pto.simd():` are scoped to the block — they do not escape. +- Cube-local scratch (`l0a`, `l0b`, `acc`) must be allocated by the caller before entering the block. +- The context manager form is equivalent to defining an inline anonymous sub-kernel. The compiler treats it identically to a named `@pto.simd` / `@pto.cube` / `@pto.simt` function. + +### Comparison + +| | Decorator form | Context manager form | +|---|---|---| +| Reuse | Named, callable from multiple call sites | Inline, single-use | +| Readability | Good for complex, multi-step logic | Good for short (3-10 line) snippets | +| Testing | Can be unit-tested independently | Tested only through the enclosing kernel | +| Cube-local args | Explicit parameters | Captured from enclosing scope | + +The two forms can be freely mixed in the same `@pto.jit` or `@pto.ukernel` body. + +## 3.8 Boundary contracts + +Data crosses decorator boundaries only through UB-backed tiles or typed UB pointers: + +| Boundary | Allowed | +|----------|---------| +| Host → `@pto.jit` | Python-native tensors | +| `@pto.jit` → `@pto.ukernel` | `Tile`, `PartitionTensorView`, `pto.ptr`, PTO scalars | +| `@pto.jit` → L3 sub-kernel (direct call) | `Tile`, PTO scalars (compiler handles MTE + sync) | +| `@pto.jit` → `with pto.{cube,sid,sitm}:` | `Tile` captured from enclosing scope | +| `@pto.ukernel` → L3 sub-kernel | `Tile`, PTO scalars | +| L3 sub-kernel → L3 sub-kernel | Not allowed (go through UB tiles via the caller) | +| `@pto.simd` → caller | Only via `vsts`/`psts` to UB tiles; `vreg` cannot escape | +| Cube-local → UB | Only via `mte_l0c_ub`; LEFT/RIGHT/ACC/BIAS are private | + +## 3.9 `pto.constexpr` + +`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time constant. The compiler specializes the kernel for each combination of constexpr values, and the compiled artifact is cached by those values. + +```python +@pto.jit(target="a5") +def kernel(A, *, BLOCK: pto.constexpr = 128, DTYPE: pto.constexpr = pto.f32): + ... +``` + +- Must appear as a keyword-only argument (after `*`). +- Must have a default value. +- Must be provided at `.compile()` time if the caller needs to override the default. +- Cannot change between launches of the same compiled instance — compile a new variant for a different value. + +`pto.constexpr` parameters can be used anywhere in the kernel body where a Python value is expected: tile shapes, loop bounds that are known at compile time, dtype arguments, etc. They are evaluated at trace time, so `for i in range(BLOCK)` would unroll `BLOCK` times. + +In contrast, values derived from runtime tensor shapes (e.g., `A.shape[0]`) are dynamic — they vary per launch and should be used with `pto.for_` to produce device-side loops. diff --git a/ptodsl/docs/user_guide/04-template-kernels.md b/ptodsl/docs/user_guide/04-template-kernels.md deleted file mode 100644 index 9fcda0fd0..000000000 --- a/ptodsl/docs/user_guide/04-template-kernels.md +++ /dev/null @@ -1,333 +0,0 @@ -### Template-based Kernel Authoring - -For operations that share similar computation patterns but differ in their core vector operations, the DSL supports template-based kernel authoring. This allows a single kernel implementation to serve multiple related operations through parameterized templates. - -#### Multi-operation Kernels with `ops` Parameter - -Instead of specifying a single `op` parameter, you can provide an `ops` list to match multiple operations: - -```python -@pto.vkernel( - target="a5", - ops=["tadd", "tsub", "tmul", "tdiv"], # List of operations - dtypes=[(T, T, T)], # Type signature using type variable - advanced=True, - templates={ - "core": { - "tadd": "vadd", - "tsub": "vsub", - "tmul": "vmul", - "tdiv": "vdiv", - } - } -) -def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - dtype = dst.element_type - rows, cols = dst.valid_shape - elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register - for row in range(0, rows, 1): - remained = cols - for col in range(0, cols, elems_per_vreg): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - out = pto.tpl("core", lhs, rhs, mask) # Template dispatch - pto.vsts(out, dst[row, col:], mask) -``` - -`op` and `ops` are mutually exclusive, and exactly one of them must be -provided. `ops=[...]` only widens the matcher set; callers still use -`pto.select_kernel(target, concrete_op, operand_types, ...)` with a concrete -PTO op such as `"tadd"` or `"tmul"`. - -#### Template System - -The template system consists of three components: - -1. **`templates` parameter**: A dictionary mapping template names to operation-specific implementations -2. **`pto.tpl()` function**: A compile-time placeholder that resolves to the appropriate implementation for the currently selected concrete op -3. **`ops` parameter**: Replaces the singular `op` parameter for multi-operation kernels - -##### Template Definition - -Templates are defined in the `templates` parameter of `@pto.vkernel`. Each template is a dictionary mapping operation names to implementation strings: - -```python -templates={ - "template_name": { - "op1": "implementation_for_op1", - "op2": "implementation_for_op2", - # ... - }, - "another_template": { - "op1": "different_implementation_for_op1", - # ... - } -} -``` - -Template-slot metadata is static and validated when the descriptor is -registered: - -- slot names must be non-empty strings -- mapping keys must be concrete ops covered by the descriptor matcher set -- mapping values must be supported real `pto.*` op names - -The implementation strings are typically vector operation names such as -`"vadd"`, `"vsub"`, `"vmul"`, and `"vdiv"`, which are resolved during kernel -expansion. - -##### Template Usage with `pto.tpl()` - -The `pto.tpl()` operation enables template dispatch for multi-operation kernels, allowing code reuse across related operations through compile-time substitution. - -#### `pto.tpl(template_name: str, *args) -> Any` - -**Description**: Template dispatch operation for multi-operation kernels. Resolves to different implementations based on the current operation being expanded. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `template_name` | `str` | Name of the template to dispatch | -| `*args` | `Any` | Positional arguments passed unchanged to the resolved real implementation | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `Any` | Result of the template implementation | - -**Behavior**: -- Only valid inside kernels decorated with `@pto.vkernel` that have a `templates` parameter -- The first argument must be a string literal template-slot name -- During kernel expansion for a specific operation `op_name`, `pto.tpl("template_name", ...)` is replaced with the implementation specified in `templates["template_name"]["op_name"]` -- The replacement is a direct compile-time substitution; positional arguments are passed unchanged -- Template implementations are typically string names of vector operations (e.g., `"vadd"`, `"vsub"`) -- `pto.select_kernel(...)` must bind a concrete op before template expansion can happen -- Python dict lookup, callable values, lambdas, and other runtime dispatch patterns are not part of the supported kernel-body surface - -**Example**: -```python -@pto.vkernel( - ops=["tadd", "tsub"], - dtypes=[(T, T, T)], - templates={ - "core": { - "tadd": "vadd", - "tsub": "vsub", - } - } -) -def elementwise_kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - # ... load vectors - result = pto.tpl("core", lhs, rhs, mask) # Expands to vadd for tadd, vsub for tsub - # ... store result -``` - -**Constraints**: -- Template names must be defined in the `templates` parameter of the `@pto.vkernel` decorator -- When a kernel body uses `pto.tpl("slot", ...)`, that slot must define an implementation for the currently selected concrete op -- Template implementations must be valid operation names in the DSL - -#### Decorator Parameters Update - -| Parameter | Type | Required | Description | -|-----------|------|----------|-------------| -| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | -| `op` | `str` | No* | Name of the PTO operation to match. **Mutually exclusive with `ops`**. | -| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. | -| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands. | -| `templates` | `Dict[str, Dict[str, str]]` | No | Static slot mappings from concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | -| `constraints` | `List[Constraint]` | No | Additional constraints that must be satisfied for kernel selection. | -| `priority` | `int` | No | Selection priority when multiple kernels match. Default: `0`. | -| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | -| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is mode-independent and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | - -**Note**: -- Either `op` or `ops` must be provided, but not both. -- `templates` is only needed when the kernel body uses `pto.tpl(...)`. -- `pto.select_kernel(...)` still queries with a concrete op even for `ops=[...]` descriptors. - -#### Advanced Template Patterns - -##### Multiple Templates per Kernel - -A kernel can define multiple templates for different aspects of the computation: - -```python -@pto.vkernel( - target="a5", - ops=["tadd_relu", "tsub_relu", "tadd_abs", "tsub_abs"], - dtypes=[(T, T, T)], - templates={ - "arithmetic": { - "tadd_relu": "vadd", - "tsub_relu": "vsub", - "tadd_abs": "vadd", - "tsub_abs": "vsub", - }, - "postprocess": { - "tadd_relu": "vrelu", - "tsub_relu": "vrelu", # Same activation for both - "tadd_abs": "vabs", - "tsub_abs": "vabs", - } - } -) -def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - # ... load vectors - arith_result = pto.tpl("arithmetic", lhs, rhs, mask) - postprocessed = pto.tpl("postprocess", arith_result, mask) - # ... store result -``` - -##### Compile-time Substitution Model - -Template-slot expansion happens before semantic checking and lowering: - -- `pto.select_kernel(...)` first binds a concrete op such as `"tadd"` -- the frontend then resolves `pto.tpl("core", ...)` using `templates["core"]["tadd"]` -- the placeholder is rewritten to a real `pto.*` call before semantic analysis -- diagnostics for unknown slots, missing mappings, or unsupported resolved surfaces are raised before any VPTO IR is generated - -#### Type Variables in Template Kernels - -Template kernels often use type variables to enforce type consistency: - -```python -T = pto.TypeVar('T') - -@pto.vkernel( - target="a5", - ops=["tadd", "tsub"], - dtypes=[(T, T, T)], # All three operands share type T - templates={ - "core": { - "tadd": "vadd", - "tsub": "vsub", - } - } -) -def typed_elementwise(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - # Type variable T ensures all tiles have same element type - dtype = dst.element_type # This is type T - # ... implementation -``` - -#### Selection Mechanism for Template Kernels - -When a PTO operation matches a template kernel: -1. The system selects the descriptor based on `op` exact match or `ops` list inclusion. -2. `pto.select_kernel(...)` binds the concrete query op as the descriptor's `selected_op`. -3. During frontend expansion, `pto.tpl()` calls are resolved using that bound concrete op. -4. For operation `"op_name"`, template `"template_name"` resolves to `templates["template_name"]["op_name"]`. -5. The resolved string (e.g., `"vadd"`) is replaced with the corresponding real DSL operation before semantic analysis and lowering. - -#### Example: Unified Arithmetic Kernel - -```python -T = pto.TypeVar('T') - -@pto.vkernel( - ops=["tadd", "tsub", "tmul", "tdiv", "tmax", "tmin"], - dtypes=[(T, T, T)], - advanced=True, - templates={ - "arithmetic": { - "tadd": "vadd", - "tsub": "vsub", - "tmul": "vmul", - "tdiv": "vdiv", - "tmax": "vmax", - "tmin": "vmin", - } - } -) -def unified_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - """Single implementation for six arithmetic operations.""" - dtype = dst.element_type - rows, cols = dst.valid_shape - elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register - - for row in range(0, rows, 1): - remained = cols - for col in range(0, cols, elems_per_vreg): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - out = pto.tpl("arithmetic", lhs, rhs, mask) - pto.vsts(out, dst[row, col:], mask) -``` - -#### Compile-time Specialization with `pto.constexpr` - -The `pto.constexpr` construct enables compile-time branching for kernel specialization, allowing different code paths to be selected based on static compile-time information. Unlike runtime conditionals that generate control flow, `pto.constexpr` branches are resolved during kernel descriptor materialization, with only the selected branch retained for lowering. - -**Syntax and Usage**: -```python -if pto.constexpr(condition): - # Branch taken if condition evaluates to True at compile time - ... -else: - # Branch taken if condition evaluates to False at compile time - ... -``` - -**Semantics**: -- The `condition` must be evaluable at compile time during kernel descriptor materialization. -- Only the selected branch is analyzed, semantically checked, and lowered to VPTO IR. -- The non-selected branch is discarded entirely and does not contribute to runtime control flow or value merging. -- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. - -**Comparison with Runtime Conditionals**: - -| Aspect | Runtime `if` | `pto.constexpr` | -|--------|--------------|-----------------| -| **Evaluation time** | Runtime | Compile-time (descriptor materialization) | -| **Control flow** | Generates `scf.if` with merge logic | No runtime control flow; branch eliminated | -| **Value merging** | Both branches must produce compatible values for merge | No value merging; only one branch exists after elimination | -| **Use case** | Dynamic decision making based on runtime values | Code generation specialization based on static parameters | - -**Typical Static Inputs**: -- Literal integers, booleans, and strings -- Data type symbols (`src.element_type`, `dst.element_type`) and comparisons derived from them -- Statically specialized `Tile.shape` and `Tile.valid_shape` values -- Frontend query helpers such as `pto.bytewidth(dtype)` and `pto.elements_per_vreg(dtype)` (which computes elements per vector register) - -**Constraints and Notes**: -- `TensorView.shape` and `TensorView.strides` may be represented by hidden kernel parameters rather than descriptor-time constants. They should not be assumed constexpr unless separately bound through specialization or other compile-time context. -- `pto.constexpr` is a frontend-only authoring construct; it does not correspond to any runtime VPTO instruction. - -**Guidelines**: -- Use `constraints=[...]` and `pto.select_kernel(...)` when specialization requires selecting an entirely different kernel descriptor. -- Use `pto.constexpr` when the kernel remains the same but internal regions require specialization based on compile-time parameters. - -**Example**: -```python -@pto.vkernel(target="a5", op="pto.trowsum") -def template_trowsum(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): - acc_dtype = tmp.element_type - dst_dtype = dst.element_type - acc_mask_1, _ = pto.make_mask(acc_dtype, 1) - dst_mask_1, _ = pto.make_mask(dst_dtype, 1) - - if pto.constexpr(acc_dtype != dst_dtype): - # Type conversion required - v_acc_casted = pto.vcvt(v_acc, dst_dtype, acc_mask_1) - pto.vsts(v_acc_casted, dst[row, 0:], dst_mask_1) - else: - # No conversion needed - pto.vsts(v_acc, dst[row, 0:], dst_mask_1) -``` - -### Value Model - -The DSL operates on symbolic values, not Python runtime values: -- **Constants**: Python literals that are typed to machine types -- **Operation results**: Values produced by DSL operations -- **Block arguments**: Values introduced by control flow structures - -### Memory Spaces - -The DSL supports different memory spaces: -- `MemorySpace.GM`: Global Memory -- `MemorySpace.UB`: Unified Buffer (local storage for vector computation) diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md new file mode 100644 index 000000000..f0944cbf4 --- /dev/null +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -0,0 +1,209 @@ +# 4. Type System and Buffer Management + +This chapter covers every type you can use in a PTODSL kernel, plus the operations for managing buffers in global memory (GM) and on-chip Unified Buffer (UB). + +## 4.1 Scalar types + +### Numeric scalar types + +| DSL Type | Description | Bit Width | +|----------|-------------|-----------| +| `pto.i1` | Boolean | 1 | +| `pto.i8` | 8-bit signless integer | 8 | +| `pto.si8` | 8-bit signed integer | 8 | +| `pto.ui8` | 8-bit unsigned integer | 8 | +| `pto.i16` | 16-bit signless integer | 16 | +| `pto.si16` | 16-bit signed integer | 16 | +| `pto.ui16` | 16-bit unsigned integer | 16 | +| `pto.i32` | 32-bit signless integer | 32 | +| `pto.si32` | 32-bit signed integer | 32 | +| `pto.ui32` | 32-bit unsigned integer | 32 | +| `pto.i64` | 64-bit signless integer | 64 | +| `pto.si64` | 64-bit signed integer | 64 | +| `pto.ui64` | 64-bit unsigned integer | 64 | +| `pto.f16` | Half-precision float | 16 | +| `pto.bf16` | Brain float 16 | 16 | +| `pto.f32` | Single-precision float | 32 | + +Python literals are automatically typed by the tracer: `bool` → `pto.i1`, `int` → context-dependent (typically `pto.i32` or `pto.i64`), `float` → `pto.f32`. + +For explicit typing, use type constructors: + +```python +x = pto.i32(1024) +y = pto.ui16(7) +z: pto.i32 = 1024 +``` + +### Low-precision types (storage only) + +The following types are available for storage and data movement, but **not** for computation. Use them to reduce memory bandwidth; convert to a compute-capable type before arithmetic. + +| DSL Type | Description | +|----------|-------------| +| `pto.hif8` | HiFloat8 format | +| `pto.f4e1m2x2` | 4-bit float (E1M2, 2-wide packed) | +| `pto.f4e2m1x2` | 4-bit float (E2M1, 2-wide packed) | +| `pto.f8e4m3` | 8-bit float (E4M3) | +| `pto.f8e5m2` | 8-bit float (E5M2) | + +### Integer literal guidance + +Prefer plain integer literals. Hex string literals are reserved for explicit bit-pattern authoring: + +```python +count = pto.i32(1024) +delta = pto.i16(-12) +hi_bit = pto.i32("0x80000000") # bit-pattern: -2147483648 +``` + +### Floating-point literal forms + +```python +a = pto.f16(-1.5) +b = pto.f32("inf") +c = pto.f32("-inf") +d = pto.f32("nan") +# Bit-pattern hex +f16_neg_inf = pto.f16("0xFC00") +``` + +## 4.2 Vector register type + +Vector registers hold a fixed 256-byte payload. `pto.vreg(dtype)` infers the element count automatically: + +| `dtype` | Result | Elements | +|---------|--------|----------| +| `pto.f32` / `pto.i32` / ... | `vreg<64xT>` | 64 | +| `pto.f16` / `pto.bf16` / `pto.i16` / ... | `vreg<128xT>` | 128 | +| `pto.i8` / `pto.si8` / `pto.ui8` | `vreg<256xT>` | 256 | + +Constraint: `element_count × bitwidth(dtype) = 2048`. + +Use `pto.elements_per_vreg(dtype)` to query the element count: + +```python +lanes = pto.elements_per_vreg(pto.f32) # 64 +``` + +### vbitcast + +Reinterpret the bits of a vector register as a different element type: + +```python +fvec = pto.vlds(ptr, offset) # !pto.vreg<64xf32> +ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> +f16_vec = pto.vbitcast(fvec, pto.f16) # !pto.vreg<128xf16> +``` + +`vbitcast` preserves the exact bit pattern (type punning). Use `vcvt` for numeric value conversion. + +## 4.3 Mask (predicate) types + +Masks are typed by bit granularity and must match the vector element width: + +| DSL Type | Granularity | Used with | +|----------|-------------|-----------| +| `pto.mask_b8` | 8-bit | `i8`, `si8`, `ui8` | +| `pto.mask_b16` | 16-bit | `f16`, `bf16`, `i16`, `si16`, `ui16` | +| `pto.mask_b32` | 32-bit | `f32`, `i32`, `si32`, `ui32` | + +Bitcast between mask types with `pto.pbitcast`: + +```python +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) +``` + +## 4.4 Pointer types + +Pointers combine an element type and a memory space: + +```python +ptr_gm = pto.ptr(pto.f32, pto.MemorySpace.GM) +ptr_ub = pto.ptr(pto.f16, pto.MemorySpace.UB) +``` + +### MemorySpace enum + +| Enum Value | Description | +|------------|-------------| +| `MemorySpace.GM` | Global Memory (off-chip HBM) | +| `MemorySpace.UB` | Unified Buffer (on-chip scratchpad) | +| `MemorySpace.MAT` | Cube L1 / cbuf staging buffer | +| `MemorySpace.LEFT` | Cube L0A left-operand buffer | +| `MemorySpace.RIGHT` | Cube L0B right-operand buffer | +| `MemorySpace.ACC` | Cube L0C accumulator buffer | +| `MemorySpace.BIAS` | Cube bias table buffer | + +## 4.5 TensorView + +`TensorView` is a descriptor for a tensor in Global Memory. Create one inside a `@pto.jit` body with `make_tensor_view`: + +```python +@pto.jit(target="a5") +def kernel(A, *, BLOCK: pto.constexpr): + tv = pto.make_tensor_view(A, shape=[N], strides=A.strides) +``` + +`make_tensor_view` wraps a Python-native tensor. You provide the logical shape and the stride of each dimension in **elements** (not bytes). The resulting `TensorView` can be partitioned for `tload`/`tstore`. + +### TensorView attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Logical dimensions (up to 5D) | +| `element_type` | `Type` | Element dtype (e.g., `pto.f32`) | +| `strides` | `tuple[int, ...]` | Stride of each dimension, in elements | + +Strides support non-contiguous tensors. Pass `strides=A.strides` from the source tensor for the default row-major layout, or supply explicit strides for sub-views. Use `tv.as_ptr()` to obtain a typed GM pointer for use with MTE Ops in a ukernel. + +## 4.6 PartitionTensorView + +`partition_view` creates a sub-view of a TensorView at a given offset and size. It describes *which part* of the GM tensor a `tload` or `tstore` should operate on: + +```python +part = pto.partition_view(tv, offsets=[row_offset, 0], sizes=[BLOCK, dim]) +``` + +The result is a `PartitionTensorView` — a lightweight descriptor, not a data buffer. It carries the partition's shape, strides, and element type (inherited from the source TensorView). Use `part.as_ptr()` to obtain a typed GM pointer for MTE Ops in a ukernel. + +## 4.7 Tile + +A `Tile` is an on-chip buffer allocated in UB or cube-local memory. Allocate tiles with `alloc_tile`: + +```python +# UB tile +a_tile = pto.alloc_tile(shape=[BLOCK, dim], dtype=pto.f32) + +# Cube-local scratch with explicit memory space +q_l0a = pto.alloc_tile(shape=[Br, dim], dtype=pto.f16, memory_space=pto.MemorySpace.LEFT) +s_acc = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) +``` + +`alloc_tile` returns a `Tile` object. The `shape` must be a compile-time constant. The default memory space is UB. + +### Tile attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Physical tile dimensions (compile-time constant) | +| `element_type` | `Type` | Element dtype | +| `memory_space` | `MemorySpace` | Where the tile lives (UB, LEFT, RIGHT, ACC, BIAS) | +| `valid_shape` | `tuple[int, ...]` | Logical data region, ≤ `shape` in each dimension | + +### Tile methods + +| Method | Description | +|--------|-------------| +| `tile.fill(value)` | Fill the entire tile with a scalar value | +| `tile.as_ptr()` | Obtain a typed pointer to the tile's base address | + +```python +m_prev_tile.fill(float("-inf")) +l_prev_tile.fill(0.0) + +rows = q_tile.valid_shape[0] +cols = k_tile.valid_shape[1] + +meta_ptr = meta_tile.as_ptr() +``` diff --git a/ptodsl/docs/user_guide/05-control-flow.md b/ptodsl/docs/user_guide/05-control-flow.md new file mode 100644 index 000000000..2dc4ce142 --- /dev/null +++ b/ptodsl/docs/user_guide/05-control-flow.md @@ -0,0 +1,228 @@ +# 5. Control Flow + +PTODSL uses a **tracing** compilation model. When you call `kernel.compile(...)`, PTODSL executes your Python function body once to record every PTO instruction — this pass is called *tracing*. The recorded program is then lowered and optimized into device code. Once compiled, launching the kernel runs the already-built device code directly on the NPU. + +This has one critical implication for how you write loops and branches: + +- **Python native `for`/`if`** runs at trace time. A `for i in range(4)` loop gets unrolled — the device code contains four copies of the body, not a loop instruction. An `if` condition is evaluated at trace time, and only the taken branch is recorded. +- **`pto.for_` / `pto.if_`** produce device-side control flow. The loop bound or branch condition can be a runtime value, and the hardware will execute the loop or take the branch dynamically. + +**Simple rule: Python control flow = trace time (compile-time). `pto.*` control flow = device-side (runtime).** + +## 5.1 Python native `for` — trace-time unrolling + +When you write a plain Python `for` loop inside a kernel body, Python executes it immediately during tracing. Each iteration records its instructions separately, so the device code gets a linear sequence with the body repeated: + +```python +@pto.jit(target="a5") +def unrolled_kernel(A, O, *, N: pto.constexpr): + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + # N is constexpr, so range(N) is known at trace time. + # The loop unrolls: the device gets N copies of the body. + for i in range(N): + a_part = pto.partition_view(a_view, offsets=[i], sizes=[1]) + o_part = pto.partition_view(o_view, offsets=[i], sizes=[1]) + a_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + pto.tload(a_part, a_tile) + pto.tadd(a_tile, a_tile, o_tile) + pto.tstore(o_tile, o_part) +``` + +This works when the loop bound is a compile-time constant (like a `constexpr` parameter). But if `N` comes from a tensor shape and varies per launch, `range(N)` would trace a different number of iterations each time — you would get a cache miss and recompilation on every new value. For dynamic bounds, use `pto.for_`. + +## 5.2 `pto.for_` — device-side loops + +`pto.for_` records a structured loop that executes on the device. Its bound can be any expression involving runtime values (tensor shapes, scalar computations, block indices), and the compiler may optimize it further — unrolling when the bound is known at compile time, or keeping it as a runtime loop otherwise. + +### Basic form + +```python +with pto.for_(start, stop, step) as iv: + # iv is the loop index (0-based relative to start) + ... +``` + +- `start`, `stop`, `step` are PTO scalar expressions. They are evaluated on the device. +- The loop body executes `(stop - start + step - 1) // step` times. +- Use with `step=1` unless you need a strided iteration. + +Compare the two approaches: + +```python +# Trace-time unrolling — BLOCK must be constexpr +for i in range(BLOCK): + ... + +# Device-side loop — num_blocks can be dynamic +with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + ... +``` + +### Nested loops + +```python +with pto.for_(0, rows, step=1) as r: + with pto.for_(0, cols, step=1) as c: + val = scalar.load(tile[r, c]) + ... +``` + +Both loops execute on the device. The outer loop bound `rows` and inner loop bound `cols` can be runtime values. + +### Loop with carry state + +When a loop needs to propagate state from one iteration to the next, use the `.carry(...)` method. This is the PTODSL equivalent of a loop that accumulates or updates variables across iterations: + +```python +kv_loop = pto.for_(0, num_blocks, step=1).carry( + m=m_prev_tile, + l=l_prev_tile, + o=o_prev_tile, +) +with kv_loop: + i = kv_loop.iv # current iteration index + m_cur = kv_loop.m # value carried in from previous iteration + l_cur = kv_loop.l + o_cur = kv_loop.o + + # ... compute m_next, l_next, o_next from m_cur, l_cur, o_cur ... + + kv_loop.update( + m=m_next_tile, + l=l_next_tile, + o=o_next_tile, + ) + +# After the loop, retrieve the final carried values +final_o = kv_loop.final("o") +``` + +`.carry(name=initial_value)` declares named state variables that are passed from one iteration to the next. Inside the loop body, access the current value with `loop.name`. At the end of the body, call `loop.update(name=new_value)` to set what the next iteration receives. After the loop exits, `loop.final("name")` retrieves the value from the last iteration. + +This pattern is central to algorithms like online softmax, where each KV block updates running statistics (row max, sum, output accumulator). The ping-pong tile pattern — allocating two tiles and swapping them each iteration — is the idiomatic way to manage this state: + +```python +# Allocate ping-pong state tiles +m_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +m_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +l_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +l_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + +# Initialize prev tiles +m_prev.fill(float("-inf")) +l_prev.fill(0.0) + +loop = pto.for_(0, num_blocks, step=1).carry(m=m_prev, l=l_prev) +with loop: + m_cur = loop.m + l_cur = loop.l + + # ... compute new m and l into m_next, l_next ... + + loop.update(m=m_next, l=l_next) +``` + +### Chunked inner loop with carry (tail handling) + +For SIMD kernels that process data in vector-width chunks, use a carry loop to track the remaining element count across column iterations: + +```python +VEC = pto.elements_per_vreg(pto.f32) +col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) +with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(tile[r, c:]) + # ... operate under mask ... + pto.vsts(vec, out_tile[r, c:], mask) + col_loop.update(remained=remained) +``` + +`make_mask(dtype, n)` returns two values: the predicate mask for the current chunk and the updated remaining count. Passing the updated count back via `col_loop.update(remained=...)` feeds it into the next iteration, so each chunk correctly computes how many elements are left. + +## 5.3 `pto.if_` — device-side conditionals + +`pto.if_` records a device-side conditional branch. Unlike a Python `if`, the condition can be a runtime PTO scalar, and both branches are recorded into the program so the hardware can choose at runtime. + +The condition must be a PTO scalar value (e.g., the result of a comparison like `scalar.gt(a, b)` or a value loaded from a tile). Python booleans evaluated at trace time should use a plain `if` instead. + +### Value merge across branches + +When a variable is assigned inside both branches of `pto.if_`/`pto.else_`, the assignments are recorded and the variable holds the merged value after the conditional block. This is the standard SSA-style merge — the downstream code sees whichever value was produced by the taken branch: + +```python +@pto.simt +def conditional_scale( + tile: pto.Tile, + threshold: pto.f32, + scale: pto.f32, + rows: pto.i32, + cols: pto.i32, +): + with pto.for_(0, rows, step=1) as r: + with pto.for_(0, cols, step=1) as c: + val = scalar.load(tile[r, c]) + big = scalar.gt(val, threshold) + + with pto.if_(big): + # Branch A: scale the value up + val = val * scale + with pto.else_(): + # Branch B: leave it as-is + pass + + # val is usable here — it is the merged result from both branches. + # If big was true, val = original * scale. + # If big was false, val = original (passed through unchanged). + scalar.store(val, tile[r, c]) +``` + +In this example, `val` is reassigned in the `if_` branch but left untouched in the `else_` branch. After the conditional block, `val` correctly represents the merged result and is stored back to the tile. You can reassign the same variable in both branches as well — the downstream code always sees the correct value. + +### Expression form + +For simple either-or selection, `pto.if_` also works as an expression that directly returns the merged value: + +```python +result = pto.if_(cond, then_value, else_value) +``` + +This is equivalent to the block form above and is convenient when each branch simply produces a different scalar or tile reference. + +## 5.4 `pto.constexpr` and tracing + +`pto.constexpr` parameters (Section 3.8) are compile-time constants. They are fixed at `.compile()` time and cannot change between launches of the same compiled kernel. Because their values are known during tracing, they interact naturally with Python control flow: + +```python +@pto.jit(target="a5") +def kernel(A, *, BLOCK: pto.constexpr = 128, UNROLL: pto.constexpr = False): + N = A.shape[0] + num_blocks = (N + BLOCK - 1) // BLOCK + + if UNROLL: + # Trace-time: UNROLL is known, so this branch resolves at compile time. + # Each iteration records separately — the loop is fully unrolled. + for i in range(num_blocks): + ... + else: + # Device-side: a single loop instruction is recorded. + with pto.for_(0, num_blocks, step=1) as i: + ... +``` + +This lets you write a single kernel that specializes into different strategies based on constexpr knobs. + +## 5.5 Summary + +| Construct | When evaluated | Use for | +|-----------|---------------|---------| +| Python `for` | Trace time | Bounds known at compile time (constexpr), deliberate unrolling | +| Python `if` | Trace time | Conditions known at compile time, variant selection | +| `pto.for_` | Device-side | Dynamic bounds, runtime loop counts | +| `pto.for_(...).carry(...)` | Device-side | Loops with accumulated state across iterations | +| `pto.if_` | Device-side | Runtime conditions, data-dependent branching | diff --git a/ptodsl/docs/user_guide/05-type-system.md b/ptodsl/docs/user_guide/05-type-system.md deleted file mode 100644 index c40f12475..000000000 --- a/ptodsl/docs/user_guide/05-type-system.md +++ /dev/null @@ -1,686 +0,0 @@ - - -## Type System - -### Scalar Types - -| DSL Type | Description | Bit Width | -|----------|-------------|-----------| -| `pto.i1` | Boolean | 1 | -| `pto.i8` | 8-bit signless integer | 8 | -| `pto.si8` | 8-bit signed integer | 8 | -| `pto.ui8` | 8-bit unsigned integer | 8 | -| `pto.i16` | 16-bit signless integer | 16 | -| `pto.si16` | 16-bit signed integer | 16 | -| `pto.ui16` | 16-bit unsigned integer | 16 | -| `pto.i32` | 32-bit signless integer | 32 | -| `pto.si32` | 32-bit signed integer | 32 | -| `pto.ui32` | 32-bit unsigned integer | 32 | -| `pto.i64` | 64-bit signless integer | 64 | -| `pto.si64` | 64-bit signed integer | 64 | -| `pto.ui64` | 64-bit unsigned integer | 64 | -| `pto.f16` | Half precision float | 16 | -| `pto.bf16` | Brain float 16 | 16 | -| `pto.f32` | Single precision float | 32 | - -Python literals are automatically typed: -- `bool` → `pto.i1` -- `int` → Context-dependent (typically `pto.i32` or `pto.i64`) -- `float` → `pto.f32` - -For explicit typing, use type constructors: -```python -x = pto.i32(1024) # Explicit i32 constant -y: pto.i32 = 1024 # Type annotation -z = pto.ui16(7) # Explicit unsigned 16-bit constant -``` - -Static dtype bindings can also be called like constructors. This is useful when -the dtype comes from compile-time metadata such as `element_type`: - -```python -idx_dtype = tile.element_type -zero_idx = idx_dtype(0) -v_col = idx_dtype(col) -``` - -Integer sign semantics are part of the DSL type surface. `pto.si16`, -`pto.ui16`, and `pto.i16` are distinct scalar dtypes and lower to `si16`, -`ui16`, and `i16` respectively in VPTO IR. - -### Integer Literal Guidance - -For ordinary integer constants, prefer plain integer literals instead of -string forms. - -```python -count = pto.i32(1024) -delta = pto.i16(-12) -min_i32 = pto.i32(-2147483648) -unsigned_hi = pto.ui16(32768) -``` - -Integer string literals are reserved for explicit bit-pattern authoring. They -must use hex form. - -```python -# Use hex strings only when you intentionally want fixed-width bit-pattern -# interpretation at the target dtype width. -hi_bit = pto.i32("0x80000000") # -2147483648 -all_ones = pto.i16("0xFFFF") # -1 -unsigned_hi = pto.ui16("0x8000") # 32768 -``` - -Rules: -- Prefer plain integer literals such as `pto.i32(1024)` or `pto.i16(-12)` for normal integer authoring. -- Integer string literals must use hex bit-pattern form such as `"0xFFFF"`. -- Ordinary integer strings such as `"1024"` or `"-12"` are rejected; write them as integer literals instead. -- For signed and signless integer dtypes (`pto.i*`, `pto.si*`), hex strings use two's-complement interpretation at the target dtype width. -- For unsigned integer dtypes (`pto.ui*`), hex strings keep their unsigned value. -- Hex strings must fit within the target bit width. For example, `pto.i16("0x10000")` is rejected because the literal exceeds 16 bits. - -### Floating-Point Literal Forms - -`pto.f16(...)`, `pto.bf16(...)`, and `pto.f32(...)` accept multiple literal forms. - -```python -# Signed numeric literals -a = pto.f16(-1.5) -b = pto.bf16(+2.5) -c = pto.f32(-3.5) - -# Special floating-point values -pos_inf = pto.f32("inf") -neg_inf = pto.f32("-inf") -qnan = pto.f32("nan") - -# Bit-pattern form (hex string, interpreted by target dtype) -f16_neg_inf = pto.f16("0xFC00") -bf16_neg_inf = pto.bf16("0xFF80") -f32_neg_inf = pto.f32("0xFF800000") -``` - -Notes: -- Prefer dtype constructors for reduction seeds and boundary values (for example rowmax initialization). -- For float bit-pattern constants, pass a **string** hex literal to the matching dtype constructor. -- Avoid passing raw integer bit-patterns directly into vector broadcast/dup APIs when a floating vector is expected. -- `float(...)` function calls are not part of the TileLang DSL public call surface; use constructor forms above. - -### Vector Register Type - -Vector registers have fixed 256-byte width: - -```python -v_f32 = pto.vreg(pto.f32) # !pto.vreg<64xf32> -v_f16 = pto.vreg(pto.f16) # !pto.vreg<128xf16> -v_i8 = pto.vreg(pto.i8) # !pto.vreg<256xi8> -``` - -`pto.vreg(dtype)` only takes the element type. The frontend infers the element count automatically from the fixed 256-byte register width: - -- `pto.f32` → `!pto.vreg<64xf32>` -- `pto.f16` → `!pto.vreg<128xf16>` -- `pto.bf16` → `!pto.vreg<128xbf16>` -- `pto.i32` → `!pto.vreg<64xi32>` -- `pto.si32` → `!pto.vreg<64xsi32>` -- `pto.ui32` → `!pto.vreg<64xui32>` -- `pto.i16` → `!pto.vreg<128xi16>` -- `pto.si16` → `!pto.vreg<128xsi16>` -- `pto.ui16` → `!pto.vreg<128xui16>` -- `pto.i8` → `!pto.vreg<256xi8>` -- `pto.si8` → `!pto.vreg<256xsi8>` -- `pto.ui8` → `!pto.vreg<256xui8>` - -Constraint: `element_count × bitwidth(element_type) = 2048` - -Use `pto.elements_per_vreg(dtype)` when you need the inferred element count explicitly: - -```python -v_dtype = pto.vreg(pto.f32) -lanes0 = v_dtype.elements_per_vreg # 64 -lanes1 = pto.elements_per_vreg(pto.f32) # 64 -``` - -Current TileLang DSL v1 vector lowering supports the 8/16/32-bit integer -families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32` element types. - -### Builtin Vector Type - -TileLang DSL v1 also exposes builtin MLIR vector types through -`pto.vector(element_dtype, shape)`. - -```python -executed_ty = pto.vector(pto.i16, (4,)) # vector<4xi16> -``` - -This type is different from `pto.vreg(...)`: - -- `pto.vreg(dtype)` models a VPTO vector register with fixed 256-byte width. -- `pto.vector(dtype, shape)` models a builtin MLIR `vector<...>` type with an - explicit static shape. - -Use `pto.vector(...)` when a kernel parameter or intermediate value must match -an existing builtin vector operand in PTO IR, for example an auxiliary -`vector<4xi16>` operand carried by a tile op template. - -```python -@pto.vkernel( - target="a5", - op="pto.tmrgsort ins(src0, src1, tmp) -> outs(dst, ex_vec)", - dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.i16)], -) -def template( - src0: pto.Tile, - src1: pto.Tile, - tmp: pto.Tile, - dst: pto.Tile, - ex_vec: pto.vector(pto.i16, (4,)), -): - return None -``` - -Notes: - -- `shape` must be a Python tuple of integers. For a 1-D vector, write `(4,)`, - not `(4)`. The trailing comma is Python's single-element tuple syntax. -- The current public surface is intended for static builtin vector types. -- In descriptor `dtypes=[...]`, builtin vector operands are matched by their - element dtype (`pto.i16` in the example above). The vector shape contract is - carried by the parameter annotation `pto.vector(...)`. - -### Vector Type Reinterpretation (vbitcast) - -Vector registers support bitwise type reinterpretation via `pto.vbitcast`: - -```python -result = pto.vbitcast(vector, to_type) -``` - -Interface summary: -- `vector`: a vector register value of type `!pto.vreg` -- `to_type`: target element dtype such as `pto.i32`, `pto.ui32`, `pto.f16`, `pto.bf16`, `pto.f32` -- return: a new vector register `!pto.vreg` whose element count is inferred from the fixed 256-byte vreg width - -Constraints: -- `vector` must be a vreg value; scalar values, pointers, `Tile`, and `TensorView` are rejected -- `to_type` must be a DSL-supported vreg element dtype -- `vbitcast` preserves the total register storage size, so only reinterpretations with the same total bit count are allowed -- the operation has no mask, rounding, saturation, or lane-placement parameters - -Lane count is recomputed from `to_type`: -- `!pto.vreg<64xf32> + pto.i32 -> !pto.vreg<64xi32>` -- `!pto.vreg<64xf32> + pto.f16 -> !pto.vreg<128xf16>` -- `!pto.vreg<128xbf16> + pto.ui16 -> !pto.vreg<128xui16>` - -```python -# Float to integer bitwise reinterpretation -fvec = pto.vlds(ub_ptr, lane) # !pto.vreg<64xf32> -ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> - -# Signed to unsigned integer reinterpretation -signed_vec = pto.vlds(ptr, lane) # !pto.vreg<64xsi32> -unsigned_vec = pto.vbitcast(signed_vec, pto.ui32) # !pto.vreg<64xui32> - -# Element size change (32-bit to 16-bit) -f32_vec = pto.vlds(ptr, lane) # !pto.vreg<64xf32> -f16_vec = pto.vbitcast(f32_vec, pto.f16) # !pto.vreg<128xf16> -``` - -Pythonic syntax sugar via `astype()` method: - -```python -ivec = fvec.astype(pto.i32) # Float to integer -unsigned_vec = signed_vec.astype(pto.ui32) # Signed to unsigned -f16_vec = f32_vec.astype(pto.f16) # 32-bit to 16-bit -``` - -`astype()` on a vector register is syntax sugar for `pto.vbitcast(...)`. In other words, it is a bit reinterpretation API, not a numeric conversion API. - -**Note**: `vbitcast` preserves the exact bit pattern (type punning), unlike `vcvt` which performs value conversion with rounding/saturation. Use `vcvt` when you want numeric conversion semantics; use `vbitcast` when you want the bits to stay unchanged. - -### Typed Masks - -Masks are typed by their bit granularity: - -| DSL Type | VPTO Type | Description | -|----------|-----------|-------------| -| `pto.mask_b8` | `!pto.mask` | 8-bit granularity mask | -| `pto.mask_b16` | `!pto.mask` | 16-bit granularity mask | -| `pto.mask_b32` | `!pto.mask` | 32-bit granularity mask | - -```python -mask_ty = pto.mask_b32 -mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) -``` - -Typed masks also support explicit type reinterpretation via `pto.pbitcast`: - -```python -mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) -mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) -mask_b32 = pto.pbitcast(mask_b16, pto.mask_b32) -``` - -`pto.pbitcast(...)` is the predicate analogue of `pto.vbitcast(...)`: -- it changes the static mask granularity seen by later DSL/VPTO consumers -- it preserves the underlying predicate bit image -- it does not perform pack/unpack or interleave/deinterleave by itself - -Mask operations must match the vector element family: -- `f32`, `i32`, `si32`, and `ui32` vectors use `mask_b32` -- `f16`, `bf16`, `i16`, `si16`, and `ui16` vectors use `mask_b16` -- `i8`, `si8`, and `ui8` vectors use `mask_b8` - -```python -# Correct: f32 vector with b32 mask -mask32 = pto.make_mask(pto.f32, PAT.ALL) -vec_f32 = pto.vlds(ptr, offset) -out = pto.vabs(vec_f32, mask32) - -# Error: mismatched mask granularity -mask16 = pto.make_mask(pto.f16, PAT.ALL) -out = pto.vabs(vec_f32, mask16) # Type error! -``` - -### Pointer Types [Advanced Tier] - -Pointers combine element type and memory space: - -```python -from pto import MemorySpace - -ptr_gm = pto.ptr(pto.f32, MemorySpace.GM) # GM pointer to f32 -ptr_ub = pto.ptr(pto.f16, MemorySpace.UB) # UB pointer to f16 -``` - -The `MemorySpace` enum provides type-safe memory space specification: - -| Enum Value | Description | -|------------|-------------| -| `MemorySpace.GM` | Global Memory (off-chip HBM/DDR) | -| `MemorySpace.MAT` | Cube L1 / cbuf staging buffer | -| `MemorySpace.LEFT` | Cube L0A left-operand buffer | -| `MemorySpace.RIGHT` | Cube L0B right-operand buffer | -| `MemorySpace.ACC` | Cube L0C accumulator buffer | -| `MemorySpace.BIAS` | Cube bias table buffer | -| `MemorySpace.UB` | Unified Buffer (on-chip SRAM, 256KB) | - -This replaces ad-hoc string literals with compile-time checked enums and is -shared by both the Vector and Cube DSL surfaces. - -### Public Buffer Types - -TileLang uses three public buffer-facing type names in kernel signatures: - -| Public Type | Description | -|-------------|-------------| -| `pto.TensorView` | GM-facing tensor view descriptor used for DMA-oriented data access | -| `pto.PartitionTensorView` | Logical GM partition (slice) descriptor, corresponding to `!pto.partition_tensor_view<...>` | -| `pto.Tile` | Tile buffer value for hardware-resident staged compute/storage buffers | - -### TensorView Types - -TensorView types represent multi-dimensional (up to 5D) views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. - -#### TensorView Type Definition - -TensorView types are parameterized by shape (a tuple of up to 5 dimensions) and element type: - -```python -# Kernel parameter using TensorView -@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) -def tiled_kernel( - input_tensor: pto.TensorView, # GM tensor view - output_tensor: pto.TensorView, # GM tensor view - tile_buf: pto.Tile # UB tile -): - # Access tensor view properties - shape = input_tensor.shape # tuple of dimensions (dynamic or static, up to 5D) - dtype = input_tensor.element_type # e.g., pto.f32 - strides = input_tensor.strides # stride in elements -``` - -Important notes: -- TensorView is a read-only descriptor for GM data, though DMA store operations can write through it. -- Shape can be static (compile-time constants) or dynamic (determined at runtime). -- Strides are expressed in elements, not bytes. -- Memory space is always GM (Global Memory). -- Maximum rank is 5. PTO ISA right-aligns lower-rank shapes to 5D. -- When higher dimensions are 1, a 5D TensorView can be abbreviated to lower-rank forms. For example, shape `(1, 1, 64, 32, 16)` can be written as `(64, 32, 16)`, and shape `(1, 1, 1, 32, 16)` can be written as `(32, 16)`. - -#### TensorView Attributes - -| Attribute | Type | Description | -|-----------|------|-------------| -| `shape` | `tuple[int, ...]` | Tensor dimensions (supports up to 5 dimensions, right-aligned to 5D in PTO ISA) | -| `element_type` | `Type` | Element data type (for example `pto.f32`, `pto.f16`) | -| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | -| `offset` | `pto.i64` | Byte offset from base pointer (internal) | - -#### Padding Mode Enum - -Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: - -| Enum Value | Description | -|------------|-------------| -| `PadMode.PadNull` | No padding. Out-of-bounds access is invalid | -| `PadMode.PadFirstElem` | Pad using the first element of the source | -| `PadMode.PadValue` | Pad using a specified value and requires `pad_value` | - -#### Slicing Syntax - -TensorView supports Python slicing syntax to create logical partitions: - -```python -# Create a partition from a tensor view -partition = tensor_view[dim0_start:dim0_end, dim1_start:dim1_end] - -# Example: extract a 16x16 tile from a larger tensor -tile_view = large_tensor[0:16, 0:16] - -# Dynamic offsets and sizes -dim0_start = tensor_view.shape[0] // 2 -dynamic_partition = tensor_view[dim0_start:tensor_view.shape[0], 4:20] - -# Static positive step on dimension 0 -stepped_partition = tensor_view[0:32:2, 0:16] - -# Right-aligned shorthand on a 5D descriptor -partition_3d = tensor_view[d2_start:d2_end, d3_start:d3_end, d4_start:d4_end] - -# Full 5D spelling remains available when needed -partition_5d = tensor_view[ - d0_start:d0_end, - d1_start:d1_end, - d2_start:d2_end, - d3_start:d3_end, - d4_start:d4_end, -] -``` - -Constraints: -- Slicing returns a new `pto.PartitionTensorView` representing the logical partition. -- The partition must be within the original tensor bounds. -- When fewer than 5 slice axes are written, they are right-aligned to the trailing physical axes of the 5D descriptor. -- `stop` must be explicit on all dimensions. -- `start` may be static or dynamic. -- `step` must be a static positive integer. -- Dimension 0 may use `step > 1`. -- Dimension 1 must keep `step == 1` in the current DMA-oriented implementation. - -### PartitionTensorView Types - -`pto.PartitionTensorView` models a logical partition of GM tensor data and maps to -`!pto.partition_tensor_view` in PTO IR. -Like `TensorView`, it is a descriptor type and does not own storage. - -#### PartitionTensorView Type Definition - -```python -@pto.vkernel(target="a5", op="custom_partition", dtypes=[(pto.f32, pto.f32)]) -def kernel(inp: pto.TensorView, out: pto.TensorView): - part: pto.PartitionTensorView = inp[0:16, 0:16] - p_rows, p_cols = part.shape - s_row, s_col = part.strides - return None -``` - -Important notes: -- A `PartitionTensorView` carries partition `shape` and `strides` metadata in element units. -- Element dtype is inherited from the source tensor view. -- Memory space remains GM. -- Rank handling follows the same right-aligned 5D contract as `TensorView`. -- `PartitionTensorView` can be used where DMA-oriented TensorView-like descriptors are accepted. -- Prefer direct indexing or tuple unpacking for `shape`/`strides` metadata values in current DSL v1 lowering. - -#### PartitionTensorView Attributes - -| Attribute | Type | Description | -|-----------|------|-------------| -| `shape` | `tuple[int, ...]` | Partition dimensions | -| `element_type` | `Type` | Element data type inherited from source tensor view | -| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | -| `offset` | `pto.i64` | Byte offset from the base tensor pointer (internal) | - -### Tile Types - -Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. - -#### Tile Type Definition - -`pto.Tile` is the public tile type used for hardware buffer allocation in specific -address spaces. Tiles are constructed directly via the `pto.Tile` constructor: - -```python -pto.Tile( - shape: tuple[int, ...], # Buffer shape (required) - dtype: Type, # Element type (required) - memory_space: MemorySpace, # Address space (required) - valid_shape: tuple[int, ...] | None = None, # Valid region, defaults to shape - blayout: BLayout | None = None, # B layout, auto-detected from address space - slayout: SLayout | None = None, # S layout, auto-detected from address space - fractal_size: int | None = None, # Fractal size, auto-detected from address space - pad_value: PadValue = PadValue.Null, # Pad policy - compact_mode: CompactMode = CompactMode.Null, # Compact mode - addr: int | None = None, # Pre-assigned address (level3 only) -) -> Tile -``` - -Layout defaults are selected automatically based on the address space: - -| Address Space | blayout default | slayout default | fractal_size default | -|--------------|----------------|----------------|---------------------| -| `MAT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | -| `LEFT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | -| `RIGHT` | `RowMajor` | `ColMajor` | `TileConfig.fractalABSize` (512) | -| `ACC` | `ColMajor` | `RowMajor` | `TileConfig.fractalCSize` (1024) | -| `BIAS` | `RowMajor` | `NoneBox` | `TileConfig.fractalABSize` (512) | -| `UB` / `VEC` | `RowMajor` | `NoneBox` | `TileConfig.fractalABSize` (512) | - -Related enum types: - -| Enum | Values | -|------|--------| -| `BLayout` | `ColMajor` (0), `RowMajor` (1) | -| `SLayout` | `NoneBox` (0), `RowMajor` (1), `ColMajor` (2) | -| `PadValue` | `Null` (0), `Zero` (1), `Max` (2), `Min` (3) | -| `CompactMode` | `Null` (0), `Normal` (1), `RowPlusOne` (2) | - -Usage: - -```python -# Allocate tiles in @vkernel or @ckernel -tile_ub = pto.Tile([256, 128], pto.f32, MemorySpace.UB) -tile_left = pto.Tile([16, 64], pto.f16, MemorySpace.LEFT) -tile_acc = pto.Tile([16, 16], pto.f32, MemorySpace.ACC, valid_shape=(12, 12)) -``` - -Important notes on shape and valid shape: -- `shape` must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. -- `valid_shape` can be either static or dynamic and must be less than or equal to `shape` in each dimension. -- When `valid_shape` is not specified, it defaults to the full `shape`. - -#### Tile Attributes - -| Attribute | Type | Description | -|-----------|------|-------------| -| `shape` | `tuple[int, ...]` | Full tile dimensions. These are compile-time constants | -| `element_type` | `Type` | Element data type (for example `pto.f32`) | -| `memory_space` | `MemorySpace` | Memory space such as UB, MAT, LEFT, RIGHT, ACC, or BIAS | -| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within the tile. Must be less than or equal to `shape` in each dimension | -| `config` | `TileConfig` | Layout and padding configuration | - -#### Tile Pad Values - -`TileConfig.pad_value` is modeled after the C++ `PadValue : uint64_t` design. - -Standard pad values use small integer encodings: - -| DSL Value | Encoded Value | Meaning | -|-----------|---------------|---------| -| `pto.PadValue.NULL` | `0` | No concrete fill value | -| `pto.PadValue.ZERO` | `1` | Zero fill | -| `pto.PadValue.MAX` | `2` | Maximum finite / integer max for the tile element dtype | -| `pto.PadValue.MIN` | `3` | Minimum finite / integer min for the tile element dtype | - -Custom pad values use the `CustomBase = 0x100000000` convention and are authored with `pto.PadValue.custom_f32(...)`: - -```python -pad0 = pto.PadValue.ZERO -pad1 = pto.PadValue.custom_f32(-1.0) -pad2 = pto.PadValue.custom_f32("0xBF800000") # float32 bit pattern for -1.0f -``` - -Notes: -- `PadValue.encoded` exposes the host-side uint64 payload. `PadValue.value` is intentionally unavailable to avoid confusion with `.eval(...)` scalar materialization. -- `PadValue.text` exposes the standard textual spelling for built-ins such as `null` and `zero`. -- Custom pad values currently model an `f32` payload. In DSL v1, materializing a custom pad into a scalar is only supported for floating tile element dtypes. -- `PadValue.NULL` does not denote a usable scalar fill constant. Calling `tile.pad_value.eval()` or `tile.config.pad_value.eval()` when the enum is `NULL` is a frontend error. -- **DMA padding**: When performing GM→UB DMA transfers with padding enabled (via `enable_ub_pad=True` in `pto.copy_gm_to_ubuf`), the pad value must be configured explicitly using `pto.set_mov_pad_val`. Tile `PadValue` descriptors are not automatically translated to hardware register configurations in TileLang DSL v1. See [Pad Fill Semantics](08-sync-dma-operations.md#pad-fill-semantics) for usage details. - -Host-side code can materialize a scalar with an explicit dtype: - -```python -pad_max_f32 = pto.PadValue.MAX.eval(pto.f32) -pad_min_i16 = pto.PadValue.MIN.eval(pto.i16) -``` - -#### Tile Shape Concepts - -- `shape` is the static physical allocation size of the tile buffer. -- `valid_shape` is the logical data region and may be static or dynamic. -- `valid_shape[i] <= shape[i]` must hold for each dimension. -- Fixed-size tiles with smaller valid regions are useful for padding and partial-tile cases. - -#### Basic Access Operations - -```python -# Get tile properties -shape = tile.shape # (256, 128) -elem_type = tile.element_type # pto.f32 -mem_space = tile.memory_space # MemorySpace.UB -valid_shape = tile.valid_shape # (240, 120) or same as shape - -# Get configuration properties -config = tile.config -b_layout = config.b_layout # pto.BLayout.ROW_MAJOR -s_layout = config.s_layout # pto.SLayout.NONE_BOX -s_fractal = config.s_fractal_size # pto.i32(512) -pad_desc = tile.config.pad_value # PadValue enum bound to the tile element dtype -pad_desc2 = tile.pad_value # direct sugar for the same PadValue enum - -# Dynamic properties -rank = tile.rank # 2 -``` - -`tile.config.pad_value` and `tile.pad_value` are enum-typed inside kernel code. Use `.eval()` to materialize the configured pad descriptor against the tile element dtype: - -- `tile.pad_value.eval()` with `PadValue.ZERO` becomes `0` / `0.0` -- `tile.pad_value.eval()` with `PadValue.MAX` becomes dtype-aware max -- `tile.pad_value.eval()` with `PadValue.MIN` becomes dtype-aware min -- `tile.pad_value.eval()` with `PadValue.custom_f32(...)` becomes the authored floating scalar -- `tile.pad_value.eval()` with `PadValue.NULL` raises a frontend error - -For dtype-dependent fill seeds, prefer `tile.pad_value.eval()` over handwritten -`if dtype == ...` ladders. - -For standalone `PadValue` symbols that are not bound to a tile, pass the target dtype explicitly: - -```python -pad_scalar = pto.PadValue.MAX.eval(pto.f32) -``` - -```python -@pto.vkernel(op="fill_pad_value", dtypes=[(pto.AnyType,)]) -def fill_pad_value(dst: pto.Tile): - pad_scalar = dst.pad_value.eval() - pad_vec = pto.vbr(pad_scalar) - # ... -``` - -Typical materialized values: - -- `PadValue.ZERO` -> `0` / `0.0` -- `PadValue.MAX` -> dtype-aware max, for example `4294967295` for `pto.ui32` -- `PadValue.MIN` -> dtype-aware min, for example `-2147483648` for `pto.i32` and `0` for `pto.ui32` - -This is usually simpler than spelling every dtype case manually with -`pto.constexpr(dst.element_type == ...)`. - -Example: reading pad value from a `Tile` - -```python -@pto.vkernel(op="fill_pad_demo", dtypes=[(pto.f16,)]) -def kernel(dst: pto.Tile): - mask, _ = pto.make_mask(pto.f16, 8) - - # Read the Tile-bound PadValue enum. - pad0 = dst.pad_value - - # Equivalent form through TileConfig metadata. - pad1 = dst.config.pad_value - - if pto.constexpr(pad0 != pto.PadValue.NULL): - scalar0 = pad0.eval() - scalar1 = pad1.eval() - vec0 = pto.vdup(scalar0, mask) - vec1 = pto.vdup(scalar1, mask) - pto.vsts(vec0, dst[0, 0:], mask) - pto.vsts(vec1, dst[1, 0:], mask) -``` - -If `dst` is specialized with `config=pto.TileConfig.from_mapping({"pad_value": pto.PadValue.ZERO})`, -both `pad0` and `pad1` are `PadValue.ZERO`, and `pad0.eval()` / `pad1.eval()` materialize to the scalar `0.0` for an `f16` tile. - -#### Conversion Operations - -Basic mode syntax uses tile element-indexing directly in vector operations: - -```python -# 2D tile indexing -vec = pto.vlds(tile[row, col:]) -pto.vsts(vec, tile[row, col:], mask) - -# 1D tile indexing -vec = pto.vlds(tile[start:]) -pto.vsts(vec, tile[start:], mask) -``` - -Advanced mode syntax converts tiles to typed pointers for byte-offset operations: - -```python -# Convert tile to pointer -ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) - -# Use pointer with byte offset -vec = pto.vlds(ptr, offset) -pto.vsts(vec, ptr, offset, mask) -``` - -#### Kernel Parameter Usage - -```python -@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def tiled_kernel( - input_tile: pto.Tile, - output_tile: pto.Tile, - scale: pto.f32 -): - all_mask = pto.make_mask(pto.f32, PAT.ALL) - for i in range(0, 256, 64): - vec = pto.vlds(input_tile[i, 0:]) - scaled = pto.vmuls(vec, scale, all_mask) - pto.vsts(scaled, output_tile[i, 0:], all_mask) -``` - -### Alignment Type - -The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. diff --git a/ptodsl/docs/user_guide/06-control-flow.md b/ptodsl/docs/user_guide/06-control-flow.md deleted file mode 100644 index 41b623d1d..000000000 --- a/ptodsl/docs/user_guide/06-control-flow.md +++ /dev/null @@ -1,181 +0,0 @@ -## Control Flow - -### Vector Scopes - -The TileLang DSL supports implicit vector scope inference, allowing developers to write vector operations directly without explicit `pto.vecscope()` blocks. The compiler automatically groups consecutive, data-dependent vector operations into implicit vector scopes during lowering. - -#### Implicit Scope Inference - -**Note:** `pto.vecscope()` is supported. Automatic scope inference runs only when the kernel does **not** contain explicit `with pto.vecscope():` blocks. - -When you write vector operations like `pto.vlds`, `pto.vadd`, `pto.vsts` directly in your code, the compiler's **Scope Inference Pass** analyzes the control flow graph and automatically creates vector scopes: - -```python -# No explicit vecscope needed - compiler infers scope boundaries -vec = pto.vlds(outer_ptr, offset) -result = pto.vadd(vec, vec, all_mask) -pto.vsts(result, dst_ptr, offset, all_mask) -``` - -The compiler automatically groups these three operations into a single implicit vector scope because they form a data-dependent chain (when no explicit `pto.vecscope()` appears in the kernel). - -**Scope boundary rules:** -1. **Control flow boundaries**: Branches (`if`/`else`), loops (`for`/`while`), and function calls create implicit scope boundaries -2. **Scalar operations**: Non-vector operations (e.g., scalar arithmetic, pointer arithmetic) create boundaries -3. **Explicit scope blocks**: User-defined `vecscope` and `strict_vecscope` blocks create hard boundaries - -#### Explicit Scope Boundaries with `strict_vecscope` [Advanced Tier] - -##### `pto.strict_vecscope(*captures: AnyType) -> ContextManager[Tuple[AnyType, ...]]` - -**Description**: Creates an explicit vector scope boundary with explicit value captures. Values used inside the scope must be passed as arguments; implicit capture from outer scope is rejected. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `*captures` | `AnyType` | Variable number of values to be captured and passed into the scope | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `context_manager` | `ContextManager[Tuple[AnyType, ...]]` | Context manager that yields a tuple of captured values when entered | - -**Constraints**: -- The scope body cannot implicitly capture values from the surrounding scope; all used values must be passed as `captures`. -- Creates a hard boundary that prevents the compiler from merging vector operations across the scope boundary. -- Useful for performance optimization, debugging, resource management, and hardware compatibility. - -For precise control over scope boundaries, use explicit `strict_vecscope` blocks. These create hard boundaries that prevent the compiler from merging operations across the block boundary: - -```python -with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): - # Operations inside this block are isolated from outside - # Compiler will not merge operations across this boundary - for i in range(lb, ub, 64): - vec = pto.vlds(s, i) - pto.vsts(vec, d, i, all_mask) -``` - -**Use cases for strict_vecscope:** -- Performance optimization: Isolate critical vector computation regions -- Debugging: Create explicit boundaries to isolate vector operations -- Resource management: Control vector register allocation boundaries -- Compatibility: Ensure deterministic scope placement for hardware constraints - -#### Explicit Scope Blocks with `vecscope` - -`pto.vecscope` provides an explicit vector-scope boundary without strict capture ABI constraints: - -```python -with pto.vecscope(): - vec = pto.vlds(src, 0) - vec = pto.vadd(vec, vec, mask) - pto.vsts(vec, dst, 0, mask) -``` - -**Rules**: -- `pto.vecscope()` takes no positional/keyword arguments. -- `pto.vecscope()` does not support `as (...)` bindings. -- When any explicit `pto.vecscope()` is present in a kernel body, automatic vecscope inference is disabled for that kernel. - -### Inline Procedures (`@pto.inline_proc`) - -TileLang DSL supports reusable top-level procedures decorated with `@pto.inline_proc`. -`inline_proc` follows function-call semantics in frontend IR and is force-inlined -later by the VPTO backend mainline in `ptoas`. - -```python -@pto.inline_proc -def store_row(dst: pto.Tile, src: pto.Tile, row: pto.i32): - vec = pto.vlds(src[row, 0:]) - mask = pto.make_mask(dst.element_type, pto.PAT.ALL) - pto.vsts(vec, dst[row, 0:], mask) - return None - -@pto.vkernel(op="pto.row_copy", dtypes=[(pto.f32, pto.f32, pto.i32)]) -def row_copy(dst: pto.Tile, src: pto.Tile, row: pto.i32): - store_row(dst, src, row) - return None -``` - -Important semantics: - -- `pto.(...)` and bare helper calls are different mechanisms. -- Calls written as `pto.vadd(...)`, `pto.vdiv(...)`, `pto.vlds(...)`, etc. target - built-in TileLang/VPTO surfaces directly. -- Calls written as bare Python names such as `store_row(...)` target a - user-defined `@pto.inline_proc` helper when the callee name resolves to a - registered top-level inline procedure in the current module. -- `inline_proc` helpers do not live in the `pto` namespace; using the same - basename as a `pto.` op is allowed because the frontend distinguishes - `pto.xxx(...)` from bare `xxx(...)` calls. -- Frontend preserves helper `func.func` and `func.call` in `mlir_text()` output. -- VPTO backend mainline force-inlines helper calls before downstream lowering. -- Helper definitions support default parameter values. -- Helper calls support positional arguments and keyword arguments. -- Helper calls can appear in statement and expression positions. -- Helper definitions can use trailing `return ` to return values. -- Implicit capture is rejected except module-level globals whose current bound value is `bool`/`int`/`float`/`str`; pass other required values as explicit arguments. -- Recursive/mutually-recursive helper call graphs are rejected. -- `*args`, `**kwargs`, and keyword-only parameters are unsupported in current version. - -Shared helpers can live in a separate Python file in the template directory and -be imported directly by templates: - -```python -# shared_rows.py -import tilelang_dsl as pto - -@pto.inline_proc -def touch_row(dst: pto.Tile, row: pto.i32): - mask = pto.make_mask(dst.element_type, pto.PAT.ALL) - vec = pto.vlds(dst[row, 0:]) - pto.vsts(vec, dst[row, 0:], mask) - return None - -# trow_template.py -import tilelang_dsl as pto -from shared_rows import touch_row - -@pto.vkernel(op="pto.row_touch", dtypes=[(pto.f32, pto.i32)]) -def row_touch(dst: pto.Tile, row: pto.i32): - touch_row(dst, row) - return None -``` - -Only directly imported `@pto.inline_proc` helpers are part of this shared-helper -surface. Ordinary Python functions remain unsupported in DSL bodies, and -qualified calls such as `shared_rows.touch_row(...)` are not part of this -version. If multiple imported helpers expose the same bare name, the frontend -rejects the template instead of choosing one by import order. - -### Loops - -Counted loops use Python's `range` syntax: - -```python -for i in range(lb, ub, step): - # Loop body - mask, rem = pto.make_mask(pto.f32, remaining) - # ... -``` - -Loop-carried state is automatically handled through variable updates within the loop. - -### Conditionals - -`if` statements support value merging: - -```python -flag: pto.i1 = some_condition -step: pto.i32 = 0 - -if flag: - step = pto.i32(64) -else: - step = pto.i32(128) - -# 'step' here is the merged result from both branches -``` - -Variables defined in only one branch are local to that branch. diff --git a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md new file mode 100644 index 000000000..ba428a313 --- /dev/null +++ b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md @@ -0,0 +1,376 @@ +# 6. Scalar and Pointer Operations + +Chapter 5 established the rule: Python constructs are resolved at trace time, PTO constructs produce device-side behavior. This chapter applies that distinction to scalars and pointers — when to use a plain Python number, when to use a `scalar.*` operation, and how to work with typed pointers. + +## 6.1 Python scalars vs PTO scalars + +A **Python scalar** is any value computed by Python during tracing: a literal (`3.14159`), a shape dimension (`A.shape[0]`), a constexpr parameter (`BLOCK`), or an arithmetic expression built from these (`1.0 / sqrt(dim)`). These are evaluated at trace time and their results are baked into the device code as constants. + +A **PTO scalar** is a value that lives on the device at runtime. It comes from a `scalar.load` read, a device-side computation (`scalar.max`, `scalar.exp`), or a runtime query (`pto.get_block_idx()`). PTO scalars flow through the recorded program and are not resolved until the kernel executes. + +### The mixed expression + +In practice, a single expression can mix both kinds: + +```python +alpha * o_prev + beta * pv_val +# ^ Python float (trace-time constant, e.g. 1.0 / sqrt(dim)) +# ^ PTO scalar (loaded from tile at runtime) +# ^ PTO scalar (loaded from tile at runtime) +``` + +`alpha` is a Python float computed from compile-time information — it becomes an immediate constant in the device code. `o_prev` and `pv_val` are PTO scalars read from tiles at runtime. The `*` and `+` operators are recorded as device-side multiply-add instructions. The tracer sees the whole expression and produces the appropriate device instructions, embedding the constant operand where possible. + +### Rule of thumb + +| If the value... | Use... | Example | +|-----------------|--------|---------| +| Is known at compile time | Python scalar | `BLOCK`, `1.0 / sqrt(dim)`, `A.shape[0]` | +| Comes from device memory | PTO scalar | `scalar.load(tile[r, c])` | +| Depends on a runtime value | PTO scalar | `scalar.max(m_prev, row_max)` | +| Is a block/subblock index | PTO scalar | `pto.get_block_idx()` | + +When in doubt, ask: *can this value change between launches of the same compiled kernel?* If yes, it must be a PTO scalar. + +## 6.2 Scalar access: load and store + +`scalar.load` reads a single scalar element from a typed pointer or tile location. `scalar.store` writes a scalar back. These are the canonical scalar memory ops for SIMT authoring. The offset is counted in elements, not bytes. + +#### `scalar.load(ptr: PtrType, offset: Index) -> ScalarType` + +**Description**: Loads one scalar element from a typed pointer at the given element offset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed pointer (`pto.ptr`) or the result of `tile.as_ptr()` | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `value` | `ScalarType` | The loaded scalar, matching the pointer's element type | + +**Tile-index form** — the preferred syntax when loading from a tile: + +```python +val = scalar.load(tile[row, col]) +``` + +`tile[row, col]` selects one element. Row and column indices are PTO scalars (or Python integers that the tracer promotes). This form is equivalent to computing the pointer and offset from the tile's base address and layout. + +**Pointer forms**: + +```python +val = scalar.load(ptr, offset) # explicit offset +val = scalar.load(ptr + offset) # pointer arithmetic shorthand +``` + +--- + +#### `scalar.store(value: ScalarType, ptr: PtrType, offset: Index) -> None` + +**Description**: Stores one scalar element to a typed pointer at the given element offset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar value to write | +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: None (side-effect operation). + +**Tile-index form**: + +```python +scalar.store(value, tile[row, col]) +``` + +**Pointer forms**: + +```python +scalar.store(value, ptr, offset) +``` + +--- + +### Typical SIMT usage + +`scalar.load` and `scalar.store` are the primary data access pattern inside `@pto.simt` kernels. Each `load`/`store` operates on one element per work-item, but the SIMT unit executes the same instruction across many work-items in parallel: + +```python +@pto.simt +def blend_output_rows( + o_prev_tile: pto.Tile, pv_tile: pto.Tile, + alpha_tile: pto.Tile, beta_tile: pto.Tile, + o_next_tile: pto.Tile, + row_start: pto.i32, row_stop: pto.i32, valid_dim: pto.i32, +): + with pto.for_(row_start, row_stop, step=1) as row: + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + with pto.for_(0, valid_dim, step=1) as col: + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) +``` + +When writing to a raw pointer (e.g., a small metadata buffer obtained via `as_ptr()`), use the pointer-plus-offset form: + +```python +meta_ptr = meta_tile.as_ptr() +scalar.store(0, meta_ptr, 0) # store at element offset 0 +scalar.store(valid_rows, meta_ptr, 4) # store at element offset 4 +row_start = scalar.load(meta_ptr, 0) +row_stop = scalar.load(meta_ptr, 4) +``` + +## 6.3 Scalar arithmetic and comparisons + +### Python operators for basic arithmetic + +Addition, subtraction, multiplication, and division of PTO scalars use standard Python syntax. The tracer records the corresponding device-side instructions automatically: + +```python +o_next = alpha * o_prev + beta * pv_val # multiply-add +l_scaled = l_prev * scalar.exp(m_prev - m_next) # subtraction inside exp +step = (N + BLOCK - 1) // BLOCK # Python int arithmetic (trace-time) +``` + +When both operands are PTO scalars (loaded from device memory or produced by another device-side op), `+`, `-`, `*`, `/` produce device-side arithmetic instructions. When one operand is a Python scalar (trace-time constant), the tracer embeds it as an immediate. + +### Math functions: `scalar.*` + +Non-trivial scalar math functions live under the `scalar` namespace (imported as `from pto import scalar` or accessed as `pto.scalar`): + +#### `scalar.max(a: ScalarType, b: ScalarType) -> ScalarType` + +**Description**: Returns the maximum of two scalars. + +#### `scalar.min(a: ScalarType, b: ScalarType) -> ScalarType` + +**Description**: Returns the minimum of two scalars. + +#### `scalar.exp(x: ScalarType) -> ScalarType` + +**Description**: Exponential, e^x. + +#### `scalar.log(x: ScalarType) -> ScalarType` + +**Description**: Natural logarithm. + +#### `scalar.sqrt(x: ScalarType) -> ScalarType` + +**Description**: Square root. + +#### `scalar.abs(x: ScalarType) -> ScalarType` + +**Description**: Absolute value. + +#### `scalar.gt(a: ScalarType, b: ScalarType) -> pto.i1` + +**Description**: Greater-than comparison. Returns `pto.i1`. + +#### `scalar.lt(a: ScalarType, b: ScalarType) -> pto.i1` + +**Description**: Less-than comparison. Returns `pto.i1`. + +#### `scalar.eq(a: ScalarType, b: ScalarType) -> pto.i1` + +**Description**: Equality comparison. Returns `pto.i1`. + +**Example**: + +```python +m_next = scalar.max(m_prev, row_max) +l_scaled = l_prev * scalar.exp(m_prev - m_next) +need_scale = scalar.gt(val, threshold) +``` + +For readability in files with many scalar operations, assign `pto.scalar` to a short local name: + +```python +scalar = pto.scalar + +m_next = scalar.max(m_prev, row_max) +l_scaled = l_prev * scalar.exp(m_prev - m_next) +``` + +These are the scalar-path counterparts of the vector math operations covered in Chapter 8. Use them inside `@pto.simt` kernels and in `@pto.ukernel` orchestration code where you need to compute a loop bound or a scalar coefficient from runtime data. + +## 6.4 Pointer operations + +Typed pointers (Section 4.4) carry both an element type and a memory space. This section covers the operations that create and manipulate them. + +### Obtaining pointers: as_ptr() + +Tiles and tensor views expose their base address via `as_ptr()`: + +```python +gm_ptr = partition.as_ptr() # GM pointer from a PartitionTensorView +ub_ptr = tile.as_ptr() # UB pointer from a Tile +``` + +`as_ptr()` is the preferred way to get a typed pointer from a high-level descriptor. The result carries the correct element type and memory space from the source. + +--- + +#### `pto.addptr(ptr: PtrType, offset: Index) -> PtrType` + +**Description**: Advances a pointer by a number of elements (not bytes). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Source pointer | +| `offset` | `Index` | Number of elements to advance | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `new_ptr` | `PtrType` | Pointer advanced by `offset` elements | + +**Example**: + +```python +ptr = pto.addptr(base_ptr, 1024) # advances by 1024 * sizeof(T) bytes +``` + +The `+` shorthand on pointers also counts in elements, not bytes. + +--- + +#### `pto.castptr(address: Index, ptr_type: Type) -> PtrType` + +**Description**: Creates a typed pointer from an integer address or reinterprets a pointer as a different type. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `address` | `Index` | Integer address or existing pointer value | +| `ptr_type` | `Type` | Target pointer type, e.g. `pto.ptr(pto.f32, pto.MemorySpace.UB)` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `ptr` | `PtrType` | Typed pointer value | + +This is an advanced operation. Prefer `as_ptr()` when the source already carries type information. + +## 6.5 Compile-time queries + +These functions return values that are known at trace time from type information or hardware constants. + +#### `pto.bytewidth(dtype: Type) -> int` + +**Description**: Returns the size in bytes of a single element of the given data type. The result is a Python `int` evaluated at trace time. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type, e.g. `pto.f32`, `pto.f16`, `pto.i8` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `size` | `int` | Element size in bytes | + +**Example**: + +```python +bw = pto.bytewidth(pto.f32) # 4 +bw = pto.bytewidth(pto.f16) # 2 +bw = pto.bytewidth(pto.i8) # 1 +``` + +--- + +#### `pto.elements_per_vreg(dtype: Type) -> int` + +**Description**: Returns how many elements of `dtype` fit in one 256-byte vector register. The result is a Python `int` evaluated at trace time. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type, e.g. `pto.f32`, `pto.f16`, `pto.i8` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `elems` | `int` | Number of elements per vector register | + +**Example**: + +```python +vec = pto.elements_per_vreg(pto.f32) # 64 +vec = pto.elements_per_vreg(pto.f16) # 128 +vec = pto.elements_per_vreg(pto.i8) # 256 +``` + +This is the standard stride for chunking column loops in SIMD kernels: + +```python +VEC = pto.elements_per_vreg(pto.f32) +with pto.for_(0, cols, step=VEC) as c: + ... +``` + +## 6.6 Per-element tile traversal in @pto.simt + +`@pto.simt` kernels are the natural home for per-element scalar work. A typical pattern uses nested `pto.for_` loops to walk over a tile row by row, column by column: + +```python +@pto.simt +def elementwise_scale( + src_tile: pto.Tile, + dst_tile: pto.Tile, + scale: pto.f32, + rows: pto.i32, + cols: pto.i32, +): + with pto.for_(0, rows, step=1) as r: + with pto.for_(0, cols, step=1) as c: + val = scalar.load(src_tile[r, c]) + scaled = val * scale + scalar.store(scaled, dst_tile[r, c]) +``` + +This reads each element from `src_tile`, multiplies by `scale`, and writes to `dst_tile`. The SIMT unit executes the body in parallel across work-items, so this scalar-looking code achieves high throughput — each work-item handles a different `(r, c)` pair. + +For operations that need per-row metadata alongside per-element computation, lift the row-level scalar out of the inner loop: + +```python +@pto.simt +def blend_with_per_row_coeffs( + o_prev_tile: pto.Tile, + pv_tile: pto.Tile, + alpha_tile: pto.Tile, # [rows, 1] — one coefficient per row + beta_tile: pto.Tile, # [rows, 1] + o_next_tile: pto.Tile, + rows: pto.i32, + cols: pto.i32, +): + with pto.for_(0, rows, step=1) as r: + alpha = scalar.load(alpha_tile[r, 0]) # read once per row + beta = scalar.load(beta_tile[r, 0]) # read once per row + with pto.for_(0, cols, step=1) as c: + o_prev = scalar.load(o_prev_tile[r, c]) + pv_val = scalar.load(pv_tile[r, c]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[r, c]) +``` + +This hoists `alpha` and `beta` out of the inner loop — the row coefficients are loaded once and broadcast across all columns in that row. diff --git a/ptodsl/docs/user_guide/07-data-movement-ops.md b/ptodsl/docs/user_guide/07-data-movement-ops.md new file mode 100644 index 000000000..bf6cf7dec --- /dev/null +++ b/ptodsl/docs/user_guide/07-data-movement-ops.md @@ -0,0 +1,1019 @@ +# 7. Data Movement Operations + +This chapter covers every operation that moves data between memory spaces in PTODSL — tile-level transfers, DMA micro-instructions, vector loads and stores, and cube data movement. Operations are organized by abstraction level: tile ops (L1), DMA ops (L2), vector memory ops (L3 SIMD), and cube memory ops (L3 cube). + +## 7.1 Tile-level movement: tload and tstore + +Tile ops move entire blocks between Global Memory and the Unified Buffer in a single call. They are the primary data movement interface inside `@pto.jit`. + +#### `pto.tload(partition: PartitionTensorView, tile: Tile) -> None` + +**Description**: Copies data from a GM partition into a UB tile. The transfer size is determined by the partition's `sizes` and the tile's shape — they must be compatible. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `partition` | `PartitionTensorView` | Source region in GM | +| `tile` | `Tile` | Destination buffer in UB | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) +a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) +pto.tload(a_part, a_tile) +``` + +--- + +#### `pto.tstore(tile: Tile, partition: PartitionTensorView) -> None` + +**Description**: Copies data from a UB tile back to a GM partition. The tile's `valid_shape` determines how many elements are written; elements outside `valid_shape` are not stored. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile` | `Tile` | Source buffer in UB | +| `partition` | `PartitionTensorView` | Destination region in GM | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +pto.tstore(o_tile, o_part) +``` + +--- + +Both `tload` and `tstore` operate at **tile granularity** — they are the idiomatic choice inside `@pto.jit` loops. When you need finer control over DMA scheduling, drop down to the micro-instruction level. + +## 7.2 DMA micro-instructions (ukernel) + +Inside `@pto.ukernel`, data movement between memory spaces is expressed with grouped DMA instructions on typed pointers. There are four operations covering the four data-movement directions: + +| Operation | Direction | Stride unit | Padding | +|-----------|-----------|-------------|---------| +| `pto.mte_gm_ub` | GM → UB | bytes | Supported | +| `pto.mte_ub_gm` | UB → GM | bytes | — (de-padded on read) | +| `pto.mte_ub_ub` | UB → UB | 32B units | — | +| `pto.mte_ub_l1` | UB → L1 | 32B units | — | + +All four share a common structure: a required innermost `nburst(...)` group that defines the repeated burst transfer, plus optional outer `loop(...)` groups for multi-level repetition. `pto.mte_gm_ub` additionally supports `pad(...)` for UB row padding. + +> **Convenience wrappers**: `pto.mte_load(src, dst)` and `pto.mte_store(src, dst)` are Python-level shorthands that expand to `mte_gm_ub` / `mte_ub_gm` with inferred strides. The reference operations below are the full grouped MTE interfaces. + +### 7.2.1 GM → UB: `pto.mte_gm_ub` + +#### `pto.mte_gm_ub(gm_src: PtrType, ub_dst: PtrType, l2_cache_ctl: int, len_burst: int, *, nburst: tuple[int, int, int], loops: list[tuple[int, int, int]] | None = None, pad: tuple[ScalarType, int, int] | tuple[ScalarType] | None = None) -> None` + +**Description**: Grouped DMA transfer from Global Memory to Unified Buffer. `nburst(...)` defines the innermost repeated burst (count, source stride in bytes, destination stride in bytes). Optional `loop(...)` groups add outer repetition levels. Optional `pad(...)` fills the gap between `len_burst` and `dst_stride` up to the 32B-aligned boundary. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `gm_src` | `PtrType` (gm) | GM source pointer | +| `ub_dst` | `PtrType` (ub) | UB destination pointer (must be 32B-aligned) | +| `l2_cache_ctl` | `int` | L2 cache allocate control (2 bits) | +| `len_burst` | `int` | Contiguous bytes transferred per burst row | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_stride, dst_stride)` — innermost burst group (required) | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional outer loop groups, each `(count, src_stride, dst_stride)`. Ordered inner to outer | +| `pad` | `tuple[ScalarType, int, int]` or `tuple[ScalarType]` or `None` | Optional padding: `(pad_value, left_count, right_count)` or `(pad_value,)`. Omitted counts default to 0 | + +**Returns**: None (side-effect operation). + +**Constraints**: +- `nburst` is always required. +- `loop` groups are ordered from inner (wrapping `nburst`) to outer. +- If `pad` specifies either left or right count, both must be provided. + +**Example** — load a 32×32 f32 tile from contiguous GM into contiguous UB: + +```python +pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 128, + nburst=(32, 128, 128)) +# 32 rows, 128 bytes per row, contiguous in both GM and UB +``` + +**Example** — load a 64×128 f16 tile from a larger GM matrix (1024×512) into UB: + +```python +pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 256, + nburst=(64, 1024, 256)) +# 64 rows of 256 bytes each. +# GM: each row is 1024 bytes apart (full matrix row stride). +# UB: rows are packed contiguously (256-byte stride). +``` + +**Example** — load with padding (100 valid f16 columns into a 128-wide UB tile): + +```python +pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 200, + nburst=(64, 200, 256), + pad=(0.0, 0, 0)) +# 64 rows, 200 valid bytes per row, 256-byte UB stride. +# Gap (56 bytes) between len_burst and dst_stride is zero-padded. +``` + +**Example** — multi-level loop: load 4 batches of 8×128 f16 tiles: + +```python +pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 256, + nburst=(8, 256, 256), + loops=[(4, 2048, 2048)]) +# Innermost: 8 rows × 256B (one tile). +# Outer loop: 4 iterations, each advancing 2048 bytes in both GM and UB. +``` + +--- + +### 7.2.2 UB → GM: `pto.mte_ub_gm` + +#### `pto.mte_ub_gm(ub_src: PtrType, gm_dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int], loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Grouped DMA transfer from Unified Buffer to Global Memory. The MTE reads `len_burst` bytes from each UB row (skipping any padding), writing only valid data to GM. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ub_src` | `PtrType` (ub) | UB source pointer (must be 32B-aligned) | +| `gm_dst` | `PtrType` (gm) | GM destination pointer | +| `len_burst` | `int` | Contiguous bytes transferred per burst row | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_stride, dst_stride)` — innermost burst group (required) | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional outer loop groups, ordered inner to outer | + +**Returns**: None (side-effect operation). + +**Example** — store a 32×32 f32 tile from UB to GM: + +```python +pto.mte_ub_gm(ub_ptr, gm_ptr, 128, + nburst=(32, 128, 128)) +``` + +**Example** — store a 64×128 f16 tile back to a larger GM matrix: + +```python +pto.mte_ub_gm(ub_ptr, gm_ptr, 256, + nburst=(64, 256, 1024)) +# UB: contiguous rows (256-byte stride). +# GM: rows spaced at 1024-byte intervals (full matrix width). +``` + +--- + +### 7.2.3 UB → UB: `pto.mte_ub_ub` + +#### `pto.mte_ub_ub(ub_src: PtrType, ub_dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int]) -> None` + +**Description**: Grouped UB-to-UB copy. Stride and gap values are in units of 32 bytes. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ub_src` | `PtrType` (ub) | UB source pointer (must be 32B-aligned) | +| `ub_dst` | `PtrType` (ub) | UB destination pointer (must be 32B-aligned) | +| `len_burst` | `int` | Burst length in units of 32 bytes | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_gap, dst_gap)` — count, source gap, destination gap (all in 32B units) | + +**Returns**: None (side-effect operation). + +Each burst copies `len_burst * 32` bytes. The next burst starts at `src + (len_burst + src_gap) * 32` and `dst + (len_burst + dst_gap) * 32`. + +**Example**: + +```python +pto.mte_ub_ub(ub_src, ub_dst, 8, + nburst=(16, 0, 4)) +# 16 bursts, each copying 8×32=256 bytes. +# Source: contiguous (src_gap=0). +# Destination: 4×32=128-byte gap between bursts. +``` + +--- + +### 7.2.4 UB → L1: `pto.mte_ub_l1` + +#### `pto.mte_ub_l1(ub_src: PtrType, l1_dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int]) -> None` + +**Description**: Grouped UB-to-L1 (CBUF) copy. Identical structure to `mte_ub_ub` but the destination is L1 cube buffer space. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ub_src` | `PtrType` (ub) | UB source pointer (must be 32B-aligned) | +| `l1_dst` | `PtrType` (l1) | L1 destination pointer (must be 32B-aligned) | +| `len_burst` | `int` | Burst length in units of 32 bytes | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_gap, dst_gap)` — all in 32B units | + +**Returns**: None (side-effect operation). + +--- + +### 7.2.5 The nburst / loop / pad model + +All grouped DMA operations follow a nested-loop execution model. `nburst` is the innermost group; each `loop` wraps the previous group as an outer iteration level. + +For `mte_gm_ub` and `mte_ub_gm`, strides are **byte distances** from the start of one burst row to the start of the next: + +``` +GM → UB (nburst only): + + for r in range(n_burst): + memcpy(ub_dst + r * dst_stride, + gm_src + r * src_stride, + len_burst) + if pad enabled: + memset(ub_dst + r * dst_stride + len_burst, + pad_value, + dst_stride_aligned - len_burst) +``` + +Each additional `loop(count, src_stride, dst_stride)` adds one outer `for` level that advances both base pointers by the corresponding strides. + +For `mte_ub_ub` and `mte_ub_l1`, the parameters are in **32-byte units**. Each burst copies `len_burst * 32` bytes, and the next burst starts at `src + (len_burst + src_gap) * 32` / `dst + (len_burst + dst_gap) * 32`. + +**UB address alignment**: For all four operations, every UB address (source and destination) must be 32-byte aligned. The `pad(...)` on `mte_gm_ub` ensures each UB row is padded to the 32B-aligned boundary of `dst_stride`, so subsequent rows stay aligned. + +### 7.2.6 Typical ukernel DMA pattern + +```python +@pto.ukernel +def process_block(k_part, v_part, k_tile, v_tile, o_tile, o_part, + rows: pto.i32, cols: pto.i32): + # Stage K and V blocks from GM to UB + pto.mte_gm_ub(k_part.as_ptr(), k_tile.as_ptr(), 0, + cols * pto.bytewidth(pto.f16), + nburst=(rows, cols * pto.bytewidth(pto.f16), + cols * pto.bytewidth(pto.f16))) + pto.mte_gm_ub(v_part.as_ptr(), v_tile.as_ptr(), 0, + cols * pto.bytewidth(pto.f16), + nburst=(rows, cols * pto.bytewidth(pto.f16), + cols * pto.bytewidth(pto.f16))) + pto.mem_bar(pto.BarrierType.SYNC) + + # ... compute on tiles ... + + pto.mem_bar(pto.BarrierType.SYNC) + pto.mte_ub_gm(o_tile.as_ptr(), o_part.as_ptr(), + cols * pto.bytewidth(pto.f32), + nburst=(rows, cols * pto.bytewidth(pto.f32), + cols * pto.bytewidth(pto.f32))) +``` + +## 7.3 Vector loads (simd) + +Inside `@pto.simd`, data moves between UB tiles and vector registers (`vreg`). Vector loads read a contiguous chunk of a tile row into a `vreg`; the chunk size equals the hardware vector width for the element type (e.g., 64 elements for `f32`, 128 for `f16`). + +### Tile-index syntax + +All vector load and store operations support the element-indexing syntax, which eliminates manual byte-offset calculation: + +```python +vec = pto.vlds(tile[row, col:]) # load from row, starting at column col +vec = pto.vlds(tile[start:]) # 1D tile, starting at element start +``` + +The compiler automatically computes the byte offset from the tile's shape, element type, and layout. The `:` indicates a full vector-width range — the number of elements loaded is `elements_per_vreg(dtype)`. + +--- + +#### `pto.vlds(tile[row, col:], dist: VLoadDist | None = None) -> VRegType` +#### `pto.vlds(tile[start:], dist: VLoadDist | None = None) -> VRegType` +#### `pto.vlds(buf: PtrType, offset: Index, dist: VLoadDist | None = None) -> VRegType` + +**Description**: Stateless vector load from UB. Reads one vector-width slice. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | +| `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | +| `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | +| `dist` | `VLoadDist` or `None` | Optional load distribution: `NORM` (default), `UNPK_B8`/`UNPK_B16`/`UNPK_B32`, `BRC_B8`/`BRC_B16`/`BRC_B32` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +--- + +#### `pto.vldsx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` +#### `pto.vldsx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` +#### `pto.vldsx2(buf: PtrType, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` + +**Description**: Dual vector load with deinterleave (AoS → SoA). Loads interleaved data and deinterleaves into two vectors. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | +| `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | +| `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | +| `dist` | `DeinterleaveDist` | `DINTLV` (alternating elements) or `BDINTLV` (block deinterleave) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Even-indexed elements | +| `high` | `VRegType` | Odd-indexed elements | + +--- + +#### `pto.vldas(tile[row, col:]) -> AlignType` +#### `pto.vldas(tile[start:]) -> AlignType` +#### `pto.vldas(buf: PtrType) -> AlignType` + +**Description**: Primes the alignment buffer for a subsequent unaligned load stream. Returns alignment state consumed by `vldus`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column | +| `tile[start:]` | Tile index | 1D tile with starting element | +| `buf` | `PtrType` | Pointer to buffer in UB | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align` | `AlignType` | Alignment state for use with `vldus` | + +--- + +#### `pto.vldus(tile[row, col:], align: AlignType) -> (VRegType, AlignType, PtrType)` +#### `pto.vldus(tile[start:], align: AlignType) -> (VRegType, AlignType, PtrType)` +#### `pto.vldus(buf: PtrType, align: AlignType) -> (VRegType, AlignType, PtrType)` + +**Description**: Unaligned load with alignment state threading. Requires alignment state from `vldas` or a previous `vldus`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | +| `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | +| `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | +| `align` | `AlignType` | Alignment state from `vldas` or previous `vldus` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Assembled vector | +| `align_out` | `AlignType` | Updated alignment state for next load | +| `base_out` | `PtrType` | Post-update base pointer | + +**Example**: + +```python +align = pto.vldas(tile[row, col:]) +vec, align, base = pto.vldus(tile[row, col:], align) +``` + +--- + +#### `pto.vsld(tile[row, col], stride: StrideMode) -> VRegType` +#### `pto.vsld(tile[pos], stride: StrideMode) -> VRegType` +#### `pto.vsld(buf: PtrType, offset: Index, stride: StrideMode) -> VRegType` + +**Description**: Strided scalar load with broadcast. Loads a single element using a strided access pattern and broadcasts to all vector lanes. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | Tile index | 2D single-element index | +| `tile[pos]` | Tile index | 1D single-element index | +| `stride` | `StrideMode` | `S3_B16`, `S4_B64`, `S8_B32`, or `S2_B64` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Broadcast vector | + +--- + +#### `pto.vgather2(buf: PtrType, offsets: Index, active_lanes: Index) -> VRegType` + +**Description**: Indexed gather from UB using per-lane offsets. Only the first `active_lanes` lanes participate. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` (UB) | Source buffer | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of participating lanes | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +--- + +#### `pto.vgather2_bc(buf: PtrType, offsets: Index, mask: MaskType) -> VRegType` + +**Description**: Indexed gather with mask. Masked-off lanes are zero-filled. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` (UB) | Source buffer | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `mask` | `MaskType` | Mask gating lane participation | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +--- + +#### `pto.vgatherb(buf: PtrType, offsets: Index) -> VRegType` + +**Description**: Byte-granularity gather load. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` | Source buffer | +| `offsets` | `Index` | Byte offsets | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +--- + +#### `pto.vsldb(tile[row, col], offset: Index, mask: MaskType) -> VRegType` +#### `pto.vsldb(tile[pos], offset: Index, mask: MaskType) -> VRegType` +#### `pto.vsldb(buf: PtrType, offset: Index, mask: MaskType) -> VRegType` + +**Description**: Block-strided load. The `offset` encodes a packed stride/control word, not a plain byte displacement. Masked-off blocks are zeroed. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `offset` | `Index` | Packed stride/control word | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Block-strided vector | + +## 7.4 Vector stores (simd) + +Vector stores write `vreg` contents back to UB tiles. Like loads, they support tile-index syntax. + +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType, dist: VStoreDist | None = None) -> None` +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType, dist: VStoreDist | None = None) -> None` +#### `pto.vsts(vec: VRegType, buf: PtrType, offset: Index, mask: MaskType, dist: VStoreDist | None = None) -> None` + +**Description**: Stateless vector store to UB. The mask gates which lanes are written. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | +| `mask` | `MaskType` | Predicate mask gating writes | +| `dist` | `VStoreDist` or `None` | Optional store distribution: `NORM_B32` (default), `PK_B16`/`PK_B32`/`PK_B64`, `ONE_POINT_B8`/`ONE_POINT_B16`/`ONE_POINT_B32` | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.psts(mask: MaskType, buf: PtrType, offset: Index) -> None` + +**Description**: Predicate store. Writes the packed predicate payload of `mask` to UB memory. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate payload to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `offset` | `Index` | Byte offset | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, buf: PtrType, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` + +**Description**: Dual interleaving store (SoA → AoS). Interleaves two vectors into one destination. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements) | +| `high` | `VRegType` | Second vector (odd elements) | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | +| `dist` | `InterleaveDist` | `INTLV` | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, buf: PtrType, offset: Index, mask: MaskType) -> None` + +**Description**: Scalar broadcast store. Stores a scalar value replicated to all lanes under `mask`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value to broadcast | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vsstb(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsstb(scalar: ScalarType, tile[start:], mask: MaskType) -> None` +#### `pto.vsstb(scalar: ScalarType, buf: PtrType, offset: Index, mask: MaskType) -> None` + +**Description**: Enhanced scalar broadcast store. Same semantics as `vsst`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value to broadcast | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vsta(align: AlignType, tile[row, col:]) -> None` +#### `pto.vsta(align: AlignType, tile[start:]) -> None` +#### `pto.vsta(align: AlignType, buf: PtrType, offset: Index) -> None` + +**Description**: Flush alignment state to memory. Commits buffered tail bytes from an unaligned store stream. Consumes the alignment state. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `AlignType` | Pending store-alignment state | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Byte offset (pointer form) | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vstas(align: AlignType, tile[row, col:], offset: Index) -> None` +#### `pto.vstas(align: AlignType, tile[start:], offset: Index) -> None` +#### `pto.vstas(align: AlignType, buf: PtrType, offset: Index) -> None` + +**Description**: Scalar-register-offset form of alignment-state flush. Same buffered-tail semantics as `vsta` with an explicit scalar offset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `AlignType` | Pending store-alignment state | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Byte offset (all forms) | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vstar(align: AlignType, tile[row, col:]) -> None` +#### `pto.vstar(align: AlignType, tile[start:]) -> None` +#### `pto.vstar(align: AlignType, buf: PtrType) -> None` + +**Description**: Register-update form of alignment-state flush. Consumes the implicit update state from the matching store stream. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `AlignType` | Pending store-alignment state | + +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vscatter(vec: VRegType, buf: PtrType, offsets: Index, active_lanes: Index) -> None` + +**Description**: Indexed scatter to UB. Stores vector lanes to irregular locations using per-lane offsets. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Source vector to scatter | +| `buf` | `PtrType` (UB) | Destination buffer | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of participating lanes | + +**Returns**: None (side-effect operation). + +--- + +### Stateful store family + +For streaming unaligned stores with explicit alignment threading: + +#### `pto.vstu(align_in: AlignType, base_in: PtrType, vec: VRegType, buf: PtrType, mode: Index) -> (AlignType, PtrType)` + +**Description**: Unaligned store with explicit threaded alignment/base state. Returns updated state for the next store in the stream. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming store-alignment state | +| `base_in` | `PtrType` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `mode` | `Index` | Post-update mode | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated buffered-tail state | +| `base_out` | `PtrType` | Post-update base pointer | + +--- + +#### `pto.vstus(align_in: AlignType, base_in: PtrType, vec: VRegType, buf: PtrType, offset: Index) -> (AlignType, PtrType)` + +**Description**: Scalar-offset unaligned store. Same roles as `vstu` with explicit scalar displacement. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming store-alignment state | +| `base_in` | `PtrType` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `offset` | `Index` | Scalar displacement | +| `mode` | `Index` | Post-update mode | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated buffered-tail state | +| `base_out` | `PtrType` | Post-update base pointer | + +--- + +#### `pto.vstur(align_in: AlignType, vec: VRegType, buf: PtrType, mode: PostUpdateMode = PostUpdateMode.NO_POST_UPDATE) -> AlignType` + +**Description**: Register-update unaligned store. Updates only residual alignment state without base pointer update. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming store-alignment state | +| `vec` | `VRegType` | Vector to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `mode` | `PostUpdateMode` | `NO_POST_UPDATE` (default) or `POST_UPDATE` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated buffered-tail state | + +--- + +#### `pto.pstu(align_in: AlignType, mask: MaskType, buf: PtrType) -> (AlignType, PtrType)` + +**Description**: Predicate unaligned store with alignment state threading. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming store-alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `PtrType` (UB) | Destination buffer | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated alignment state | +| `base_out` | `PtrType` | Post-update base pointer | + +--- + +**Unaligned store stream pattern** — prime, thread, flush: + +```python +align, base = pto.vstu(align0, base0, vec0, ub_ptr, mode) +align, base = pto.vstu(align, base, vec1, ub_ptr, mode) +pto.vsta(align, ub_ptr, flush_offset) +``` + +### Distribution enums reference + +| Enum | Values | Used with | +|------|--------|-----------| +| `VLoadDist` | `NORM`, `UNPK_B8`, `UNPK_B16`, `UNPK_B32`, `BRC_B8`, `BRC_B16`, `BRC_B32`, `US_B8`, `US_B16`, `DS_B8`, `DS_B16` | `vlds` | +| `VStoreDist` | `NORM_B8`, `NORM_B16`, `NORM_B32`, `ONE_POINT_B8`, `ONE_POINT_B16`, `ONE_POINT_B32`, `PK_B16`, `PK_B32`, `PK_B64`, `PK4_B32`, `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | `vsts` | +| `DeinterleaveDist` | `DINTLV`, `BDINTLV` | `vldsx2` | +| `InterleaveDist` | `INTLV` | `vstsx2` | +| `StrideMode` | `S3_B16`, `S4_B64`, `S8_B32`, `S2_B64` | `vsld` | +| `PostUpdateMode` | `NO_POST_UPDATE`, `POST_UPDATE` | `vstur` | + +## 7.5 Cube data movement (cube) + +Inside `@pto.cube`, data flows through a hierarchy of private buffers: GM → L1 (cbuf) → L0A/L0B (operand buffers) → L0C (accumulator) → UB or back to GM. + +### Staging: GM → L1 and L1 → UB + +#### `pto.mte_gm_l1(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured GM-to-L1 (cbuf) data movement. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (GM) | Global Memory source pointer | +| `dst` | `PtrType` (L1) | L1 (cbuf) destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop parameters | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_gm_l1_frac(src: PtrType, dst: PtrType, mode: FractalMode, *, shape: tuple[int, int], src_layout: tuple[int, int], dst_group: tuple[int, int, int, int], ctrl: tuple[int, bool]) -> None` + +**Description**: Fractal GM-to-L1 load for specialized layouts (`ND2NZ`, `DN2NZ`). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (GM) | Global Memory source pointer | +| `dst` | `PtrType` (L1) | L1 destination pointer | +| `mode` | `FractalMode` | `ND2NZ` or `DN2NZ` | +| `shape` | `tuple[int, int]` | `(n_value, d_value)` | +| `src_layout` | `tuple[int, int]` | `(inner_stride, outer_stride)` | +| `dst_group` | `tuple[int, int, int, int]` | `(group_count, loop2_stride, loop3_stride, loop4_stride)` | +| `ctrl` | `tuple[int, bool]` | `(l2_cache_ctrl, smallc0_en)` | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l1_ub(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured L1 (cbuf) to UB data movement. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (UB) | UB destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop parameters | + +**Returns**: None (side-effect operation). + +--- + +### Operand loading: L1 → L0A / L0B + +#### `pto.mte_l1_l0a(src: PtrType, dst: PtrType, m: int, k: int) -> None` + +**Description**: Structured L1-to-L0A (left-operand buffer) load. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (L0A) | L0A destination pointer | +| `m` | `int` | M dimension size | +| `k` | `int` | K dimension size | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l1_l0b(src: PtrType, dst: PtrType, k: int, n: int, *, transpose: bool = False) -> None` + +**Description**: Structured L1-to-L0B (right-operand buffer) load. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (L0B) | L0B destination pointer | +| `k` | `int` | K dimension size | +| `n` | `int` | N dimension size | +| `transpose` | `bool` | Whether to load in transposed order | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l1_l0a_mx(src: PtrType, dst: PtrType, m: int, k: int) -> None` +#### `pto.mte_l1_l0b_mx(src: PtrType, dst: PtrType, k: int, n: int) -> None` + +**Description**: MX-mode variants of `mte_l1_l0a` and `mte_l1_l0b` for MX-capable dtypes. Parameters same as their non-MX counterparts. + +--- + +### Bias loading + +#### `pto.mte_l1_bias(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0)) -> None` + +**Description**: Structured L1 (cbuf) to bias table load. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (BIAS) | Bias table destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_gap, dst_gap)` | + +**Returns**: None (side-effect operation). + +--- + +### Accumulator writeback: L0C → L1 / GM / UB + +#### `pto.mte_l0c_l1(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, mode: FractalMode = FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (acc) to L1 (cbuf) writeback. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L0C) | L0C accumulator source pointer | +| `dst` | `PtrType` (L1) | L1 destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `mode` | `FractalMode` | `NZ2ND` (default), `NZ2DN`, or `NZ2NZ` | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l0c_gm(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, sid: int = 0, l2_cache_ctrl: int = 0, mode: FractalMode = FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (acc) to GM writeback. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L0C) | L0C accumulator source pointer | +| `dst` | `PtrType` (gm) | GM destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `sid` | `int` | Stream ID (default 0) | +| `l2_cache_ctrl` | `int` | L2 cache control (default 0) | +| `mode` | `FractalMode` | `NZ2ND` (default), `NZ2DN`, or `NZ2NZ` | +| `loop0_src_stride` | `int` or `None` | Loop level 0 source stride | +| `split` | `int` or `None` | Split parameter | +| `loop3` | `tuple[int, int, int]` or `None` | Loop level 3 parameters | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l0c_ub(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, dual_dst_mode: int = 0, sub_blockid: int = 0, mode: FractalMode = FractalMode.NZ2ND, loop0_src_stride: int | None = None, channel_split_en: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (acc) directly to UB. This is the most common writeback path for cube kernels that feed results into subsequent processing. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L0C) | L0C accumulator source pointer | +| `dst` | `PtrType` (ub) | UB destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `dual_dst_mode` | `int` | Dual destination mode (default 0) | +| `sub_blockid` | `int` | Sub-block ID (default 0) | +| `mode` | `FractalMode` | `NZ2ND` (default), `NZ2DN`, or `NZ2NZ` | +| `loop0_src_stride` | `int` or `None` | Loop level 0 source stride | +| `channel_split_en` | `int` or `None` | Channel split enable (required for `NZ2NZ` mode) | +| `loop3` | `tuple[int, int, int]` or `None` | Loop level 3 parameters | + +**Returns**: None (side-effect operation). + +--- + +### Cube data movement quick reference + +| Data Flow | Operation | Src Space | Dst Space | +|-----------|-----------|-----------|-----------| +| GM → L1 | `mte_gm_l1` | gm | l1 | +| GM → L1 (fractal) | `mte_gm_l1_frac` | gm | l1 | +| L1 → UB | `mte_l1_ub` | l1 | ub | +| L1 → L0A | `mte_l1_l0a` | l1 | l0a | +| L1 → L0B | `mte_l1_l0b` | l1 | l0b | +| L1 → L0A (MX) | `mte_l1_l0a_mx` | l1 | l0a | +| L1 → L0B (MX) | `mte_l1_l0b_mx` | l1 | l0b | +| L1 → Bias | `mte_l1_bias` | l1 | bt | +| L0C → L1 | `mte_l0c_l1` | l0c | l1 | +| L0C → GM | `mte_l0c_gm` | l0c | gm | +| L0C → UB | `mte_l0c_ub` | l0c | ub | + +### Typical cube dataflow in a matmul + +A full cube matmul (`@pto.cube`) follows this dataflow pattern: + +```python +@pto.cube +def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): + m = q_tile.valid_shape[0] + k = q_tile.valid_shape[1] + n = k_tile.valid_shape[0] + + pto.mte_l1_l0a(q_tile, q_l0a, m, k) # UB tile → L0A + pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) # UB tile → L0B + pto.mad(q_l0a, k_l0b, s_acc) # L0A × L0B → L0C + pto.mte_l0c_ub(s_acc, s_tile, m, n) # L0C → UB tile +``` + +The `mte_l1_l0a`/`mte_l1_l0b` operations take UB `Tile` references directly (not raw pointers) — the tile-to-cube-local transfer is implicit. `mad` performs the matrix multiply. `mte_l0c_ub` writes the result back to a UB tile. diff --git a/ptodsl/docs/user_guide/07-frontend-operations.md b/ptodsl/docs/user_guide/07-frontend-operations.md deleted file mode 100644 index 621a8c78f..000000000 --- a/ptodsl/docs/user_guide/07-frontend-operations.md +++ /dev/null @@ -1,352 +0,0 @@ - -### Frontend-only Authoring Operations - -Operations in this family affect descriptor construction and code generation -shape. They are consumed by the frontend and do not correspond to runtime VPTO -instructions by themselves. - -#### `pto.constexpr(value: bool) -> bool` - -**Description**: Compile-time conditional construct for kernel specialization. Marks a boolean expression for evaluation during descriptor materialization, enabling branch elimination based on static compile-time information. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `value` | `bool` | Boolean expression that must be evaluable at compile time. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `bool` | A frontend-only compile-time boolean used to guard `if` statements. | - -**Behavior**: -- Evaluated during kernel descriptor materialization before semantic analysis and lowering. -- When used in `if pto.constexpr(...):` statements, only the selected branch is retained; the other branch is discarded entirely. -- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. -- Does not generate runtime control flow or value merging logic. - -**Examples**: -```python -# Specialize based on element size -dtype = dst.element_type -elem_bytes = pto.bytewidth(dtype) - -if pto.constexpr(elem_bytes == 2): - # Specialized path for 16-bit types (f16/bf16) - ... -else: - # Fallback path for other types - ... -``` - -```python -# Specialize based on tile shape -rows, cols = dst.shape - -if pto.constexpr(rows == 1 and cols == 16): - # Fast path for specific tile configuration - ... -``` - -**Constraints**: -- `pto.constexpr` is a frontend-only authoring construct with no runtime representation. -- The condition must be statically evaluable from descriptor-time information (data types, tile shapes, literals, etc.). -- For kernel-level specialization, prefer `constraints=[...]` and `pto.select_kernel(...)`. -- See [Compile-time Specialization with `pto.constexpr`](04-template-kernels.md#compile-time-specialization-with-ptoconstexpr) for detailed usage guidelines. - -### Type Query Operations - -Operations for querying type properties. - -#### `pto.bytewidth(dtype: Type) -> pto.i32` - -**Description**: Returns the size in bytes of a single element of the given data type. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `size` | `pto.i32` | Element size in bytes | - -**Example**: -```python -f32_size = pto.bytewidth(pto.f32) # Returns 4 -f16_size = pto.bytewidth(pto.f16) # Returns 2 -i8_size = pto.bytewidth(pto.i8) # Returns 1 -ui64_size = pto.bytewidth(pto.ui64) # Returns 8 -``` - -**Common Use Case**: Calculate byte offsets for memory access: -```python -element_type = pto.f32 -byte_offset = index * pto.bytewidth(element_type) -``` - -#### `pto.elements_per_vreg(dtype: Type) -> pto.i32` - -**Description**: Returns the number of elements per vector register for a given element type, based on the hardware vector register size (256 bytes). This function computes `256 // bytewidth(dtype)`, which represents the maximum number of elements of the given type that can fit in a single vector register. Useful for determining vector width and loop stride calculations. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `elems` | `pto.i32` | Number of elements per vector register for the given element type | - -**Example**: -```python -f32_elems_per_vreg = pto.elements_per_vreg(pto.f32) # Returns 64 (256 / 4) -f16_elems_per_vreg = pto.elements_per_vreg(pto.f16) # Returns 128 (256 / 2) -i8_elems_per_vreg = pto.elements_per_vreg(pto.i8) # Returns 256 (256 / 1) -si16_elems_per_vreg = pto.elements_per_vreg(pto.si16) # Returns 128 (256 / 2) -``` - -**Common Use Case**: Loop stride calculation for vector operations: -```python -dtype = pto.f32 -elems_per_vreg = pto.elements_per_vreg(dtype) # Returns 64 for f32 -for col in range(0, cols, elems_per_vreg): - # Load/store vectors of 'elems_per_vreg' elements - pass -``` - -**Relationship with `pto.bytewidth`**: -```python -# The relationship between bytewidth and elements per vector register: -elems = 256 // pto.bytewidth(dtype) -# This is equivalent to: -elems = pto.elements_per_vreg(dtype) -``` - -### Runtime Block Query Operations - -These ops expose the current kernel instance's execution coordinates to scalar -code. They are pure scalar producers: - -- they do not move data -- they do not allocate buffers -- they do not by themselves create `vecscope` boundaries - -Their main purpose is workload partitioning. A common pattern is: - -1. query the current block or subblock id -2. compute a per-instance starting offset -3. use that offset to derive GM/UB pointers or TensorView slices -4. run the local tile or vector loop for that partition - -#### `pto.get_block_idx() -> pto.i64` - -**Description**: Returns the current block ID for the running kernel instance. - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `block` | `pto.i64` | Current block index in the range `[0, pto.get_block_num())` | - -**Behavior**: -- The returned value is launch-instance-local and may differ across concurrently running blocks. -- The value is stable for the lifetime of one kernel instance. -- The op is scalar-only and can be used before pointer arithmetic, TensorView partitioning, DMA setup, or loop construction. - -#### `pto.get_subblock_idx() -> pto.i64` - -**Description**: Returns the current subblock ID visible to the running kernel instance. - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `subblock` | `pto.i64` | Current subblock index in the range `[0, pto.get_subblock_num())` | - -**Behavior**: -- Used when one block is further subdivided by the launch/runtime model. -- Like `pto.get_block_idx()`, this is a pure scalar query with no side effects. - -#### `pto.get_block_num() -> pto.i64` - -**Description**: Returns the total number of blocks visible to the current kernel launch. - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `block_num` | `pto.i64` | Total block count for the current launch domain | - -**Behavior**: -- Typically paired with `pto.get_block_idx()` to compute per-block ranges. -- The result is a runtime value and should not be assumed to be a compile-time constant. - -#### `pto.get_subblock_num() -> pto.i64` - -**Description**: Returns the total number of subblocks visible to the current execution instance. - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `subblock_num` | `pto.i64` | Total subblock count in the current runtime execution domain | - -**Behavior**: -- Typically paired with `pto.get_subblock_idx()` for finer-grained partitioning inside one block. - -**Example**: -```python -block = pto.get_block_idx() -block_num = pto.get_block_num() -subblock = pto.get_subblock_idx() -subblock_num = pto.get_subblock_num() -``` - -**Typical Use Case**: Compute a per-block base pointer. -```python -block = pto.get_block_idx() -block_len = 2048 -base_elem = block * block_len -block_src = pto.addptr(src_gm, base_elem) -block_dst = pto.addptr(dst_gm, base_elem) -``` - -**Constraints**: -- These ops return runtime scalar values, not compile-time specialization constants. -- They are intended for scalar address/control computation, not as vector operands. -- When mixing them with pointer arithmetic, remember that `pto.addptr(...)` uses element offsets, not byte offsets. - -### Scalar Pointer Helpers [Advanced Tier] - -These ops perform scalar element access on typed PTO pointers. Unlike -`pto.vlds(...)` / `pto.vsts(...)`, they operate on exactly one element and do -not create or consume vector registers or masks. - -They are useful when a kernel needs a small amount of scalar state next to -vector code, for example: - -- reading a scalar coefficient or loop-carried value from UB -- writing a scalar flag or reduction result -- patching a small header/metadata area without vector load-store semantics - -#### `pto.load_scalar(ptr: PtrType, offset: Index) -> ScalarType` -#### `pto.load_scalar(dtype: Type, ptr: PtrType, offset: Index) -> ScalarType` - -**Description**: Loads one scalar element from a typed PTO pointer at the given element offset. - -**Parameters (`load_scalar(ptr, offset)`)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `ptr` | `PtrType` | Typed pointer created by `pto.ptr(...)`, `pto.castptr(...)`, `Tile.as_ptr()`, or `TensorView.as_ptr()` | -| `offset` | `Index` | Element displacement from `ptr` | - -**Parameters (`load_scalar(dtype, ptr, offset)`)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `dtype` | `Type` | Optional explicit result dtype; must match the pointer element type | -| `ptr` | `PtrType` | Typed pointer source | -| `offset` | `Index` | Element displacement from `ptr` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `value` | `ScalarType` | One scalar element loaded from `ptr[offset]` | - -**Behavior**: -- Access is element-based, not byte-based. -- The loaded value has the same scalar dtype as the pointer element type. -- This is a scalar memory helper; it does not participate in vector distribution families such as `dist`. -- It may target any memory space represented by the pointer type; the memory-space legality follows the pointer producer. - -#### `pto.store_scalar(ptr: PtrType, offset: Index, value: ScalarType) -> None` -#### `pto.store_scalar(value: ScalarType, ptr: PtrType, offset: Index) -> None` - -**Description**: Stores one scalar element to a typed PTO pointer at the given element offset. - -**Parameters (`store_scalar(ptr, offset, value)`)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `ptr` | `PtrType` | Typed destination pointer | -| `offset` | `Index` | Element displacement from `ptr` | -| `value` | `ScalarType` | Scalar value to write | - -**Parameters (`store_scalar(value, ptr, offset)`)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `value` | `ScalarType` | Scalar value to write | -| `ptr` | `PtrType` | Typed destination pointer | -| `offset` | `Index` | Element displacement from `ptr` | - -**Returns**: None (side-effect operation) - -**Behavior**: -- Stores exactly one scalar element to `ptr[offset]`. -- Does not consume a predicate mask. -- Does not imply vector-store ordering semantics such as `dist` or unaligned store state. - -**Example**: -```python -value = pto.load_scalar(src_ptr, 0) -pto.store_scalar(dst_ptr, 0, value) -``` - -**Typical Use Case**: Read-modify-write scalar metadata next to vector code. -```python -flag = pto.load_scalar(status_ptr, 0) -# scalar compute on `flag` -pto.store_scalar(status_ptr, 0, flag) -``` - -**Constraints**: -- `ptr` must be a typed `pto.ptr(...)` value. -- `offset` is element-based and must be index-typed after frontend normalization. - Plain integer literals such as `0` are accepted and lowered as index constants. -- The scalar dtype must match the pointer element dtype. -- These ops are advanced pointer-surface operations; prefer Tile/TensorView authoring surfaces when scalar pointer manipulation is not required. - -### Pointer Construction [Advanced Tier] - -Operations for creating and manipulating typed pointers. - -#### `pto.castptr(offset: pto.i64, ptr_type: Type) -> PtrType` - -**Description**: Creates a typed pointer from an integer address, a memref-backed address value, or another typed pointer in the same memory space. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `offset` | `pto.i64` / address-like value | Integer address, memref-backed address value, or existing pointer | -| `ptr_type` | `Type` | Target pointer type (e.g., `pto.ptr(pto.f32, MemorySpace.GM)`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `ptr` | `PtrType` | Typed pointer value | - -**Example**: -```python -ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) -``` - -`TensorView.as_ptr()` and `Tile.as_ptr()` remain the preferred high-level APIs. They lower directly to address-extraction intrinsics (`pto.tensor_view_addr` / `pto.tile_buf_addr`) with pointer result types, while tile slice / buffer-view authoring paths continue to materialize memref results from the same intrinsics. - -#### `pto.addptr(ptr: PtrType, offset: pto.i64) -> PtrType` - -**Description**: Adds an element offset to an existing pointer. The offset is counted in elements, not bytes. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `ptr` | `PtrType` | Source pointer | -| `offset` | `pto.i64` | Element offset to add (counted in elements, not bytes) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `new_ptr` | `PtrType` | Pointer with element offset applied | - -**Example**: -```python -# Advance pointer by 1024 f32 elements (not bytes) -next_ptr = pto.addptr(ub_ptr, 1024) -``` - diff --git a/ptodsl/docs/user_guide/08-compute-operations.md b/ptodsl/docs/user_guide/08-compute-operations.md new file mode 100644 index 000000000..6fddecdf4 --- /dev/null +++ b/ptodsl/docs/user_guide/08-compute-operations.md @@ -0,0 +1,659 @@ +# 8. Compute Operations + +Chapters 6 and 7 covered scalars, pointers, and data movement. This chapter covers everything that actually *computes* — arithmetic, math functions, reductions, comparisons, and matrix multiplication — organized by abstraction level: tile ops (L1), vector ops (L3 SIMD), and cube ops (L3 cube). + +## 8.1 Tile-level compute (L1) + +Tile compute ops are the primary arithmetic surface inside `@pto.jit`. They operate on `Tile` buffers in UB and follow a consistent pattern: each op reads one or more source tiles, optionally a scalar, and writes a destination tile. Shapes and valid regions must be compatible across all operands. + +### 8.1.1 Binary tile-tile arithmetic + +Element-wise operations between two tiles of the same shape. + +#### `pto.tadd(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tsub(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tmul(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tmax(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tmin(src0: Tile, src1: Tile, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src0[i,j] src1[i,j]`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile | +| `src1` | `Tile` | Second source tile | +| `dst` | `Tile` | Destination tile (must be pre-allocated, shape-compatible) | + +**Returns**: None (writes to `dst`). + +**Example**: + +```python +pto.tadd(a_tile, b_tile, o_tile) +pto.tmul(scale_tile, data_tile, scaled_tile) +``` + +--- + +#### `pto.tdiv(src0: Tile, src1: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` + +**Description**: Element-wise division. `precision_mode` can be `DEFAULT` or `HIGH_PRECISION` (f16/f32 only). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | Numerator tile | +| `src1` | `Tile` | Denominator tile | +| `dst` | `Tile` | Destination tile | +| `precision_mode` | `PrecisionMode` | `DEFAULT` (default) or `HIGH_PRECISION` | + +**Returns**: None. + +--- + +### 8.1.2 Tile-scalar arithmetic + +Element-wise operations between a tile and a scalar. + +#### `pto.tadds(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tsubs(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tmuls(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tmaxs(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tmins(src: Tile, scalar: ScalarType, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src[i,j] scalar`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `scalar` | `ScalarType` | Scalar operand (Python number or PTO scalar) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tdivs(numer: Tile | ScalarType, denom: Tile | ScalarType, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` + +**Description**: Element-wise tile-scalar division. Accepts both `(tile, scalar)` and `(scalar, tile)` operand orders. + +--- + +### 8.1.3 Unary math + +Single-source element-wise math functions. + +#### `pto.texp(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tlog(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tsqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.trsqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.trecip(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` + +**Description**: Element-wise `exp`, `ln`, `sqrt`, `1/sqrt`, `1/x`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `dst` | `Tile` | Destination tile | +| `precision_mode` | `PrecisionMode` | `DEFAULT` or `HIGH_PRECISION` | + +**Returns**: None. + +--- + +#### `pto.tabs(src: Tile, dst: Tile) -> None` +#### `pto.tneg(src: Tile, dst: Tile) -> None` + +**Description**: Element-wise absolute value and negation. No precision mode attribute. + +--- + +### 8.1.4 Activation + +#### `pto.trelu(src: Tile, dst: Tile) -> None` + +**Description**: `dst[i,j] = max(0, src[i,j])`. Supported on f16, f32, i32. + +#### `pto.tlrelu(src: Tile, slope: float, dst: Tile) -> None` + +**Description**: Leaky ReLU — `dst[i,j] = src[i,j] >= 0 ? src[i,j] : slope * src[i,j]`. + +--- + +### 8.1.5 Row and column reductions + +Reductions collapse one dimension of a 2D tile, producing a tile with one row or one column. + +#### Row reductions + +#### `pto.trowsum(src: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.trowmax(src: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.trowmin(src: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.trowprod(src: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.trowargmax(src: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.trowargmin(src: Tile, tmp: Tile, dst: Tile) -> None` + +**Description**: For each row `i`, reduce across columns: `dst[i, 0] = _j src[i, j]`. `trowargmax`/`trowargmin` return the column index of the extremum. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (`[rows, cols]`) | +| `tmp` | `Tile` | Scratch tile for intermediate reduction state | +| `dst` | `Tile` | Destination tile (`[rows, 1]`) | + +**Returns**: None. + +--- + +#### Column reductions + +#### `pto.tcolsum(src: Tile, dst: Tile) -> None` +#### `pto.tcolmax(src: Tile, dst: Tile) -> None` +#### `pto.tcolmin(src: Tile, dst: Tile) -> None` +#### `pto.tcolprod(src: Tile, dst: Tile) -> None` + +**Description**: For each column `j`, reduce across rows: `dst[0, j] = _i src[i, j]`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (`[rows, cols]`) | +| `dst` | `Tile` | Destination tile (`[1, cols]`) | + +**Returns**: None. + +--- + +### 8.1.6 Broadcast and expansion + +Expansion ops take a narrow source (scalar, row vector, or column vector) and broadcast it to a full tile shape. They are useful for applying per-row or per-column coefficients to a tile. + +#### Scalar broadcast + +#### `pto.texpands(scalar: ScalarType, dst: Tile) -> None` + +**Description**: `dst[i,j] = scalar` — fills every element of `dst` with the same scalar value. + +--- + +#### Row expansion + +#### `pto.trowexpand(src: Tile, dst: Tile) -> None` + +**Description**: `dst[row, col] = src[row, 0]` — broadcasts each row's single value across all columns of `dst`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (`[rows, 1]`) | +| `dst` | `Tile` | Destination tile (`[rows, cols]`) | + +**Returns**: None. + +--- + +#### Column expansion + +#### `pto.tcolexpand(src: Tile, dst: Tile) -> None` + +**Description**: `dst[row, col] = src[0, col]` — broadcasts each column's single value across all rows of `dst`. + +--- + +#### Row-expand arithmetic + +These combine broadcasting with an arithmetic operation: `src1` is a per-row coefficient tile (`[rows, 1]`) that gets expanded row-wise before the element-wise op with `src0`. + +| Op | Semantics | +|----|-----------| +| `pto.trowexpandadd(src0, src1, dst)` | `dst = src0 + expand_rows(src1)` | +| `pto.trowexpandsub(src0, src1, dst)` | `dst = src0 - expand_rows(src1)` | +| `pto.trowexpandmul(src0, src1, dst)` | `dst = src0 * expand_rows(src1)` | +| `pto.trowexpanddiv(src0, src1, dst)` | `dst = src0 / expand_rows(src1)` (f-only) | +| `pto.trowexpandmax(src0, src1, dst)` | `dst = max(src0, expand_rows(src1))` | +| `pto.trowexpandmin(src0, src1, dst)` | `dst = min(src0, expand_rows(src1))` | +| `pto.trowexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_rows(src1))` (f-only) | + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | Full-shape source tile (`[rows, cols]`) | +| `src1` | `Tile` | Per-row coefficient tile (`[rows, 1]`) | +| `dst` | `Tile` | Destination tile (`[rows, cols]`) | + +**Returns**: None. + +**Example** — apply per-row scale and bias: + +```python +# alpha_tile: [rows, 1], beta_tile: [rows, 1], data_tile: [rows, cols] +pto.trowexpandmul(data_tile, alpha_tile, scaled_tile) +pto.trowexpandadd(scaled_tile, beta_tile, result_tile) +``` + +--- + +#### Column-expand arithmetic + +Same pattern as row-expand arithmetic, but `src1` is a per-column coefficient tile (`[1, cols]`): + +| Op | Semantics | +|----|-----------| +| `pto.tcolexpandadd(src0, src1, dst)` | `dst = src0 + expand_cols(src1)` | +| `pto.tcolexpandsub(src0, src1, dst)` | `dst = src0 - expand_cols(src1)` | +| `pto.tcolexpandmul(src0, src1, dst)` | `dst = src0 * expand_cols(src1)` | +| `pto.tcolexpanddiv(src0, src1, dst)` | `dst = src0 / expand_cols(src1)` (f-only) | +| `pto.tcolexpandmax(src0, src1, dst)` | `dst = max(src0, expand_cols(src1))` | +| `pto.tcolexpandmin(src0, src1, dst)` | `dst = min(src0, expand_cols(src1))` | +| `pto.tcolexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_cols(src1))` (f-only) | + +--- + +### 8.1.7 Selection + +#### `pto.tsel(mask: Tile, src0: Tile, src1: Tile, tmp: Tile, dst: Tile) -> None` + +**Description**: Element-wise ternary: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. The `mask` is an integer tile where zero means false and non-zero means true. + +#### `pto.tsels(mask: Tile, src: Tile, scalar: ScalarType, tmp: Tile, dst: Tile) -> None` + +**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. + +--- + +### 8.1.8 Type conversion + +#### `pto.tcvt(src: Tile, dst: Tile, *, rmode: RoundMode = RoundMode.NONE) -> None` + +**Description**: Element-wise type conversion. The destination tile's `dtype` determines the target type. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `dst` | `Tile` | Destination tile (with target dtype) | +| `rmode` | `RoundMode` | Rounding mode: `NONE`, `RINT`, `ROUND`, `FLOOR`, `CEIL`, `TRUNC`, `ODD`, `CAST_RINT` | + +**Returns**: None. + +--- + +### 8.1.9 Tile compute quick reference + +| Category | Operations | +|----------|------------| +| Binary tile-tile | `tadd`, `tsub`, `tmul`, `tdiv`, `tmax`, `tmin` | +| Tile-scalar | `tadds`, `tsubs`, `tmuls`, `tdivs`, `tmaxs`, `tmins` | +| Unary math | `texp`, `tlog`, `tsqrt`, `trsqrt`, `trecip`, `tabs`, `tneg` | +| Activation | `trelu`, `tlrelu` | +| Row reductions | `trowsum`, `trowmax`, `trowmin`, `trowprod`, `trowargmax`, `trowargmin` | +| Column reductions | `tcolsum`, `tcolmax`, `tcolmin`, `tcolprod` | +| Broadcast | `texpands`, `trowexpand`, `tcolexpand` | +| Row-expand arith | `trowexpandadd`, `trowexpandsub`, `trowexpandmul`, `trowexpanddiv`, `trowexpandmax`, `trowexpandmin`, `trowexpandexpdif` | +| Col-expand arith | `tcolexpandadd`, `tcolexpandsub`, `tcolexpandmul`, `tcolexpanddiv`, `tcolexpandmax`, `tcolexpandmin`, `tcolexpandexpdif` | +| Selection | `tsel`, `tsels` | +| Type conversion | `tcvt` | +| Bitwise | `tnot`, `tand`, `tor`, `txor`, `tshl`, `tshr`, `tands`, `tors`, `txors`, `tshls`, `tshrs` | +| Partial elementwise | `tpartadd`, `tpartmul`, `tpartmax`, `tpartmin` | +| Fill/padding | `tfillpad`, `tfillpad_expand`, `tfillpad_inplace` | + +--- + +## 8.2 Vector compute (L3 — `@pto.simd`) + +Vector compute ops operate on `VRegType` values inside `@pto.simd` sub-kernels. Every vector op takes a `MaskType` predicate that gates which lanes participate; masked-off lanes produce an unspecified result (use the result only where the mask is true, or feed it to a masked store). + +All vector ops in this section follow the pattern established in Section 7.3 for tile-index and pointer-form addressing. The signatures below use the vector-register form — tile-index forms load into `vreg` first, then compute. + +### 8.2.1 Unary vector ops + +#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vneg(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vrec(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vrsqrt(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise unary operation under mask. `vrec` = reciprocal, `vrsqrt` = inverse square root, `vrelu` = `max(0, x)`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match element type) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +**Example**: + +```python +exp_vec = pto.vexp(s_row, col_mask) +``` + +--- + +### 8.2.2 Binary vector ops + +#### `pto.vadd(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vsub(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vmul(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vdiv(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vmax(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vmin(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise binary operation: `result[i] = v0[i] v1[i]` for lanes where `mask[i]` is true. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `v0` | `VRegType` | First operand vector | +| `v1` | `VRegType` | Second operand vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +--- + +**Bitwise binary ops** (integer types only): + +| Op | Semantics | +|----|-----------| +| `pto.vand(v0, v1, mask) -> VRegType` | `v0 & v1` | +| `pto.vor(v0, v1, mask) -> VRegType` | `v0 \| v1` | +| `pto.vxor(v0, v1, mask) -> VRegType` | `v0 ^ v1` | +| `pto.vshl(vec, shift, mask) -> VRegType` | `vec << shift` (per-element) | +| `pto.vshr(vec, shift, mask) -> VRegType` | `vec >> shift` (per-element) | + +--- + +### 8.2.3 Vector-scalar ops + +#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vsubs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise `result[i] = vec[i] scalar`. The scalar is broadcast to all active lanes. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand (uniform across all lanes) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +**Example** — subtract row max from score row (online softmax): + +```python +s_shifted = pto.vsubs(s_row, m_next, col_mask) +``` + +--- + +#### `pto.vlrelu(vec: VRegType, alpha: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Leaky ReLU — `vec[i] >= 0 ? vec[i] : alpha * vec[i]`. + +--- + +### 8.2.4 Full-vector and group reductions + +#### Full-vector reductions + +#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Full-vector sum reduction. Result placed in lane 0. + +#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Full-vector max with argmax. Result lane 0 = max value, lane 1 = max index. + +#### `pto.vcmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Full-vector min with argmin. Result lane 0 = min value, lane 1 = min index. + +--- + +#### Group reductions (per-VLane) + +These reduce within each hardware vector lane group (typically 8 groups per vector). Useful when a vector register holds multiple independent sub-vectors that need separate reductions. + +#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Per-group sum, max, or min. Each group's result is placed in the first lane of that group. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Vector with per-group reduction results | + +**Example** — row max and row sum from online softmax: + +```python +row_max = pto.vcgmax(s_row, col_mask) # per-group max → first lane of each group +row_sum = pto.vcgadd(p_row, col_mask) # per-group sum → first lane of each group +``` + +--- + +#### `pto.vcpadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Inclusive prefix sum (scan). `result[i] = sum_{k=0}^{i} vec[k]` for active lanes. f16 and f32 only. + +--- + +### 8.2.5 Fused and compound ops + +These combine an arithmetic operation with a math function or activation in a single instruction. + +#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, *, part: PartMode = PartMode.EVEN) -> VRegType` + +**Description**: `exp(vec[i] - max_vec[i])` — the stable softmax numerator. `part` controls which half of the vector is computed: `EVEN` or `ODD`. Result type is always f32. + +--- + +#### `pto.vaxpy(alpha: ScalarType, x: VRegType, y: VRegType, mask: MaskType) -> VRegType` + +**Description**: Fused multiply-add: `alpha * x[i] + y[i]`. + +--- + +#### `pto.vaddrelu(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` + +**Description**: `max(0, v0[i] + v1[i])` — fused add + ReLU. + +#### `pto.vsubrelu(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` + +**Description**: `max(0, v0[i] - v1[i])` — fused sub + ReLU. + +--- + +### 8.2.6 Comparison and selection + +#### `pto.vcmp(v0: VRegType, v1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Element-wise comparison producing a predicate mask. `seed_mask` selects which lanes participate; the result inherits its granularity (e.g., `mask_b32` for f32). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `v0` | `VRegType` | First operand | +| `v1` | `VRegType` | Second operand | +| `seed_mask` | `MaskType` | Seed mask gating participation | +| `cmp_mode` | `CmpMode` | `EQ`, `NE`, `LT`, `LE`, `GT`, `GE` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `pred` | `MaskType` | Result predicate mask | + +--- + +#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Vector-scalar comparison. Same semantics as `vcmp` with a uniform scalar second operand. + +--- + +#### `pto.vsel(true_v: VRegType, false_v: VRegType, mask: MaskType) -> VRegType` + +**Description**: Per-lane select: `mask[i] ? true_v[i] : false_v[i]`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `true_v` | `VRegType` | Values when mask is true | +| `false_v` | `VRegType` | Values when mask is false | +| `mask` | `MaskType` | Selection predicate | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Selected vector | + +--- + +### 8.2.7 Vector compute quick reference + +| Category | Operations | +|----------|------------| +| Unary | `vexp`, `vln`, `vsqrt`, `vabs`, `vneg`, `vrec`, `vrsqrt`, `vrelu`, `vnot`, `vmov`, `vcls`, `vbcnt` | +| Binary | `vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor`, `vshl`, `vshr`, `vmod` | +| Vector-scalar | `vadds`, `vsubs`, `vmuls`, `vmaxs`, `vmins`, `vshls`, `vshrs`, `vlrelu`, `vands`, `vors`, `vxors` | +| Broadcast | `vbr`, `vdup` | +| Full reduction | `vcadd`, `vcmax`, `vcmin` | +| Group reduction | `vcgadd`, `vcgmax`, `vcgmin` | +| Scan | `vcpadd` | +| Fused | `vexpdif`, `vaxpy`, `vprelu`, `vaddrelu`, `vsubrelu`, `vmulconv`, `vaddreluconv` | +| Compare/select | `vcmp`, `vcmps`, `vsel`, `vselr`, `vselrv2` | +| Carry | `vaddc`, `vsubc`, `vaddcs`, `vsubcs` | +| Extended arith | `vmull`, `vmula` | +| Conversion | `vcvt`, `vtrc`, `vbitcast`, `pbitcast` | +| Index gen | `vci` | +| Rearrangement | `vintlv`, `vdintlv`, `vintlvv2`, `vdintlvv2`, `vsqz`, `vusqz`, `vpack`, `vsunpack`, `vzunpack`, `vperm`, `vshift`, `vslide`, `vsort32`, `vmrgsort`, `vtranspose` | + +--- + +## 8.3 Cube compute (L3 — `@pto.cube`) + +The Cube unit performs matrix multiplication. Its operands are typed pointers into cube-local buffers — L0A (left operand), L0B (right operand), L0C (accumulator), and BIAS. Cube data movement (`mte_l1_l0a`, `mte_l1_l0b`, `mte_l0c_ub`, etc.) was covered in Section 7.5; this section covers the compute instruction itself. + +### 8.3.1 Matrix multiply: `pto.mad` + +#### `pto.mad(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, k: int, n: int) -> None` + +**Description**: Zero-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N]`. `lhs` is an L0A pointer, `rhs` is an L0B pointer, `dst` is an L0C pointer. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `lhs` | `PtrType` (L0A) | Left operand matrix (M × K) | +| `rhs` | `PtrType` (L0B) | Right operand matrix (K × N) | +| `dst` | `PtrType` (L0C) | Destination accumulator (M × N) | +| `m` | `int` | M dimension size | +| `k` | `int` | K dimension (inner/reduction dimension) | +| `n` | `int` | N dimension size | + +**Returns**: None (writes to `dst` in L0C). + +--- + +#### `pto.mad_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, k: int, n: int) -> None` + +**Description**: Accumulating matrix multiply: `dst[M×N] += lhs[M×K] * rhs[K×N]`. `dst` must already hold a prior accumulation result. + +--- + +#### `pto.mad_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, k: int, n: int) -> None` + +**Description**: Bias-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N] + bias[M×N]`. `bias` is a BIAS pointer. + +--- + +### 8.3.2 Typical cube matmul pattern + +A full cube matmul follows a three-stage pattern: stage operands into L0A/L0B, compute, write back to UB. + +```python +@pto.cube +def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): + m = pto.tile_valid_rows(q_tile) + k = pto.tile_valid_cols(q_tile) + n = pto.tile_valid_rows(k_tile) + + # Stage: UB → L0A / L0B + pto.mte_l1_l0a(q_tile, q_l0a, m, k) + pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) + + # Compute: L0A × L0B → L0C + pto.mad(q_l0a, k_l0b, s_acc, m, k, n) + + # Writeback: L0C → UB + pto.mte_l0c_ub(s_acc, s_tile, m, n) +``` + +The `mte_l1_l0a`/`mte_l1_l0b` stage operands from UB into cube-local buffers. `mad` performs the matrix multiply into L0C. `mte_l0c_ub` writes the result back to a UB tile for downstream processing. + +--- + +### 8.3.3 Cube compute quick reference + +| Operation | Semantics | +|-----------|-----------| +| `pto.mad(lhs, rhs, dst, m, k, n)` | `dst = lhs * rhs` (zero-init) | +| `pto.mad_acc(lhs, rhs, dst, m, k, n)` | `dst += lhs * rhs` (accumulating) | +| `pto.mad_bias(lhs, rhs, dst, bias, m, k, n)` | `dst = lhs * rhs + bias` | +| `pto.mad_mx(lhs, rhs, dst, m, k, n)` | MX-format zero-init matmul | +| `pto.mad_mx_acc(lhs, rhs, dst, m, k, n)` | MX-format accumulating matmul | +| `pto.mad_mx_bias(lhs, rhs, dst, bias, m, k, n)` | MX-format bias-init matmul | + +MX variants require MX-enabled dtypes (f8) and pre-loaded scale payloads. For most users, the standard `mad`, `mad_acc`, and `mad_bias` are the primary interface. diff --git a/ptodsl/docs/user_guide/08-sync-dma-operations.md b/ptodsl/docs/user_guide/08-sync-dma-operations.md deleted file mode 100644 index 883e5104a..000000000 --- a/ptodsl/docs/user_guide/08-sync-dma-operations.md +++ /dev/null @@ -1,622 +0,0 @@ -### Synchronization & Buffer Control - -Operations for pipeline synchronization and buffer management. - -#### Enum Types for Synchronization - -The following enum types provide type-safe parameter specification for synchronization operations: - -- **`BarrierType`**: Memory barrier types for `pto.mem_bar` - - `VV_ALL`, `VST_VLD`, `VLD_VST`, `VST_VST`: vector→vector barriers - - `VS_ALL`, `VST_LD`, `VLD_ST`, `VST_ST`: vector→scalar barriers - - `SV_ALL`, `ST_VLD`, `LD_VST`, `ST_VST`: scalar→vector barriers - -- **`Pipe`**: Hardware pipeline identifiers - - `MTE2`: Memory Transfer Engine 2 pipeline - - `V`: Vector pipeline - - `MTE3`: Memory Transfer Engine 3 pipeline - - `ALL`: All pipelines (for barrier operations) - -- **`Event`**: Event identifiers for synchronization - - `ID0`, `ID1`, `ID2`, `ID3`, ..., `ID31`: Event IDs 0-31 (A5 supports 32 event IDs, 0-15 for subblock 0, 16-31 for subblock 1) - -#### `pto.set_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` - -**Description**: Sets a synchronization flag between hardware pipelines. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | -| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | -| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import PIPE, EVENT - -pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) -``` - -#### `pto.wait_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` - -**Description**: Waits for a synchronization flag between hardware pipelines. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | -| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | -| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import PIPE, EVENT - -pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) -``` - -#### `pto.pipe_barrier(pipes: PIPE) -> None` - -**Description**: Executes a barrier across specified pipelines. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipes` | `PIPE` | Pipeline specification (e.g., `PIPE.ALL`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import PIPE - -pto.pipe_barrier(PIPE.ALL) -``` - -#### `pto.get_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` - -**Description**: Acquire buffer slot for inter-pipeline double-buffering coordination. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | -| `buf_id` | `pto.i64` | Buffer identifier | -| `mode` | `pto.i64` | Acquisition mode | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import Pipe - -# Acquire buffer for MTE2 pipeline -pto.get_buf(Pipe.MTE2, 0, 0) -``` - -#### `pto.rls_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` - -**Description**: Release buffer slot to allow other pipeline to proceed. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | -| `buf_id` | `pto.i64` | Buffer identifier | -| `mode` | `pto.i64` | Release mode | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import Pipe - -# Release buffer for MTE2 pipeline -pto.rls_buf(Pipe.MTE2, 0, 0) -``` - -#### `pto.mem_bar(barrier_type: BarrierType) -> None` - -**Description**: Memory barrier for pipeline synchronization within vector scope. Required when UB addresses alias between vector load/store operations. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `barrier_type` | `BarrierType` | Barrier type controlling prior/subsequent instruction ordering. Supported values are `BarrierType.VV_ALL`, `BarrierType.VST_VLD`, `BarrierType.VLD_VST`, `BarrierType.VST_VST`, `BarrierType.VS_ALL`, `BarrierType.VST_LD`, `BarrierType.VLD_ST`, `BarrierType.VST_ST`, `BarrierType.SV_ALL`, `BarrierType.ST_VLD`, `BarrierType.LD_VST`, and `BarrierType.ST_VST`. | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import BarrierType - -# Ensure stores are visible before loads to same UB region -pto.mem_bar(BarrierType.VST_VLD) -``` - -#### `pto.set_cross_core(core_id: pto.i64, event_id: Event) -> None` - -**Description**: Signal event to another core (cross-core synchronization). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `core_id` | `pto.i64` | Target/source core identifier (platform-specific mapping) | -| `event_id` | `Event` | Cross-core event identifier (e.g., `Event.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import Event - -# Signal event ID0 to core 0 -pto.set_cross_core(0, Event.ID0) -``` - -#### `pto.set_intra_block(block_id: pto.i64, event_id: Event) -> None` - -**Description**: Signal event within a block (A5). Specifies trigger pipe. 1:1 per subblock. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `block_id` | `pto.i64` | Block/pipeline identifier specifying trigger pipe | -| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import Event - -# Signal event ID0 on block/pipeline 0 -pto.set_intra_block(0, Event.ID0) -``` - -#### `pto.set_intra_core(config: pto.i32) -> None` - -**Description**: Configures intra-core synchronization settings. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `config` | `pto.i32` | Configuration value for intra-core synchronization | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.set_intra_core(3) -``` - -#### `pto.wait_flag_dev(core_id: pto.i64, event_id: Event) -> None` - -**Description**: Wait for event from another core. SU-level blocking — entire core stalls. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `core_id` | `pto.i64` | Core identifier | -| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import Event - -# Wait for event ID0 from core 0 -pto.wait_flag_dev(0, Event.ID0) -``` - -#### `pto.wait_intra_core(block_id: pto.i64, event_id: Event) -> None` - -**Description**: Wait for event within block (A5). Specifies which pipeline should wait — only that pipe stalls, SU and other pipes continue. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `block_id` | `pto.i64` | Block/pipeline identifier specifying which pipeline should wait | -| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import Event - -# Wait for event ID0 on block/pipeline 0 -pto.wait_intra_core(0, Event.ID0) -``` - -### DMA Programming [Advanced Tier] - -This section covers Direct Memory Access (DMA) operations for transferring data between Global Memory (GM) and Unified Buffer (UB). DMA operations are performance-critical and require careful configuration of stride parameters and transfer sizes. - -**Key Concepts:** -- **DMA Configuration**: Set stride parameters and loop sizes using `set_loop*_stride_*` and `set_loop_size_*` operations. -- **DMA Execution**: Perform transfers using `copy_gm_to_ubuf`, `copy_ubuf_to_gm`, and `copy_ubuf_to_ubuf` operations. -- **GM→UB Padding**: Optionally fill out-of-bounds regions with a specified value when copying from GM to UB. See [Pad Fill Semantics](#pad-fill-semantics) for details. - -**Usage Flow:** -1. Configure DMA parameters (strides, loop sizes) -2. Execute the DMA transfer operation -3. Optionally enable padding for GM→UB transfers - -**Note**: All DMA operations in this section are part of the **Advanced Tier** and require explicit buffer management and pointer arithmetic. For basic tile-based authoring, refer to the [Basic Authoring Mode](01-introduction.md#basic-vs-advanced-authoring-modes) documentation. - -#### Manual Configuration Example - -```python -# DMA configuration example (requires careful parameter tuning) -pto.set_loop2_stride_outtoub(src_stride=32, dst_stride=128) # Outer loop strides -pto.set_loop1_stride_outtoub(src_stride=1, dst_stride=32) # Inner loop strides -pto.set_loop_size_outtoub(loop1=16, loop2=16) # Transfer size -pto.copy_gm_to_ubuf(src=gm_ptr, dst=ub_ptr, n_burst=16, len_burst=128, gm_stride=128, ub_stride=128) - -``` - -#### Pad Fill Semantics - -When copying data from Global Memory (GM) to Unified Buffer (UB), you can enable padding to fill out-of-bounds regions with a specified value. This is useful when the source data dimensions don't perfectly match the destination tile allocation, or when you need to handle boundary conditions in tiled computations. - -##### How Padding Works - -1. **Configure the hardware pad register**: Call `pto.set_mov_pad_val` to set the pad value in the hardware register. This must be done before any `pto.copy_gm_to_ubuf` operation with padding enabled. - -2. **Enable padding in the DMA operation**: Set `enable_ub_pad=True` in the `pto.copy_gm_to_ubuf` call to activate the padded transfer path. The pad value from the hardware register will be used for filling out-of-bounds regions. - -3. **Hardware mapping**: The `pto.set_mov_pad_val` operation corresponds directly to the low-level VPTO instruction that configures the hardware pad register. There is no automatic translation from tile `PadValue` descriptors—you must explicitly set the pad register before padded DMA transfers. - -##### Example Workflow - -Configure the hardware pad register using `pto.set_mov_pad_val`, then perform the DMA transfer with padding enabled: - -```python -# First, configure the hardware pad register with a scalar value -# For zero fill, use an appropriate scalar type based on your data -pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data - -# Then perform the DMA transfer with padding enabled -pto.copy_gm_to_ubuf( - src=gm_ptr, - dst=ub_ptr, - n_burst=32, - len_burst=200, - gm_stride=200, - ub_stride=256, - enable_ub_pad=True, # Enable padded transfer -) -``` - -##### Accessing Pad Values in Kernel Code - -Tile `PadValue` descriptors can be used within kernel code for computation purposes (e.g., initializing vectors with a specific fill value). However, note that **these descriptors are not automatically used for DMA padding**—you must still call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for GM→UB transfers. - -To access a pad value from a tile descriptor in kernel code: - -```python -# Get the pad descriptor from the destination tile -pad_desc = dst.pad_value - -# Check if a valid pad value is configured -if pto.constexpr(pad_desc != pto.PadValue.NULL): - # Materialize the scalar value - pad_scalar = pad_desc.eval() - - # Use the scalar value (e.g., for vector duplication) - mask = pto.make_mask(pto.f32, PAT.ALL) - pad_vector = pto.vdup(pad_scalar, mask) -``` - -##### Important Notes - -- The `PadValue.NULL` descriptor indicates no pad value is configured. Attempting to call `.eval()` on `PadValue.NULL` will raise a frontend error. -- Custom pad values currently support only 32-bit float payloads (`PadValue.custom_f32(...)`). -- Padding only affects GM→UB transfers (`pto.copy_gm_to_ubuf`). UB→GM and UB→UB transfers do not support padding. -- The padded region is determined by the difference between the tile's `valid_shape` and its full `shape`. Ensure your tile is configured with appropriate dimensions. -- Tile `PadValue` descriptors are not automatically used for DMA padding. You must call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for padded GM→UB transfers. - -##### `pto.set_mov_pad_val` Operation [Advanced Tier] - -The `pto.set_mov_pad_val` operation configures the hardware pad register used for GM→UB transfers when padding is enabled. This operation must be called explicitly before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`, as the TileLang DSL v1 does not automatically translate tile `PadValue` descriptors to hardware register configurations. - -**Operation Signature**: -```python -pto.set_mov_pad_val(pad_value: ScalarType) -> None -``` - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: any 8/16/32-bit integer scalar (`pto.i8`, `pto.si8`, `pto.ui8`, `pto.i16`, `pto.si16`, `pto.ui16`, `pto.i32`, `pto.si32`, `pto.ui32`) plus `pto.f16`, `pto.bf16`, and `pto.f32`. The value's bit pattern is encoded into the hardware pad register. Integer inputs are automatically normalized to the corresponding signless hardware operand width during lowering, so no manual cast is required before calling `pto.set_mov_pad_val`. For standard pad values, use `PadValue.eval(...)` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | - -**Returns**: None (side-effect operation) - -**Example**: - -Using a scalar value directly: -```python -# Configure the hardware pad register for zero fill using an integer scalar -pto.set_mov_pad_val(pto.i32(0)) # Zero fill for integer types - -# Or using a float scalar for floating-point padding -pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float types - -# Perform DMA transfer with padding enabled -pto.copy_gm_to_ubuf( - src=gm_ptr, - dst=ub_ptr, - n_burst=32, - len_burst=200, - gm_stride=200, - ub_stride=256, - enable_ub_pad=True, -) -``` - -Using a tile's pad value descriptor: -```python -# Get the pad value from a tile configuration -pad_desc = tile.pad_value # PadValue enum -if pto.constexpr(pad_desc != pto.PadValue.NULL): - pad_scalar = pad_desc.eval() # Materializes to a scalar value - pto.set_mov_pad_val(pad_scalar) - - # Perform padded DMA transfer - pto.copy_gm_to_ubuf( - src=gm_ptr, - dst=ub_ptr, - n_burst=32, - len_burst=200, - gm_stride=200, - ub_stride=256, - enable_ub_pad=True, - ) -``` - -Using a standalone `PadValue` with an explicit dtype: -```python -pad_scalar = pto.PadValue.MAX.eval(pto.f32) -pto.set_mov_pad_val(pad_scalar) -``` - -For integer tile dtypes such as `pto.ui16` or `pto.si32`, `pad_desc.eval()` can be passed directly to `pto.set_mov_pad_val`. TileLang DSL v1 will automatically insert the required same-width bitcast to the signless hardware operand type during lowering. - -**Important**: You are responsible for ensuring the pad register is properly configured before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`. The pad register configuration persists until changed by another `pto.set_mov_pad_val` call. - -**Future Improvement**: Future versions of TileLang DSL may provide an implicit approach that automatically translates `PadValue` descriptors from tile configurations to hardware register configurations, similar to DMA syntax sugar features. - -#### `pto.set_loop2_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA stride parameters for GM → UB transfers (loop2). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src_stride` | `pto.i64` | Source-side stride | -| `dst_stride` | `pto.i64` | Destination-side stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop1_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA stride parameters for GM → UB transfers (loop1). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src_stride` | `pto.i64` | Source-side stride | -| `dst_stride` | `pto.i64` | Destination-side stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop_size_outtoub(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA transfer size for GM → UB transfers. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `loop1` | `pto.i64` | Inner loop trip count | -| `loop2` | `pto.i64` | Outer loop trip count | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.set_loop_size_outtoub(loop1=1, loop2=1) -``` - -#### `pto.set_loop2_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA stride parameters for UB → GM transfers (loop2). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src_stride` | `pto.i64` | Source-side stride | -| `dst_stride` | `pto.i64` | Destination-side stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop1_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA stride parameters for UB → GM transfers (loop1). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src_stride` | `pto.i64` | Source-side stride | -| `dst_stride` | `pto.i64` | Destination-side stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop_size_ubtoout(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA transfer size for UB → GM transfers. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `loop1` | `pto.i64` | Inner loop trip count | -| `loop2` | `pto.i64` | Outer loop trip count | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop(loop_id: pto.i32, src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA stride parameters for a generic loop. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | -| `src_stride` | `pto.i64` | Source-side stride | -| `dst_stride` | `pto.i64` | Destination-side stride | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.set_loop(1, src_stride=32, dst_stride=64) -``` - -#### `pto.set_loop_size(loop_id: pto.i32, size: pto.i64) -> None` [Advanced Tier] - -**Description**: Configures DMA transfer size for a generic loop. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | -| `size` | `pto.i64` | Loop trip count | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.set_loop_size(1, 16) -``` - -#### DMA Execution Operations - -**Note**: These operations execute DMA transfers but require manual configuration of DMA parameters (loop strides, loop sizes) using the `set_loop*_stride_*` and `set_loop_size_*` operations described above. - -The following operations provide direct control over DMA transfers but require manual stride and size configuration. - -#### `pto.copy_gm_to_ubuf(src: GMPtr, dst: UBPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, left_padding_count: pto.i64 = 0, right_padding_count: pto.i64 = 0, enable_ub_pad: pto.i1 = False, l2_cache_ctl: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] - -**Description**: Copies data from Global Memory (GM) to Unified Buffer (UB). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `GMPtr` | Source GM pointer | -| `dst` | `UBPtr` | Destination UB pointer | -| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | -| `n_burst` | `pto.i64` | Number of bursts | -| `len_burst` | `pto.i64` | Bytes copied by each burst | -| `left_padding_count` | `pto.i64` | Left padding count, defaults to `0` | -| `right_padding_count` | `pto.i64` | Right padding count, defaults to `0` | -| `enable_ub_pad` | `pto.i1` | Convenience alias for `data_select_bit`, defaults to `False` | -| `l2_cache_ctl` | `pto.i64` | L2 cache control operand, defaults to `0` | -| `gm_stride` | `pto.i64` | GM-side stride in bytes | -| `ub_stride` | `pto.i64` | UB-side stride in bytes | - -**Returns**: None (side-effect operation) - -**Notes**: -- **Keyword arguments**: The keyword form shown above is the recommended public API surface. Use named arguments for clarity. -- **Padding control**: Set `enable_ub_pad=True` to enable padded GM→UB transfers. The pad value must be configured separately using `pto.set_mov_pad_val` before the DMA operation (see [Pad Fill Semantics](#pad-fill-semantics) for details). -- **Pad value source**: When padding is enabled, the fill scalar comes from the hardware pad register configured by `pto.set_mov_pad_val`. You must call this operation explicitly before the DMA transfer. -- **ABI compatibility**: The lowering preserves the underlying PTO operand order while providing a more ergonomic keyword interface. - -**Example**: -```python -pto.copy_gm_to_ubuf( - src=gm_ptr, - dst=ub_ptr, - n_burst=32, - len_burst=128, - gm_stride=128, - ub_stride=128, - enable_ub_pad=False, -) -``` - -**Padding Example**: -```python -# First configure the hardware pad register with a scalar value -pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data - -# Then perform padded DMA transfer -pto.copy_gm_to_ubuf( - src=gm_ptr, - dst=ub_ptr, - n_burst=32, - len_burst=200, - gm_stride=200, - ub_stride=256, - enable_ub_pad=True, -) -``` - -#### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` [Advanced Tier] - -**Description**: Copies data within Unified Buffer (UB → UB). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `UBPtr` | Source UB pointer | -| `dst` | `UBPtr` | Destination UB pointer | -| `src_offset` | `pto.i64` | Source offset | -| `src_stride0` | `pto.i64` | Source stride dimension 0 | -| `src_stride1` | `pto.i64` | Source stride dimension 1 | -| `dst_offset` | `pto.i64` | Destination offset | -| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | -| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | - -**Returns**: None (side-effect operation) - -#### `pto.copy_ubuf_to_gm(src: UBPtr, dst: GMPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, reserved: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] - -**Description**: Copies data from Unified Buffer (UB) to Global Memory (GM). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `UBPtr` | Source UB pointer | -| `dst` | `GMPtr` | Destination GM pointer | -| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | -| `n_burst` | `pto.i64` | Number of bursts | -| `len_burst` | `pto.i64` | Bytes copied by each burst | -| `reserved` | `pto.i64` | Reserved operand, defaults to `0` | -| `gm_stride` | `pto.i64` | GM-side stride in bytes | -| `ub_stride` | `pto.i64` | UB-side stride in bytes | - -**Returns**: None (side-effect operation) - -**Notes**: -- In TileLang DSL, the keyword form above is the recommended public surface. -- `gm_stride`/`ub_stride` are ergonomic aliases for the low-level `burst_dst_stride`/`burst_src_stride` operands. -- The lowering still maps to the underlying low-level PTO operand ABI in positional order. - -**Example**: -```python -pto.copy_ubuf_to_gm( - src=ub_ptr, - dst=gm_ptr, - n_burst=32, - len_burst=128, - gm_stride=128, - ub_stride=128, -) -``` diff --git a/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md new file mode 100644 index 000000000..e8cc6bf6b --- /dev/null +++ b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md @@ -0,0 +1,392 @@ +# 9. Predicate and Mask Operations + +Vector operations on the SIMD unit execute across many lanes in parallel — but not all lanes always hold valid data. The last chunk of a row may be shorter than the hardware vector width; a row-wise reduction may need to skip padding elements. **Predicate masks** are the mechanism that gates which lanes participate in an operation. + +This chapter covers mask types, mask creation, logical manipulation, reorganization, and load/store. Comparison operations that *produce* masks from vector data (`vcmp`, `vcmps`) are also covered here, since masks are their primary output. + +## 9.1 Mask types + +The hardware predicate register is a 256-bit register. PTODSL exposes three typed views of it, differing in how many elements each bit represents: + +| Mask type | ALU width | Lanes | Used with vector types | +|-----------|-----------|-------|----------------------| +| `pto.mask_b8` | 8-bit | 256 | `i8` vectors | +| `pto.mask_b16` | 16-bit | 128 | `f16`, `bf16`, `i16` vectors | +| `pto.mask_b32` | 32-bit | 64 | `f32`, `i32` vectors | + +A mask and the vector it gates must share the same granularity: a `mask_b32` gates an `f32` vector (64 lanes), not an `f16` vector (128 lanes). + +**Zeroing predication**: when a lane is masked off, the operation produces zero in that lane. This is the gating model for all vector compute ops in Chapter 8. + +## 9.2 Mask creation: `pto.make_mask` + +The recommended front door for creating masks is `pto.make_mask`. It dispatches to the right underlying op based on its arguments. + +#### `pto.make_mask(dtype: Type, value: int | MaskPattern) -> MaskType | (MaskType, int)` + +**Description**: Creates a predicate mask of the granularity matching `dtype`. When `value` is an `int` (typically a remaining-element count in a chunked loop), returns a tuple `(mask, remaining)`. When `value` is a `MaskPattern`, returns just the mask. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Element type to infer mask granularity from (e.g., `pto.f32` → `mask_b32`, `pto.f16` → `mask_b16`) | +| `value` | `int` or `MaskPattern` | Either a remaining-element count or a pattern token | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | The created mask | +| `remained` | `int` | Updated remaining count (only when `value` is `int`) | + +**Example** — chunked SIMD loop with tail handling: + +```python +VEC = pto.elements_per_vreg(pto.f32) +col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) +with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(tile[r, c:]) + # ... operate under mask ... + pto.vsts(vec, out_tile[r, c:], mask) + col_loop.update(remained=remained) +``` + +`make_mask` generates a tail mask from the remaining count: the first `min(remained, VL)` lanes are active, and `remained` is decremented by `VL` for the next iteration. On the final partial chunk, fewer than `VL` lanes are active. + +--- + +When the mask pattern is known at compile time, pass a `MaskPattern` instead: + +```python +full_mask = pto.make_mask(pto.f32, pto.MaskPattern.ALL) +``` + +This is equivalent to calling the granularity-specific ops described below. + +--- + +## 9.3 Granularity-specific creation ops + +When you need explicit control over the mask granularity, use these ops directly. + +### 9.3.1 Pattern-based: `pset_b*` and `pge_b*` + +`pset` generates a mask from a named pattern. `pge` generates a tail mask where the first N lanes are active (N encoded in the pattern). + +#### `pto.pset_b8(pattern: MaskPattern) -> pto.mask_b8` +#### `pto.pset_b16(pattern: MaskPattern) -> pto.mask_b16` +#### `pto.pset_b32(pattern: MaskPattern) -> pto.mask_b32` + +**Description**: Creates a mask from a pattern token. `PAT_ALL` sets all lanes active; other patterns set a subset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `MaskPattern` | Pattern token: `ALL`, `ALLF`, `H`, `Q`, `VL1`–`VL128`, `M3`, `M4` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Mask with lanes set per the pattern | + +--- + +#### `pto.pge_b8(pattern: MaskPattern) -> pto.mask_b8` +#### `pto.pge_b16(pattern: MaskPattern) -> pto.mask_b16` +#### `pto.pge_b32(pattern: MaskPattern) -> pto.mask_b32` + +**Description**: Tail mask — `mask[i] = (i < N) ? 1 : 0`, where N is encoded in the pattern. Typically uses `VL*` patterns. + +--- + +### 9.3.2 Scalar-driven: `plt_b*` + +`plt` generates a tail mask from a live `i32` scalar — the idiomatic choice for dynamic tail handling when not using `make_mask`. + +#### `pto.plt_b8(scalar: pto.i32) -> (pto.mask_b8, pto.i32)` +#### `pto.plt_b16(scalar: pto.i32) -> (pto.mask_b16, pto.i32)` +#### `pto.plt_b32(scalar: pto.i32) -> (pto.mask_b32, pto.i32)` + +**Description**: Generates a tail mask where the first `min(scalar, VL)` lanes are active, and returns `scalar - min(scalar, VL)` as the updated remaining count. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Remaining element count | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Tail mask (first N lanes active) | +| `scalar_out` | `pto.i32` | Updated remaining = `max(0, scalar - VL)` | + +`VL` is 256 for `b8`, 128 for `b16`, and 64 for `b32`. + +--- + +## 9.4 Mask logical operations + +Once created, masks can be combined with bitwise logical ops. All take a gating mask that selects which lanes participate; inactive lanes are zeroed in the result. + +#### `pto.pand(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` +#### `pto.por(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` +#### `pto.pxor(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise AND / OR / XOR of two masks, gated by a third mask: `dst[i] = mask[i] ? (src0[i] src1[i]) : 0`. All three masks must share the same granularity. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First source mask | +| `src1` | `MaskType` | Second source mask | +| `mask` | `MaskType` | Gating mask (lanes where false produce 0) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Combined mask | + +--- + +#### `pto.pnot(src: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise NOT under gate: `dst[i] = mask[i] ? (~src[i]) : 0`. + +--- + +#### `pto.psel(src0: MaskType, src1: MaskType, sel: MaskType) -> MaskType` + +**Description**: Per-lane mask select: `dst[i] = sel[i] ? src0[i] : src1[i]`. All lanes participate directly — there is no additional gating beyond `sel` itself. + +--- + +## 9.5 Mask reorganization + +These ops reshape masks between granularities and layouts without changing the underlying 256-bit register image (except pack/unpack, which remap bits). + +#### `pto.pbitcast(mask: MaskType, to_type: MaskType) -> MaskType` + +**Description**: Bitwise reinterpretation of a mask at a different granularity. The 256-bit predicate register image is unchanged; only the lane count and element-width interpretation change. + +**Example**: + +```python +# Reinterpret a b16 mask as b32 +mask32 = pto.pbitcast(mask16, pto.mask_b32) +``` + +--- + +#### `pto.ppack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Narrowing pack — keeps one bit out of each adjacent 2-bit group from the source, packing them into the selected half (`LOWER` or `HIGHER`) of the result. The other half is zero-filled. + +#### `pto.punpack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Widening unpack — reads the selected half of the source, zero-extends each 1-bit element into a 2-bit group in the result. + +--- + +#### `pto.pintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` +#### `pto.pintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` +#### `pto.pintlv_b32(src0: pto.mask_b32, src1: pto.mask_b32) -> (pto.mask_b32, pto.mask_b32)` + +**Description**: Interleave two masks element-wise. Returns `(low, high)` where `low[i] = src0[i]` and `high[i] = src1[i]` at each interleaved position. + +#### `pto.pdintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` +#### `pto.pdintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` +#### `pto.pdintlv_b32(src0: pto.mask_b32, src1: pto.mask_b32) -> (pto.mask_b32, pto.mask_b32)` + +**Description**: Deinterleave — the inverse of `pintlv`. Takes interleaved data in two masks and separates even/odd elements. + +--- + +## 9.6 Comparisons: producing masks from vectors + +Vector comparisons produce predicate masks from vector data. The result can feed into mask logical ops, `vsel`, or gated stores. + +#### `pto.vcmp(v0: VRegType, v1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Element-wise vector-vector comparison: `dst[i] = seed_mask[i] ? (v0[i] v1[i]) : 0`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `v0` | `VRegType` | First operand vector | +| `v1` | `VRegType` | Second operand vector | +| `seed_mask` | `MaskType` | Seed mask gating which lanes participate | +| `cmp_mode` | `CmpMode` | `EQ`, `NE`, `LT`, `LE`, `GT`, `GE` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `pred` | `MaskType` | Result predicate mask (inherits granularity from operands) | + +--- + +#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Vector-scalar comparison: `dst[i] = seed_mask[i] ? (vec[i] scalar) : 0`. The scalar is broadcast to all lanes. + +**Example** — threshold a vector: + +```python +big = pto.vcmps(scores, threshold, seed, pto.CmpMode.GT) +# big[i] = 1 where scores[i] > threshold +``` + +--- + +**Tile-level comparisons** (`pto.tcmp`, `pto.tcmps`) compare two tiles and write packed predicate bytes into an `i8` destination tile. They are used when the comparison result needs to be stored to UB for later selection (`tsel`) or cross-kernel communication. + +--- + +## 9.7 Mask load and store + +Masks can be persisted to and loaded from UB memory, enabling cross-stage predicate communication. + +### 9.7.1 Predicate loads + +#### `pto.plds(buf: PtrType, offset: Index, *, dist: PredicateDist = PredicateDist.NORM) -> MaskType` + +**Description**: Load a predicate mask from UB memory at the given byte offset. The mask granularity is inferred from context. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` (UB) | Source buffer | +| `offset` | `Index` | Byte offset | +| `dist` | `PredicateDist` | `NORM` (load VL/8 packed bytes), `US` (upsample), `DS` (downsample) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +--- + +### 9.7.2 Predicate stores + +#### `pto.psts(mask: MaskType, buf: PtrType, offset: Index, *, dist: PredicateDist = PredicateDist.NORM) -> None` + +**Description**: Store a predicate mask to UB memory. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `offset` | `Index` | Byte offset | +| `dist` | `PredicateDist` | `NORM` (store VL/8 packed bytes) or `PK` (pack to VL/16 bytes) | + +**Returns**: None. + +--- + +### 9.7.3 Unaligned predicate store + +#### `pto.pstu(align_in: AlignType, mask: MaskType, buf: PtrType) -> (AlignType, PtrType)` + +**Description**: Unaligned predicate store with alignment state threading. Threads the `align` state through a stream of stores, ensuring tail bytes are correctly buffered. The base pointer type is determined by the mask granularity (`ui16` for `b16`, `ui32` for `b32`). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming alignment state (from `init_align` or previous `pstu`) | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `PtrType` (UB) | Destination buffer | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated alignment state | +| `base_out` | `PtrType` | Post-update base pointer | + + +## 9.8 How masks gate vector operations + +Every vector compute op in Chapter 8 takes a mask as its last operand. The contract is consistent: + +- For **unary ops** (`vexp`, `vabs`, etc.): `dst[i] = mask[i] ? f(src[i]) : 0` +- For **binary ops** (`vadd`, `vmul`, etc.): `dst[i] = mask[i] ? (lhs[i] rhs[i]) : 0` +- For **vector stores** (`vsts`): `dst[i] = mask[i] ? src[i]` — masked-off lanes are not written +- For **reductions** (`vcadd`, `vcgmax`, etc.): only lanes where `mask[i]` is true contribute to the result + +The mask granularity must match the vector element type. Using a `mask_b16` with an `f32` vector (or vice versa) is an error. + +**Typical pattern** — tail-safe vector processing: + +```python +VEC = pto.elements_per_vreg(pto.f32) +with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + + vec = pto.vlds(tile[r, c:]) + vec = pto.vexp(vec, mask) + pto.vsts(vec, out_tile[r, c:], mask) + + col_loop.update(remained=remained) +``` + +The `mask` gates the `vexp` (masked-off lanes produce 0) and the `vsts` (masked-off lanes are not written). `col_loop` carries the remaining count across iterations, so the final partial chunk correctly masks only the valid tail elements. + +--- + +## 9.9 Tile-level mask operations + +When working at the tile level (L1, `@pto.jit`), masks are carried in `i8` tile buffers holding packed predicate bytes. The key consumer of tile-level masks is `tsel`. + +#### `pto.tsel(mask: Tile, src0: Tile, src1: Tile, tmp: Tile, dst: Tile) -> None` + +**Description**: Element-wise ternary select: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. `mask` is an integer tile (typically `i8`) where zero means false. `tmp` is a scratch tile for the underlying implementation. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `Tile` | Integer mask tile (zero = false) | +| `src0` | `Tile` | True-branch source tile | +| `src1` | `Tile` | False-branch source tile | +| `tmp` | `Tile` | Scratch tile | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tsels(mask: Tile, src: Tile, scalar: ScalarType, tmp: Tile, dst: Tile) -> None` + +**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. + +--- + +## 9.10 Enum reference + +| Enum | Values | Used with | +|------|--------|-----------| +| `MaskPattern` | `ALL`, `ALLF`, `H`, `Q`, `VL1`–`VL128`, `M3`, `M4` | `pset_b*`, `pge_b*`, `make_mask` | +| `CmpMode` | `EQ`, `NE`, `LT`, `LE`, `GT`, `GE` | `vcmp`, `vcmps` | +| `PredicateDist` (load) | `NORM`, `US`, `DS` | `plds` | +| `PredicateDist` (store) | `NORM`, `PK` | `psts` | +| `PredicatePart` | `LOWER`, `HIGHER` | `ppack`, `punpack` | diff --git a/ptodsl/docs/user_guide/09-vector-memory-operations.md b/ptodsl/docs/user_guide/09-vector-memory-operations.md deleted file mode 100644 index f7a20fd76..000000000 --- a/ptodsl/docs/user_guide/09-vector-memory-operations.md +++ /dev/null @@ -1,1058 +0,0 @@ -### Enum Types for Vector Memory Operations - -The current DSL exposes type-safe Enum operands for the dual load/store -distribution families: - -- **`VLoadDist`** for `pto.vlds` - - `VLoadDist.NORM`: ordinary load - - `VLoadDist.UNPK_B8`, `VLoadDist.UNPK_B16`, `VLoadDist.UNPK_B32`: unpacking loads - - `VLoadDist.BRC_B8`, `VLoadDist.BRC_B16`, `VLoadDist.BRC_B32`: broadcast loads - - `VLoadDist.US_B8`, `VLoadDist.US_B16`, `VLoadDist.DS_B8`, `VLoadDist.DS_B16`: strided/narrow load families - -- **`VStoreDist`** for `pto.vsts` - - `VStoreDist.NORM_B8`, `VStoreDist.NORM_B16`, `VStoreDist.NORM_B32`: ordinary stores - - `VStoreDist.ONE_POINT_B8`, `VStoreDist.ONE_POINT_B16`, `VStoreDist.ONE_POINT_B32`: one-point stores - - `VStoreDist.PK_B16`, `VStoreDist.PK_B32`, `VStoreDist.PK_B64`: packed stores - - `VStoreDist.PK4_B32`, `VStoreDist.MRG4CHN_B8`, `VStoreDist.MRG2CHN_B8`, `VStoreDist.MRG2CHN_B16`: merged packed stores - -- **`DeinterleaveDist`** for `pto.vldsx2` - - `DeinterleaveDist.DINTLV`: alternating-element deinterleave - - `DeinterleaveDist.BDINTLV`: block deinterleave - - compatibility aliases: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, - `DeinterleaveDist.B32`, `DeinterleaveDist.BD` - -- **`InterleaveDist`** for `pto.vstsx2` - - `InterleaveDist.INTLV`: interleave two vectors into one destination stream - - compatibility aliases: `InterleaveDist.B8`, `InterleaveDist.B16`, - `InterleaveDist.B32` - -- **`PostUpdateMode`** for `pto.vstur` - - `PostUpdateMode.NO_POST_UPDATE`: preserve the current hardware AR state - - `PostUpdateMode.POST_UPDATE`: advance the hardware AR state after the store - -The canonical VPTO v0.3 spellings are the enum values: - -- `VLoadDist.UNPK_B16.value == "UNPK_B16"` -- `VStoreDist.PK_B32.value == "PK_B32"` -- `DeinterleaveDist.DINTLV.value == "DINTLV"` -- `DeinterleaveDist.BDINTLV.value == "BDINTLV"` -- `InterleaveDist.INTLV.value == "INTLV"` -- `PostUpdateMode.NO_POST_UPDATE.value == "NO_POST_UPDATE"` -- `PostUpdateMode.POST_UPDATE.value == "POST_UPDATE"` - -`pto.vstur` mode is intentionally Enum-only in the DSL. Unlike the legacy -distribution-token compatibility retained for some older load/store families, -raw strings such as `"POST_UPDATE"` are not accepted for `PostUpdateMode`. - -For migration convenience, the implementation still accepts legacy raw strings -such as `"DINTLV_B32"` and `"INTLV_B32"`, but new DSL code should prefer the -Enum operands. - -- **`StrideMode`**: Stride modes for `pto.vsld` - - `S3_B16`: Stride 3, block size 16 - - `S4_B64`: Stride 4, block size 64 - - `S8_B32`: Stride 8, block size 32 - - `S2_B64`: Stride 2, block size 64 - -### Address Generation Syntax Sugar - -To simplify address calculation and reduce manual byte offset computation errors, TileLang DSL provides syntactic sugar for vector load/store operations using element-based indexing. This syntax automatically computes the byte offset based on tile shape, element type, and layout. - -#### Indexing Syntax - -The syntax supports two indexing modes for different operations: - -1. **Vector-range indexing** (for vector load/store operations): - - **Row-major layout (default)**: `tile[row_index, col_start:]` - - `row_index`: Row index (0-based) - - `col_start:`: Starting column index followed by colon, indicating a vector-width contiguous region starting from this column - - The colon (`:`) indicates an implicit vector-width range determined by hardware vector size (256 bytes) and element type - - - **Column-major layout**: `tile[row_start:, col_index]` - - `row_start:`: Starting row index followed by colon, indicating a vector-width contiguous region starting from this row - - `col_index`: Column index (0-based) - - Used for column-major tiles (`BLayout.COL_MAJOR`) where elements are stored column-wise - - - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) - - `start:`: Starting element index followed by colon - - Tile indexing sugar only accepts an open-ended vector slice. Python slice - forms with an explicit `stop` or `step` are not supported for `Tile` - indexing. For example, `tile[row, col:col_end]`, `tile[row, col::2]`, - `tile[row_start:row_end, col]`, and `tile[start:stop:step]` are invalid. - -2. **Single-element indexing** (for scalar load operations like `pto.vsld`): - - **Row-major layout (default)**: `tile[row_index, col_index]` - - `row_index`: Row index (0-based) - - `col_index`: Column index (0-based) - - Loads a single element at the specified position and broadcasts it to all vector lanes - - - **Column-major layout**: `tile[row_index, col_index]` (same syntax) - - `row_index`: Row index (0-based) - - `col_index`: Column index (0-based) - - Same syntax as row-major; the layout determines how the offset is computed - - - **1D tile indexing**: `tile[pos]` - - `pos`: Element index (0-based) - - Loads a single element at the specified position and broadcasts it to all vector lanes - -#### Vector Width Calculation - -The number of elements loaded/stored in a single vector operation is determined by: - -``` -vector_lanes = 256 // element_size_bytes(element_type) -``` - -**Convenience API**: Use `pto.elements_per_vreg(dtype)` to compute the number of elements per vector register for a given element type (e.g., `pto.elements_per_vreg(pto.f32)` returns 64, `pto.elements_per_vreg(pto.f16)` returns 128). See [Type Query Operations](07-frontend-operations.md#type-query-operations) for full documentation. - -Where `element_size_bytes` is: -- 1 byte for `i8`, `si8`, `ui8` -- 2 bytes for `i16`, `si16`, `ui16`, `f16`, `bf16` -- 4 bytes for `i32`, `si32`, `ui32`, `f32` -- 8 bytes for `i64`, `si64`, `ui64` - -#### Offset Computation - -The byte offset is automatically computed based on tile layout: - -- **Row-major layout** (`BLayout.ROW_MAJOR`): - ``` - offset = (row_index * stride_row + col_start) * element_size_bytes - ``` - where `stride_row` is the row stride in elements (typically `tile.shape[1]` for contiguous tiles). - -- **Column-major layout** (`BLayout.COL_MAJOR`): - - For syntax `tile[row_start:, col_index]`: - ``` - offset = (col_index * stride_col + row_start) * element_size_bytes - ``` - - For backward compatibility with traditional offset calculation: - ``` - offset = (col_start * stride_col + row_index) * element_size_bytes - ``` - where `stride_col` is the column stride in elements (typically `tile.shape[0]` for contiguous tiles), `row_start` is the starting row index, and `col_index` is the column index. - -**Note**: -- For single-element indexing (`tile[row, col]` or `tile[pos]`), the same offset formulas apply with `col_start` replaced by `col_index` (or `start` replaced by `pos` for 1D tiles). -- For column-major vector-range indexing (`tile[row_start:, col_index]`), the offset formula uses `row_start` as the starting position along the contiguous dimension. -- The compiler automatically handles the appropriate substitution based on the indexing syntax and tile layout. - -#### Constraints - -1. **Boundary checks**: The requested region must be within tile bounds: - - **For vector-range indexing** (`:` syntax): - - **Row-major layout** (`tile[row_index, col_start:]`): - - `row_index < tile.shape[0]` and `col_start + vector_lanes <= tile.shape[1]` - - **Column-major layout** (`tile[row_start:, col_index]`): - - `row_start + vector_lanes <= tile.shape[0]` and `col_index < tile.shape[1]` - - **1D tile indexing**: `tile[start:]` - - `start + vector_lanes <= tile.shape[0]` (or `tile.shape[1]` for 1D tiles) - - **For single-element indexing** (no `:` syntax): - - 2D: `row_index < tile.shape[0]` and `col_index < tile.shape[1]` (same for both layouts) - - 1D: `pos < tile.shape[0]` (or `tile.shape[1]` for 1D tiles) - -2. **Alignment**: The computed offset must satisfy hardware alignment requirements for the operation. - -3. **Full vectors only**: The `:` syntax always loads/stores a full vector width. For partial vectors, use the traditional byte offset approach with explicit mask handling. - -4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. - -5. **No explicit slice bounds/stride for `Tile` indexing**: `Tile` vector-range - indexing only supports the open-ended forms `tile[start:]`, - `tile[row, col:]`, and `tile[row_start:, col_index]` (for column-major - layout). `stop` and `step` syntax are not accepted in user-guide Tile - indexing. - -#### Supported Operations - -The indexing syntax is supported for all vector load and store operations with the following syntax mapping: - -- **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): - - Load operations: `vlds`, `vldas`, `vldus`, `vldsx2` - - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstsx2` - -- **Single-element indexing** (`tile[row, col]` or `tile[pos]`): - - Load operations: `vsld` (scalar load with broadcast) - -#### Examples - -The following examples use row-major layout syntax. For column-major tiles, use `tile[row_start:, col_index]` syntax instead of `tile[row_index, col_start:]`. - -```python -# 2D tile indexing (row-major layout) -vec = pto.vlds(tile[i, j:]) # Load vector from row i, columns j to j+vector_lanes-1 -pto.vsts(vec, tile[i, j:], mask) # Store vector with mask - -# 1D tile indexing -vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_lanes-1 -pto.vsts(vec, tile[k:], mask) # Store vector with mask - -# Dual load with deinterleave -low, high = pto.vldsx2(tile[i, j:], "DINTLV") - -# Aligned load with indexing -vec = pto.vldas(tile[i, j:], align) - -# Scalar load (broadcast) -vec = pto.vsld(tile[i, j]) # Load scalar at tile[i,j] and broadcast to vector -``` - -#### Comparison with Manual Offset Calculation - -**Traditional approach (error-prone):** -```python -# Manual byte offset calculation for f32 tile -rows, cols = tile.shape -row_offset = i * cols * 4 # Hard-coded 4 bytes for f32 -col_offset = j * 4 -offset = row_offset + col_offset -vec = pto.vlds(tile, offset) -``` - -**New syntax (type-safe):** -```python -# Automatic offset calculation -vec = pto.vlds(tile[i, j:]) # Compiler computes correct offset for any element type -``` - -The syntax sugar eliminates manual byte calculations, reduces errors, and makes code generic across different element types (e.g., the same kernel works for both `f16` and `f32` without modification). - -### Vector Load Operations - -Operations for loading data from memory into vector registers. - -#### `pto.vlds(buf: ptr, offset: Index, dist: pto.VLoadDist | None = None) -> VRegType` [Advanced Tier] -#### `pto.vlds(tile[row, col:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] -#### `pto.vlds(tile[start:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] - -**Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. - -**Parameters (pointer syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Byte offset | -| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Loaded vector register | - -**Constraints**: -- Buffer must be in UB memory space -- For byte-offset syntax: offset must be properly aligned based on element type -- For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements -- `dist` is optional. When omitted, the load uses the backend default layout for the vector family. -- `dist` must be a `pto.VLoadDist` enum value. - -**Examples**: -```python -# Traditional byte-offset syntax -vec = pto.vlds(ub_ptr, lane * 256) -vec_unpacked = pto.vlds(ub_ptr, lane * 128, dist=pto.VLoadDist.UNPK_B16) - -# New element-indexing syntax -vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 -vec = pto.vlds(tile[k:]) # Load from 1D tile, elements k to k+vector_lanes-1 - -# Generic kernel that works for both f16 and f32 -@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def generic_scale(src: pto.Tile, dst: pto.Tile, scale: pto.f32): - rows, cols = src.shape - all_mask = pto.make_mask(src.element_type, PAT.ALL) - for i in range(0, rows): - for j in range(0, cols, vector_lanes): # vector_lanes computed from element type - # No manual byte calculation needed! - vec = pto.vlds(src[i, j:]) - scaled = pto.vmuls(vec, scale, all_mask) - pto.vsts(scaled, dst[i, j:], all_mask) -``` - -#### `pto.vldas(buf: ptr) -> pto.align` [Advanced Tier] -#### `pto.vldas(tile[row, col:]) -> pto.align` [Basic Tier] -#### `pto.vldas(tile[start:]) -> pto.align` [Basic Tier] - -**Description**: Prime alignment buffer for subsequent unaligned load. Returns alignment state for use with `pto.vldus`. Supports both pointer syntax and element-indexing syntax. - -**Parameters (pointer syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `align` | `pto.align` | Alignment state for use with `pto.vldus` | - -**Examples**: -```python -# Pointer syntax -align = pto.vldas(ub_ptr) - -# Element-indexing syntax -align = pto.vldas(tile[i, j:]) -align = pto.vldas(tile[k:]) -``` - -#### `pto.vldus(buf: ptr, align: pto.align) -> (VRegType, pto.align, ptr)` [Advanced Tier] -#### `pto.vldus(tile[row, col:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] -#### `pto.vldus(tile[start:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] - -**Description**: Unaligned load using primed align state. Requires alignment state from `pto.vldas` or previous `pto.vldus`. Updates alignment state and base pointer for subsequent loads. Supports both pointer syntax and element-indexing syntax. - -**Parameters (pointer syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | -| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | -| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | -| _or_ | | | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | -| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Assembled vector value | -| `align_out` | `pto.align` | Updated alignment state for next load | -| `base_out` | `ptr` | Post-update base pointer state | - -**Constraints**: -- A matching `pto.vldas` must appear before the first dependent `pto.vldus` stream in the same vector loop -- Both alignment state and base address advance across the stream -- If DSL authoring uses explicit byte/element offsets, the frontend first rewrites them into pointer/index expressions before lowering to this VPTO form. - -**Examples**: -```python -# Pointer syntax - requires alignment state priming -align = pto.vldas(ub_ptr) -vec, align_out, base_out = pto.vldus(ub_ptr, align) - -# Element-indexing syntax -align = pto.vldas(tile[i, j:]) -vec, align_out, base_out = pto.vldus(tile[i, j:], align) - -# Multiple unaligned loads in a stream -align = pto.vldas(tile[k:]) -for n in range(4): - vec, align, base = pto.vldus(tile[k:], align) # alignment state updates -``` - - -#### `pto.vldsx2(buf: ptr, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` [Advanced Tier] -#### `pto.vldsx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] -#### `pto.vldsx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] - -**Description**: Dual vector load with deinterleave (AoS → SoA conversion). Loads interleaved data from a single buffer and deinterleaves into two vectors. Supports both byte-offset and element-indexing syntax. - -**Parameters (pointer syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to source buffer in UB memory space (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Byte offset | -| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | -| _or_ | | | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `low` | `VRegType` | First vector (even elements in interleaved stream) | -| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | - -**Constraints**: -- Source buffer must be in UB memory space -- Offset must satisfy alignment requirements for the selected distribution mode -- The requested vector region must be within tile bounds (for element-indexing syntax) -- Distribution mode must match element type (e.g., `"DINTLV"` for 32-bit elements) - -**Examples**: -```python -# Byte-offset syntax -low, high = pto.vldsx2(ub_ptr, offset, pto.DeinterleaveDist.DINTLV) - -# Element-indexing syntax -low, high = pto.vldsx2(tile[i, j:], pto.DeinterleaveDist.DINTLV) -low, high = pto.vldsx2(tile[k:], pto.DeinterleaveDist.DINTLV) - -# Example: Load interleaved XY pairs into separate X/Y vectors -x_vec, y_vec = pto.vldsx2(xy_tile[i, j:], pto.DeinterleaveDist.DINTLV) -``` - -#### `pto.vsld(buf: ptr, offset: Index, stride: StrideMode) -> VRegType` [Advanced Tier] -#### `pto.vsld(tile[row, col], stride: StrideMode) -> VRegType` [Basic Tier] -#### `pto.vsld(tile[pos], stride: StrideMode) -> VRegType` [Basic Tier] - -**Description**: Strided load with fixed stride pattern. Loads elements from memory with regular stride pattern. The offset parameter encodes displacement with selected stride mode. This is a deprecated compatibility family. - -**Parameters (pointer syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Byte displacement encoded with selected stride mode | -| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. Determines which sub-elements are read from each source block. | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | -| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | -| _or_ | | | -| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | -| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Loaded vector with strided pattern | - -**Constraints**: -- The selected stride token determines which sub-elements are read from each source block -- This operation family is deprecated; prefer other load patterns when possible - -**Examples**: -```python -from pto import StrideMode - -# Byte-offset syntax -vec = pto.vsld(ub_ptr, offset, StrideMode.S4_B64) - -# Element-indexing syntax -vec = pto.vsld(tile[i, j], StrideMode.S3_B16) -vec = pto.vsld(tile[k], StrideMode.S8_B32) -``` - -#### `pto.vgather2(buf: ptr, offsets: Index, active_lanes: Index) -> VRegType` [Advanced Tier] - -**Description**: Indexed gather from UB. Gathers elements from a single buffer using per-lane offsets, with participation bounded by active lanes count. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to source buffer in UB memory space | -| `offsets` | `Index` | Per-lane element offsets (vector register) | -| `active_lanes` | `Index` | Number of lanes that participate (bounds gather operation) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Gathered vector | - -**Constraints**: -- Only the first `active_lanes` offsets participate in the gather -- Index element width and interpretation must match selected gather form -- Each effective address must satisfy the gather form's alignment rules - -**Example**: -```python -vec = pto.vgather2(buf, offsets, active_lanes) -``` - -#### `pto.vgather2_bc(buf: ptr, offsets: Index, mask: MaskType) -> VRegType` [Advanced Tier] - -**Description**: Gather with broadcast, conditioned by mask. Gathers elements from a single buffer using per-lane offsets, with mask gating lane participation. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to source buffer in UB memory space | -| `offsets` | `Index` | Per-lane element offsets (vector register) | -| `mask` | `MaskType` | Mask gating which lanes participate | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Gathered vector | - -**Constraints**: -- Masked-off lanes do not participate in address coalescing and do not trigger address overflow exceptions -- Destination lanes for masked-off lanes are zero-filled -- This is a backward-compatible operation family - -**Example**: -```python -vec = pto.vgather2_bc(buf, offsets, mask) -``` - -#### `pto.vgatherb(buf: ptr, offsets: Index) -> VRegType` [Advanced Tier] - -**Description**: Byte‑granularity gather load. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to buffer | -| `offsets` | `Index` | Byte offsets | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Gathered vector | - -**Example**: -```python -vec = pto.vgatherb(buf, offsets) -``` - -#### `pto.vsldb(buf: ptr, offset: Index, mask: MaskType) -> VRegType` [Advanced Tier] -#### `pto.vsldb(tile[row, col], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] -#### `pto.vsldb(tile[pos], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] - -**Description**: Block-strided load for 2D tile access. Loads elements with block stride pattern controlled by packed offset word and mask. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Pointer to buffer in UB memory space | -| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | -| `mask` | `MaskType` | Mask controlling which blocks participate | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | -| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | -| `mask` | `MaskType` | Mask controlling which blocks participate | -| _or_ | | | -| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | -| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | -| `mask` | `MaskType` | Mask controlling which blocks participate | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Loaded vector with block-strided pattern | - -**Constraints**: -- The offset encodes block stride and repeat pattern, not a plain byte displacement -- If a block is masked off, the corresponding destination block is zeroed -- Masked-off blocks must not raise address overflow exceptions - -**Example**: -```python -# Byte-offset syntax -vec = pto.vsldb(ub_ptr, control_word, mask) - -# Element-indexing syntax -vec = pto.vsldb(tile[i, j], control_word, mask) -vec = pto.vsldb(tile[k], control_word, mask) -``` - -### Vector Store Operations - -Operations for storing data from vector registers to memory. - -#### `pto.vsts(vec: VRegType, buf: ptr, offset: Index, mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Advanced Tier] -#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] -#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] - -**Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Byte offset | -| `mask` | `MaskType` | Predicate mask | -| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | -| `mask` | `MaskType` | Predicate mask | -| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | - -**Returns**: None (side-effect operation) - -**Constraints**: -- Buffer must be in UB memory space -- For byte-offset syntax: offset must be properly aligned based on element type -- For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements -- `dist` is optional. When omitted, the store uses the backend default layout for the vector family. -- Current TileLang DSL v1 accepts exactly one keyword attr on `pto.vsts`: `dist=...`. -- `dist` must be a `pto.VStoreDist` enum value. -- `mask` must match the effective store payload granularity, which may differ from the vector element family when `dist` repacks lanes. -- Common width-changing cases: - default / `NORM_B32` stores expect `mask_b32` for `f32`/`i32`-family vectors; - `PK_B32` also expects `mask_b32` and is used by narrow stores such as `f32 -> f16` `tcvt`; - `PK_B16` expects `mask_b16`. - -**Examples**: -```python -# Byte-offset syntax -pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) - -# Element-indexing syntax -pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 -pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 - -# VPTO-aligned packed store -vec_f16 = pto.vcvt( - vec_f32, - pto.f16, - mask32, - rnd=pto.VcvtRoundMode.R, - sat=pto.VcvtSatMode.SAT, - part=pto.VcvtPartMode.EVEN, -) -pto.vsts(vec_f16, tile[i, j:], mask32, dist=pto.VStoreDist.PK_B32) - -# In a generic kernel -@pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def generic_store(src: pto.Tile, dst: pto.Tile): - rows, cols = src.shape - all_mask = pto.make_mask(src.element_type, PAT.ALL) - for i in range(0, rows): - for j in range(0, cols, vector_lanes): - vec = pto.vlds(src[i, j:]) - pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation -``` - -#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] - -**Description**: Predicate store (`pto.psts`) writes the packed payload represented by -`MaskType` to UB memory. This is the dynamic-offset form of the VPTO predicate-store -family (`psts` vs `psti`): the payload semantics are identical, and only the offset -delivery form differs. - -**Parameters (advanced byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Predicate payload to store | -| `buf` | `ptr` | Pointer to destination UB buffer (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Runtime offset (`index`) | -| `dist` | `PredicateDist` | Predicate distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | - -**Returns**: None (side-effect operation) - -**DIST semantics (VPTO-aligned)**: -- `PredicateDist.NORM`: store packed predicate payload into a normal destination space of size `VL/8`. -- `PredicateDist.PK`: store packed predicate payload into a destination space of size `VL/16`, keeping one bit out of every two bits. - -**Notes**: -- `pto.psts` is intentionally documented as explicit `buf + offset` surface in DSL v1. -- Packed predicate payload layout is bit-level (`VL/8` or `VL/16`), so tile element-indexing is not part of the stable Basic Tier contract. -- The pointer + offset form maps directly to explicit `base[offset]`. -- Authoritative predicate-memory-family semantics are documented in `10-predicate-operations.md`. - -#### `pto.vsst(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] -#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` -#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` - -**Description**: Scalar to vector store (broadcast scalar to all lanes). Supports both traditional byte-offset syntax and new element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `buf` | `ptr` | Pointer to destination buffer (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Byte offset | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -#### `pto.vstsx2(low: VRegType, high: VRegType, buf: ptr, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` [Advanced Tier] -#### `pto.vstsx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` -#### `pto.vstsx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` - -**Description**: Dual interleaved store (SoA → AoS conversion). Stores two vectors interleaved into a single buffer. Supports both byte-offset and element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `low` | `VRegType` | First vector (even elements in interleaved stream) | -| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | -| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Byte offset | -| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `low` | `VRegType` | First vector (even elements in interleaved stream) | -| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `low` | `VRegType` | First vector (even elements in interleaved stream) | -| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -**Constraints**: -- Destination buffer must be in UB memory space -- Offset must satisfy alignment requirements for the selected distribution mode -- The destination vector region must be within tile bounds (for element-indexing syntax) -- Distribution mode must match element type (e.g., `"INTLV"` for 32-bit elements) -- The two source vectors form an ordered pair; interleave semantics must be preserved - -**Examples**: -```python -# Byte-offset syntax -pto.vstsx2(x_vec, y_vec, ub_ptr, offset, pto.InterleaveDist.INTLV, mask) - -# Element-indexing syntax -pto.vstsx2(x_vec, y_vec, tile[i, j:], pto.InterleaveDist.INTLV, mask) -pto.vstsx2(x_vec, y_vec, tile[k:], pto.InterleaveDist.INTLV, mask) - -# Example: Store separate X/Y vectors as interleaved XY pairs -pto.vstsx2(x_vec, y_vec, xy_tile[i, j:], pto.InterleaveDist.INTLV, all_mask) -``` - -#### `pto.vsta(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] -#### `pto.vsta(align: pto.align, tile[row, col:]) -> None` [Basic Tier] -#### `pto.vsta(align: pto.align, tile[start:]) -> None` [Basic Tier] - -**Description**: Flush alignment state to memory. Writes buffered tail bytes from alignment state to UB memory. Consumes the alignment state after flush. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align` | `pto.align` | Pending store-alignment state | -| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Flush displacement | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align` | `pto.align` | Pending store-alignment state | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| _or_ | | | -| `align` | `pto.align` | Pending store-alignment state | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | - -**Returns**: None (side-effect operation) - -**Constraints**: -- The flush address must match the post-updated address expected by the preceding unaligned-store stream -- After the flush, the corresponding store alignment state is consumed -- A final flush operation is required to commit buffered bytes after unaligned-store sequences -- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream - -**Example**: -```python -# Byte-offset syntax -pto.vsta(align, ub_ptr, offset) - -# Element-indexing syntax -pto.vsta(align, tile[i, j:]) -pto.vsta(align, tile[k:]) -``` - -#### `pto.vscatter(vec: VRegType, buf: ptr, offsets: Index, active_lanes: Index) -> None` [Advanced Tier] - -**Description**: Indexed scatter to UB. Stores vector elements to irregular locations using per-lane offsets, with participation bounded by active lanes count. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Source vector to scatter | -| `buf` | `ptr` | Pointer to destination buffer in UB memory space | -| `offsets` | `Index` | Per-lane element offsets (vector register) | -| `active_lanes` | `Index` | Number of lanes that participate (bounds scatter operation) | - -**Returns**: None (side-effect operation) - -**Constraints**: -- Only `b8`, `b16`, and `b32` element sizes are supported -- Current TileLang DSL / VPTO path requires `i32` index vectors -- Each computed address must be element-aligned -- If indices alias, only one write is guaranteed (winning lane is implementation-defined) -- Only the first `active_lanes` offsets participate in the scatter - -**Example**: -```python -pto.vscatter(vec, buf, offsets, active_lanes) -``` - -#### `pto.vsstb(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] -#### `pto.vsstb(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` [Basic Tier] -#### `pto.vsstb(scalar: ScalarType, tile[start:], mask: MaskType) -> None` [Basic Tier] - -**Description**: Scalar to vector store with broadcast (enhanced version of `vsst`). Supports both byte‑offset and element‑indexing syntax. - -**Parameters (pointer syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `buf` | `ptr` | Pointer to destination buffer | -| `offset` | `Index` | Byte offset | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -**Example**: -```python -# Byte-offset syntax -pto.vsstb(pto.f32(0.0), ub_ptr, offset, mask) - -# Element-indexing syntax -pto.vsstb(pto.f32(1.0), tile[i, j:], mask) -``` - -#### `pto.vstar(align: pto.align, buf: ptr) -> None` [Advanced Tier] -#### `pto.vstar(align: pto.align, tile[row, col:]) -> None` [Basic Tier] -#### `pto.vstar(align: pto.align, tile[start:]) -> None` [Basic Tier] - -**Description**: Flush alignment state using the register-update form. Writes buffered tail bytes from alignment state to UB memory. The implicit update state must correspond to the same store stream that produced the alignment state. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align` | `pto.align` | Pending store-alignment state | -| `buf` | `ptr` | Pointer to destination buffer in UB memory space | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align` | `pto.align` | Pending store-alignment state | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| _or_ | | | -| `align` | `pto.align` | Pending store-alignment state | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | - -**Returns**: None (side-effect operation) - -**Constraints**: -- The implicit update state consumed by this flush must correspond to the same store stream that produced the alignment state -- A final flush operation is required to commit buffered bytes after unaligned-store sequences -- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream - -**Example**: -```python -# Byte-offset syntax -pto.vstar(align, ub_ptr) - -# Element-indexing syntax -pto.vstar(align, tile[i, j:]) -pto.vstar(align, tile[k:]) -``` - -#### `pto.vstas(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] -#### `pto.vstas(align: pto.align, tile[row, col:], offset: Index) -> None` [Basic Tier] -#### `pto.vstas(align: pto.align, tile[start:], offset: Index) -> None` [Basic Tier] - -**Description**: Scalar-register-offset form of alignment-state flush. Writes buffered tail bytes from alignment state to UB memory with explicit scalar offset. Uses same buffered-tail semantics as `pto.vsta`. - -**Parameters (pointer syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align` | `pto.align` | Pending store-alignment state | -| `buf` | `ptr` | Pointer to destination buffer in UB memory space | -| `offset` | `Index` | Scalar-register style displacement | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align` | `pto.align` | Pending store-alignment state | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `offset` | `Index` | Scalar-register style displacement | -| _or_ | | | -| `align` | `pto.align` | Pending store-alignment state | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `offset` | `Index` | Scalar-register style displacement | - -**Returns**: None (side-effect operation) - -**Example**: -```python -# Byte-offset syntax -pto.vstas(align, ub_ptr, offset) - -# Element-indexing syntax -pto.vstas(align, tile[i, j:], offset) -pto.vstas(align, tile[k:], offset) -``` - -### Stateful Store Operations - -Operations for storing data with stateful semantics. - -#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] - -**Description**: Predicate unaligned store with align state update. Stores predicate mask with alignment state threading. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `pto.align` | Incoming store-alignment state | -| `mask` | `MaskType` | Predicate mask to store | -| `buf` | `ptr` | Pointer to destination buffer in UB memory space | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `align_out` | `pto.align` | Updated alignment state | -| `base_out` | `ptr` | Post-update base pointer state | - -**Constraints**: -- Part of stateful unaligned-store sequence with alignment state threading - -#### `pto.vstu(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, mode: Index) -> (pto.align, ptr)` [Advanced Tier] - -**Description**: Unaligned store with explicit threaded alignment/base state. Models a stateful unaligned-store sequence in SSA form. Requires a final `pto.vsta`/`pto.vstas`/`pto.vstar` to flush trailing buffered bytes. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `pto.align` | Incoming store-alignment state | -| `base_in` | `ptr` | Current stream base pointer | -| `vec` | `VRegType` | Vector to store | -| `buf` | `ptr` | Destination buffer in UB memory space | -| `mode` | `Index` | Mode selecting post-update behavior | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `align_out` | `pto.align` | Updated buffered-tail state | -| `base_out` | `ptr` | Post-update base pointer state | - -**Constraints**: -- Models stateful unaligned-store sequence in SSA form -- Final flush operation required to commit buffered bytes - -**Example**: -```python -# Stateful unaligned store + final flush (vsta form) -align1, base1 = pto.vstu(align0, base0, vec0, ub_ptr, mode) -align2, base2 = pto.vstu(align1, base1, vec1, ub_ptr, mode) -pto.vsta(align2, ub_ptr, tail_offset) -``` - -#### `pto.vstus(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, offset: Index) -> (pto.align, ptr)` [Advanced Tier] - -**Description**: Scalar-offset unaligned store with threaded state. Same roles as `pto.vstu` but with explicit scalar displacement. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `pto.align` | Incoming store-alignment state | -| `base_in` | `ptr` | Current stream base pointer | -| `vec` | `VRegType` | Vector to store | -| `buf` | `ptr` | Destination buffer in UB memory space | -| `offset` | `Index` | Scalar displacement | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `align_out` | `pto.align` | Updated buffered-tail state | -| `base_out` | `ptr` | Post-update base pointer state | - -**Constraints**: -- Same final flush requirement and state-threading constraints as `pto.vstu` - -**Example**: -```python -# Scalar-offset threaded form + final flush (vstas form) -align1, base1 = pto.vstus(align0, base0, vec0, ub_ptr, offset0) -align2, base2 = pto.vstus(align1, base1, vec1, ub_ptr, offset1) -pto.vstas(align2, ub_ptr, flush_offset) -``` - -#### `pto.vstur(align_in: pto.align, vec: VRegType, buf: ptr, mode: PostUpdateMode = pto.PostUpdateMode.NO_POST_UPDATE) -> pto.align` [Advanced Tier] - -**Description**: Register-update unaligned store form. Updates only the residual alignment state without base pointer update. Requires matching flush operation to emit trailing bytes. The optional `mode` operand is a typed Enum and controls whether the hardware performs post-update on the implicit AR state. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `pto.align` | Incoming store-alignment state | -| `vec` | `VRegType` | Vector to store | -| `buf` | `ptr` | Destination buffer in UB memory space | -| `mode` | `PostUpdateMode` | Optional post-update mode. Defaults to `pto.PostUpdateMode.NO_POST_UPDATE`. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `align_out` | `pto.align` | Updated buffered-tail state | - -**Constraints**: -- Updates only residual alignment state (no base pointer update) -- Matching flush operation still required to emit trailing bytes - -**Example**: -```python -# Residual-state form + final flush (vstar form) -align1 = pto.vstur(align0, vec0, ub_ptr) -align2 = pto.vstur(align1, vec1, ub_ptr) -pto.vstar(align2, ub_ptr) - -# Explicit post-update mode with typed Enum -align3 = pto.vstur(align2, vec2, ub_ptr, pto.PostUpdateMode.POST_UPDATE) -``` - -#### Align-State Store Closed Loop - -For unaligned store families, the state must form a closed loop: - -1. Start from an incoming `align` state. -2. Thread state through one or more `vstu` / `vstus` / `vstur` operations. -3. Terminate the stream with exactly one flush op: `vsta` or `vstas` or `vstar`. -4. Do not reuse a flushed `align` state in another stream. diff --git a/ptodsl/docs/user_guide/10-predicate-operations.md b/ptodsl/docs/user_guide/10-predicate-operations.md deleted file mode 100644 index 8cc92da2c..000000000 --- a/ptodsl/docs/user_guide/10-predicate-operations.md +++ /dev/null @@ -1,637 +0,0 @@ -### Predicate Operations - -Operations for creating and manipulating typed masks. - -**Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. - -**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. - -**Predicate Part Enum**: `pto.ppack` and `pto.punpack` require the `PredicatePart` enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`; these lower to the VPTO canonical `PART` tokens `"LOWER"` and `"HIGHER"`. - -**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate memory families. Load families (`plds`, `pld`, `pldi`) use `NORM`, `US`, and `DS`. Store families (`psts`, `pst`, `psti`) use `NORM` and `PK`. - -**Pattern coverage**: The VPTO canonical predicate-generation families use `PAT_*` tokens such as `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, `PAT_VL*`, `PAT_M3`, and `PAT_M4`. The Python DSL surface may expose only a subset through `pto.MaskPattern`; check the enum for currently available values. - -#### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` - -**Description**: Creates an 8-bit granularity mask from a pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b8` | 8-bit granularity mask | - -**Constraints**: -- Used with `i8` vector operations - -**Example**: -```python -mask8 = pto.pset_b8(PAT.ALL) -``` - -#### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` - -**Description**: Creates a 16-bit granularity mask from a pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b16` | 16-bit granularity mask | - -**Constraints**: -- Used with `f16`/`bf16`/`i16` vector operations - -**Example**: -```python -mask16 = pto.pset_b16(PAT.ALL) -``` - -#### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` - -**Description**: Creates a 32-bit granularity mask from a pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b32` | 32-bit granularity mask | - -**Constraints**: -- Used with `f32`/`i32` vector operations - -**Example**: -```python -mask32 = pto.pset_b32(PAT.ALL) -``` - -#### `pto.pge_b8(pattern: pto.MaskPattern) -> pto.mask_b8` - -**Description**: Generate tail mask — first N lanes active based on pattern. Creates an 8-bit granularity mask where the first N lanes are active according to the specified pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b8` | 8-bit granularity tail mask | - -**Constraints**: -- Used with `i8` vector operations -- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) - -**Example**: -```python -# Tail mask pattern lowered as `PAT_VL16` -tail_mask = pto.pge_b8(PAT.VL16) -``` - -#### `pto.pge_b16(pattern: pto.MaskPattern) -> pto.mask_b16` - -**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 16-bit granularity mask where the first N lanes are active according to the specified pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b16` | 16-bit granularity tail mask | - -**Constraints**: -- Used with `f16`/`bf16`/`i16` vector operations -- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) - -**Example**: -```python -# Tail mask for first 16 lanes -tail_mask = pto.pge_b16(PAT.VL16) -``` - -#### `pto.pge_b32(pattern: pto.MaskPattern) -> pto.mask_b32` - -**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 32-bit granularity mask where the first N lanes are active according to the specified pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b32` | 32-bit granularity tail mask | - -**Constraints**: -- Used with `f32`/`i32` vector operations -- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) - -**Example**: -```python -# Tail mask for first 32 lanes -tail_mask = pto.pge_b32(PAT.VL32) -``` - -#### `pto.plt_b8(scalar: pto.i32) -> (pto.mask_b8, pto.i32)` - -**Description**: Generate predicate state together with updated scalar state (tail processing). Creates an 8-bit granularity mask and returns updated scalar value for state progression. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b8` | 8-bit granularity mask | -| `scalar_out` | `pto.i32` | Updated scalar state | - -**Constraints**: -- Used with `i8` vector operations for tail processing -- The scalar input is typically a remaining element count that decrements across successive calls - -**Example**: -```python -remaining: pto.i32 = 64 -mask, remaining = pto.plt_b8(remaining) # generates mask for next chunk, updates remaining count -``` - -#### `pto.plt_b16(scalar: pto.i32) -> (pto.mask_b16, pto.i32)` - -**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 16-bit granularity mask and returns updated scalar value for state progression. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b16` | 16-bit granularity mask | -| `scalar_out` | `pto.i32` | Updated scalar state | - -**Constraints**: -- Used with `f16`/`bf16`/`i16` vector operations for tail processing -- The scalar input is typically a remaining element count that decrements across successive calls - -**Example**: -```python -remaining: pto.i32 = 64 -mask, remaining = pto.plt_b16(remaining) # generates mask for next chunk, updates remaining count -``` - -#### `pto.plt_b32(scalar: pto.i32) -> (pto.mask_b32, pto.i32)` - -**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 32-bit granularity mask and returns updated scalar value for state progression. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b32` | 32-bit granularity mask | -| `scalar_out` | `pto.i32` | Updated scalar state | - -**Constraints**: -- Used with `f32`/`i32` vector operations for tail processing -- The scalar input is typically a remaining element count that decrements across successive calls - -**Example**: -```python -remaining: pto.i32 = 64 -mask, remaining = pto.plt_b32(remaining) # generates mask for next chunk, updates remaining count -``` - -#### `pto.make_mask(element_type: Type, value: pto.i32 | pto.MaskPattern) -> MaskType | (MaskType, pto.i32)` - -**Description**: Creates a mask with appropriate bitwidth (8, 16, or 32) based on element type, automatically inferring whether to perform tail processing or pattern-based mask generation based on the `value` parameter type. This convenience function eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` and `pset_b8`/`pset_b16`/`pset_b32` operations. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | -| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (for example `pto.MaskPattern.ALL` or `pto.MaskPattern.VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `MaskType` | Generated mask with appropriate granularity | -| `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | - -**Constraints**: -- The `element_type` must be one of: `f32`, `f16`, `bf16`, or an 8/16/32-bit integer family member (`i*`, `si*`, `ui*`) -- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`/`si32`/`ui32`, 16-bit for `f16`/`bf16`/`i16`/`si16`/`ui16`, and 8-bit for `i8`/`si8`/`ui8` -- The function infers the operation mode from the `value` parameter type at compile time: - - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) - - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) - -**Implementation Note**: This function is a DSL macro that performs type-based dispatch at compile time: -- When `value` is a `pto.i32` expression: expands to corresponding `plt_b` instruction (`plt_b32`, `plt_b16`, or `plt_b8`) -- When `value` is a `pto.MaskPattern` enum value: expands to corresponding `pset_b` instruction (`pset_b32`, `pset_b16`, or `pset_b8`) - -**Example**: -```python -# Tail processing with f32 vectors: value is pto.i32 → expands to plt_b32 -mask_f32, remaining_f32 = pto.make_mask(pto.f32, remaining_elements) - -# Tail processing with f16 vectors: value is pto.i32 → expands to plt_b16 -mask_f16, remaining_f16 = pto.make_mask(pto.f16, remaining_elements) - -# Tail processing with i8 vectors: value is pto.i32 → expands to plt_b8 -mask_i8, remaining_i8 = pto.make_mask(pto.i8, remaining_elements) - -# Pattern-based mask with f32 vectors: value is MaskPattern enum → expands to pset_b32 -mask_all_f32 = pto.make_mask(pto.f32, PAT.ALL) - -# Pattern-based mask with f16 vectors: value is MaskPattern enum → expands to pset_b16 -mask_even_f16 = pto.make_mask(pto.f16, PAT.EVEN) - -# Pattern-based mask with i8 vectors: value is MaskPattern enum → expands to pset_b8 -mask_all_i8 = pto.make_mask(pto.i8, PAT.ALL) - -# Type annotations help clarify expected parameter types -remaining: pto.i32 = 1024 -mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing -mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode -``` - -#### `pto.ppack(mask: MaskType, part: PredicatePart) -> MaskType` - -**Description**: Narrowing pack of a predicate register. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | -| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `packed` | `MaskType` | Packed mask | - -**Example**: -```python -packed = pto.ppack(mask, pto.PredicatePart.LOWER) -``` - -#### `pto.punpack(mask: MaskType, part: PredicatePart) -> MaskType` - -**Description**: Widening unpack of a predicate register. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Input mask | -| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `MaskType` | Unpacked mask | - -**Example**: -```python -unpacked = pto.punpack(mask, pto.PredicatePart.HIGHER) -``` - -#### `pto.pbitcast(mask: MaskType, to_type: MaskType) -> MaskType` - -**Description**: Reinterprets a typed predicate mask as another typed mask granularity without changing the underlying predicate bit image. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | -| `to_type` | `MaskType` | Target mask type marker such as `pto.mask_b16` or `pto.mask_b32` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Reinterpreted mask with the requested target granularity | - -**Constraints**: -- `mask` must already be a typed predicate value -- `to_type` must be one of the DSL mask type markers: `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` -- this is a bit reinterpretation helper, not a logical predicate transform; it does not insert packing, unpacking, interleaving, or deinterleaving by itself -- use `pto.ppack`, `pto.punpack`, `pto.pdintlv_b8`, or `pto.pintlv_b16` when the predicate image itself must be rearranged - -**Example**: -```python -mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) -mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) - -mask0_b16, mask1_b16 = pto.pintlv_b16(mask_b16, pto.pset_b16(PAT.ALL)) -mask0_b32 = pto.pbitcast(mask0_b16, pto.mask_b32) -``` - -#### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` - -**Description**: Predicate negation under a same-granularity mask gate. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Input mask | -| `gate` | `MaskType` | Gating mask with the same granularity | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `negated` | `MaskType` | Negated mask | - -#### `pto.psel(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` - -**Description**: Selects between two masks using a third mask as selector. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src0` | `MaskType` | First input mask | -| `src1` | `MaskType` | Second input mask | -| `mask` | `MaskType` | Selection mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Selected mask | - -#### `pto.plds(buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> MaskType` [Advanced Tier] - -**Description**: Predicate load with scalar-index style offset form. This is the default DSL surface for loading predicate masks from UB memory. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Source pointer in UB memory space | -| `offset` | `Index` | Scalar/index-style offset | -| `dist` | `PredicateDist` | Distribution mode (default: `PredicateDist.NORM`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `MaskType` | Loaded predicate mask | - -**Example**: -```python -mask = pto.plds(buf, offset, pto.PredicateDist.NORM) -``` - -#### `pto.pld(buf: ptr, offset: Index, dist: PredicateDist) -> MaskType` [Advanced Tier] - -**Description**: Predicate load with areg/index register style offset encoding. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Source pointer in UB memory space | -| `offset` | `Index` | Areg/index-style offset | -| `dist` | `PredicateDist` | Distribution mode | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `MaskType` | Loaded predicate mask | - -**Example**: -```python -mask = pto.pld(buf, offset, pto.PredicateDist.NORM) -``` - -#### `pto.pldi(buf: ptr, imm_offset: pto.i32, dist: PredicateDist) -> MaskType` [Advanced Tier] - -**Description**: Predicate load with immediate-offset encoding form. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `ptr` | Source pointer in UB memory space | -| `imm_offset` | `pto.i32` | Immediate-offset operand | -| `dist` | `PredicateDist` | Distribution mode | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `MaskType` | Loaded predicate mask | - -**Example**: -```python -mask = pto.pldi(buf, 0, pto.PredicateDist.NORM) -``` - -#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] - -**Description**: Stores a predicate mask to UB memory using the VPTO dynamic-offset -`psts` form. This is the dynamic counterpart of `psti`: both encode the same -predicate payload semantics, while offset delivery differs (runtime `index` vs -constant immediate). - -**Parameters (Advanced Tier: explicit pointer surface)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Predicate mask to store | -| `buf` | `ptr` | Pointer to destination UB buffer | -| `offset` | `Index` | Runtime offset (`index`) | -| `dist` | `PredicateDist` | Distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | - -**DIST semantics (VPTO-aligned)**: -- `NORM`: stores packed predicate payload into destination space of size `VL/8`. -- `PK`: stores packed predicate payload into destination space of size `VL/16`, - keeping one bit out of every two bits. - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.psts(mask, buf, offset, pto.PredicateDist.NORM) -``` - -#### `pto.pst(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] - -**Description**: Stores a predicate mask to UB memory using areg/index offset encoding. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Predicate mask to store | -| `buf` | `ptr` | Pointer to destination UB buffer | -| `offset` | `Index` | Areg/index-style offset | -| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.pst(mask, buf, offset, pto.PredicateDist.NORM) -``` - -#### `pto.psti(mask: MaskType, buf: ptr, imm_offset: pto.i32, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] - -**Description**: Stores a predicate mask to UB memory using immediate-offset encoding. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Predicate mask to store | -| `buf` | `ptr` | Pointer to destination UB buffer | -| `imm_offset` | `pto.i32` | Immediate-offset operand | -| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.psti(mask, buf, pto.i32(8), pto.PredicateDist.PK) -``` - -#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] - -**Description**: Unaligned predicate store with align-state update. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `pto.align` | Input alignment state | -| `mask` | `MaskType` | Predicate mask to store | -| `buf` | `ptr` | Pointer to destination UB buffer | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `align_out` | `pto.align` | Updated alignment state | -| `base_out` | `ptr` | Updated destination pointer | - -**Example**: -```python -align_out, base_out = pto.pstu(align_in, mask, buf) -``` - -#### `pto.pand(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` - -**Description**: Bitwise AND of two predicate masks under a gating mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src0` | `MaskType` | First input mask | -| `src1` | `MaskType` | Second input mask | -| `mask` | `MaskType` | Gating mask with the same granularity | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Bitwise AND result | - -**Example**: -```python -result = pto.pand(mask1, mask2, gate) -``` - -#### `pto.por(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` - -**Description**: Bitwise OR of two predicate masks under a gating mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src0` | `MaskType` | First input mask | -| `src1` | `MaskType` | Second input mask | -| `mask` | `MaskType` | Gating mask with the same granularity | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Bitwise OR result | - -**Example**: -```python -result = pto.por(mask1, mask2, gate) -``` - -#### `pto.pxor(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` - -**Description**: Bitwise XOR of two predicate masks under a gating mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src0` | `MaskType` | First input mask | -| `src1` | `MaskType` | Second input mask | -| `mask` | `MaskType` | Gating mask with the same granularity | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Bitwise XOR result | - -**Example**: -```python -result = pto.pxor(mask1, mask2, gate) -``` - -#### `pto.pdintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` - -**Description**: Predicate deinterleave for 8-bit masks. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src0` | `pto.mask_b8` | First input mask | -| `src1` | `pto.mask_b8` | Second input mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `low` | `pto.mask_b8` | First result mask | -| `high` | `pto.mask_b8` | Second result mask | - -**Example**: -```python -low8, high8 = pto.pdintlv_b8(mask_a, mask_b) -``` - -#### `pto.pintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` - -**Description**: Predicate interleave for 16-bit masks. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src0` | `pto.mask_b16` | First input mask | -| `src1` | `pto.mask_b16` | Second input mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `low` | `pto.mask_b16` | First result mask | -| `high` | `pto.mask_b16` | Second result mask | - -**Example**: -```python -low16, high16 = pto.pintlv_b16(mask_a, mask_b) -``` - -**Note**: Prefer `pto.make_mask()` for automatic bitwidth selection and unified tail/pattern mask generation. diff --git a/ptodsl/docs/user_guide/10-sync-ops.md b/ptodsl/docs/user_guide/10-sync-ops.md new file mode 100644 index 000000000..7124f33dd --- /dev/null +++ b/ptodsl/docs/user_guide/10-sync-ops.md @@ -0,0 +1,447 @@ +# 10. Synchronization Operations + +Chapters 7 and 8 covered data movement and computation. This chapter covers the synchronization primitives that keep those operations correctly ordered across the NPU's concurrent hardware pipelines. + +The Ascend NPU executes work across multiple independent pipelines — MTE (DMA), Vector, and Cube — each with its own instruction stream. Synchronization operations coordinate these pipelines: a DMA must finish loading data before the vector unit starts computing on it; a matrix multiply must complete before the result is stored. Without explicit synchronization, pipelines race, and results are undefined. + +## 10.1 Enum types for synchronization + +PTODSL provides three enum types for type-safe specification of synchronization parameters. + +### `BarrierType` + +Memory barrier types used with `pto.mem_bar`. Each value specifies which category of prior instruction must complete before which category of subsequent instruction may proceed. + +| Member | Meaning | +|--------|---------| +| `VV_ALL` | All vector ops before → all vector ops after | +| `VST_VLD` | Vector stores before → vector loads after | +| `VLD_VST` | Vector loads before → vector stores after | +| `VST_VST` | Vector stores before → vector stores after | +| `VS_ALL` | All vector ops before → all scalar ops after | +| `VST_LD` | Vector stores before → scalar loads after | +| `VLD_ST` | Vector loads before → scalar stores after | +| `VST_ST` | Vector stores before → scalar stores after | +| `SV_ALL` | All scalar ops before → all vector ops after | +| `ST_VLD` | Scalar stores before → vector loads after | +| `LD_VST` | Scalar loads before → vector stores after | +| `ST_VST` | Scalar stores before → vector stores after | +| `SYNC` | Full ordering — all prior memory operations (all pipes) complete before any subsequent operation | + +`SYNC` is a convenience value equivalent to a full pipeline barrier. It is the idiomatic choice for separating compute phases inside a ukernel when fine-grained barrier types are not needed. + +The naming convention: `V` = vector, `S` = scalar, `ST` = store, `LD` = load. `VST_VLD` reads "Vector STore before Vector LoaD." + +### `Pipe` + +Hardware pipeline identifiers used with `pto.set_flag`, `pto.wait_flag`, and `pto.pipe_barrier`. + +| Member | Pipeline | +|--------|----------| +| `S` | Scalar / control pipeline | +| `V` | Vector pipeline (SIMD) | +| `M` | Matrix / Cube pipeline | +| `MTE1` | Memory Transfer Engine 1 | +| `MTE2` | Memory Transfer Engine 2 | +| `MTE3` | Memory Transfer Engine 3 | +| `MTE4` | Memory Transfer Engine 4 | +| `ALL` | All pipelines (for barrier operations) | + +The most commonly used pipes in synchronization are `MTE2` (GM ↔ UB DMA), `MTE3` (UB ↔ UB DMA), `V` (vector compute), and `M` (matrix compute). + +### `Event` + +Event identifiers for pipeline synchronization flags. The hardware provides 8 event IDs (0–7) per pipeline pair, supporting up to 8 concurrent in-flight DMA/compute sequences. + +| Member | Value | +|--------|-------| +| `ID0` | Event 0 | +| `ID1` | Event 1 | +| `ID2` | Event 2 | +| `ID3` | Event 3 | +| `ID4` | Event 4 | +| `ID5` | Event 5 | +| `ID6` | Event 6 | +| `ID7` | Event 7 | + +Events are per-pipeline-pair: the same `ID0` used between `MTE2 → V` is independent from `ID0` used between `MTE3 → V`. + +--- + +## 10.2 Pipeline synchronization: `set_flag`, `wait_flag`, `pipe_barrier` + +Pipeline synchronization is the primary mechanism for ordering work across pipelines. The pattern is always **signal then wait**: the producer pipeline sets a flag when its work is done; the consumer pipeline waits on that flag before proceeding. + +### `pto.set_flag(pipe_from, pipe_to, event_id)` + +**Description**: Sets a synchronization flag between two hardware pipelines. The producing pipeline signals that work up to this point is complete. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `Pipe` | Source pipeline — the pipeline that has completed its work | +| `pipe_to` | `Pipe` | Destination pipeline — the pipeline being notified | +| `event_id` | `Event` | Event identifier for this specific synchronization point | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import Pipe, Event + +# MTE2 has finished loading tile data — signal Vector pipeline +pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID0) +``` + +### `pto.wait_flag(pipe_from, pipe_to, event_id)` + +**Description**: Waits for a synchronization flag. The consuming pipeline blocks until the flag is set by the producing pipeline. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `Pipe` | Source pipeline that set the flag | +| `pipe_to` | `Pipe` | Destination pipeline — the pipeline that is waiting | +| `event_id` | `Event` | Event identifier matching the corresponding `set_flag` | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import Pipe, Event + +# Vector pipeline waits for MTE2 to finish loading +pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID0) +``` + +### `pto.pipe_barrier(pipes)` + +**Description**: Executes a barrier across the specified pipelines. All work before the barrier in the named pipelines must complete before any work after the barrier may begin. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipes` | `Pipe` | Pipeline specification — typically `Pipe.ALL` for a full barrier | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import Pipe + +# Full hardware barrier — all pipelines synchronize +pto.pipe_barrier(Pipe.ALL) +``` + +### Typical usage pattern + +A common ukernel pattern interleaves DMA and compute with `set_flag` / `wait_flag` pairs: + +```python +@pto.ukernel +def gemm_block(q_tile, k_tile, v_tile, o_tile, ...): + # DMA: load K and V tiles from GM to UB + # mte_load derives strides, burst sizes, etc. from k_part / k_tile types + pto.mte_load(k_part, k_tile) + pto.mte_load(v_part, v_tile) + + # Signal: DMA done, UB data ready + pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID0) + + # Wait: vector pipeline stalls until data arrives + pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID0) + + # Compute: now safe to use k_tile and v_tile + qk_matmul(q_tile, k_tile, ...) + pv_matmul(p_tile, v_tile, ...) + + # Signal: compute done, results ready for store + pto.set_flag(Pipe.V, Pipe.MTE3, Event.ID1) + pto.wait_flag(Pipe.V, Pipe.MTE3, Event.ID1) + + # DMA: store results back to GM + pto.mte_store(o_tile, o_part) +``` + +--- + +## 10.3 Buffer management: `get_buf`, `rls_buf` + +Double-buffering is a common optimization in NPU kernels: while one buffer is being computed on, the other is being loaded with the next block of data. The `get_buf` / `rls_buf` pair coordinates buffer ownership between pipelines. + +### `pto.get_buf(pipe, buf_id, mode=0)` + +**Description**: Acquire a buffer slot for inter-pipeline double-buffering coordination. The calling pipeline claims ownership of the buffer, blocking if the buffer is still in use by another pipeline. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier of the acquiring pipeline | +| `buf_id` | `pto.i64` | Buffer identifier (0-based index into the buffer pool) | +| `mode` | `pto.i64` | Acquisition mode (default 0) | + +**Returns**: None (side-effect operation). + +### `pto.rls_buf(pipe, buf_id, mode=0)` + +**Description**: Release a buffer slot, allowing another pipeline to acquire it. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier of the releasing pipeline | +| `buf_id` | `pto.i64` | Buffer identifier matching the corresponding `get_buf` | +| `mode` | `pto.i64` | Release mode (default 0) | + +**Returns**: None (side-effect operation). + +### Double-buffering example + +```python +from pto import Pipe + +# Pipeline V acquires buffer 0 for compute +pto.get_buf(Pipe.V, 0, 0) + +# ... compute into buffer 0 ... + +# Release buffer 0 — DMA can now refill it +pto.rls_buf(Pipe.V, 0, 0) + +# Pipeline MTE2 acquires buffer 0 for reload +pto.get_buf(Pipe.MTE2, 0, 0) + +# ... DMA loads next block into buffer 0 ... + +pto.rls_buf(Pipe.MTE2, 0, 0) +``` + +--- + +## 10.4 Memory barriers: `mem_bar` + +Within a single pipeline, load and store instructions may be reordered by the hardware. `mem_bar` enforces ordering when UB addresses alias between operations — for example, when a store to a region must be visible to a subsequent load from the same region. + +### `pto.mem_bar(barrier_type)` + +**Description**: Inserts a memory barrier that enforces ordering of prior and subsequent instructions within the same pipeline. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `barrier_type` | `BarrierType` | Barrier type controlling which categories of prior instructions must complete before which categories of subsequent instructions may proceed | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import BarrierType + +# Ensure all prior vector stores are visible before any subsequent vector loads +pto.mem_bar(BarrierType.VST_VLD) +``` + +The most commonly used barrier types in practice: + +| Use case | Barrier type | +|----------|--------------| +| General vector ordering | `BarrierType.VV_ALL` | +| Store-then-load to same UB region | `BarrierType.VST_VLD` | +| Vector → scalar handoff | `BarrierType.VS_ALL` | +| Scalar → vector handoff | `BarrierType.SV_ALL` | + +### Usage in ukernel blocks + +In flash attention, `mem_bar` separates logically independent computation phases within the same ukernel: + +```python +@pto.ukernel +def flash_attention_block(q_tile, k_tile, v_tile, ...): + # Phase 1: load K/V + pto.mte_load(k_part, k_tile) + pto.mte_load(v_part, v_tile) + pto.mem_bar(BarrierType.SYNC) + + # Phase 2: S = Q @ K^T + qk_matmul(q_tile, k_tile, ...) + pto.mem_bar(BarrierType.SYNC) + + # Phase 3: softmax(S) + online_softmax(s_tile, ...) + pto.mem_bar(BarrierType.SYNC) + + # Phase 4: PV = P @ V + pv_matmul(p_tile, v_tile, ...) + pto.mem_bar(BarrierType.SYNC) + + # Phase 5: blend output + blend_output(o_prev_tile, pv_tile, ...) + pto.mem_bar(BarrierType.SYNC) +``` + +--- + +## 10.5 Cross-core and intra-block synchronization + +Section 10.2 covers the general pipe-to-pipe sync mechanism (`set_flag`/`wait_flag`). This section covers two additional sync domains that the pipe-flag mechanism does not address: **cross-core** communication between separate NPU cores, and **intra-block** synchronization between the Cube and Vector units within a block. + +### 10.5.1 Cross-core sync: `set_cross_core`, `wait_flag_dev` + +When a kernel spans multiple cores, cores need to coordinate through shared resources. `set_cross_core` sends a signal to another core; `wait_flag_dev` blocks the calling core until the expected signal arrives. + +These are core-level (SU) operations — `wait_flag_dev` stalls the entire core, not just a single pipeline. Use them sparingly: splitting work so that each core operates independently for as long as possible minimises cross-core sync overhead. + +#### `pto.set_cross_core(core_id, event_id)` + +**Description**: Signal an event to another core, indicating that shared data or a pipeline stage is ready. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Target core identifier (platform-specific mapping) | +| `event_id` | `Event` | Cross-core event identifier | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import Event + +# Signal core 0 that our computation is complete +pto.set_cross_core(0, Event.ID0) +``` + +#### `pto.wait_flag_dev(core_id, event_id)` + +**Description**: Wait for an event from another core. Core-level (SU) blocking — the entire core stalls until the event is received. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Source core identifier | +| `event_id` | `Event` | Event identifier to wait on | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import Event + +# Core 1 waits for core 0 to signal event ID0 +pto.wait_flag_dev(0, Event.ID0) +``` + +### 10.5.2 Intra-block sync: `set_intra_block`, `wait_intra_core` + +The Cube unit (matrix pipeline) has a dedicated synchronization channel separate from the standard pipe-flag mechanism used by MTE and Vector pipelines. `set_intra_block` and `wait_intra_core` synchronize Cube and Vector within the same block, ensuring that shared UB tile data is not accessed before the producer finishes. + +Unlike `wait_flag_dev`, `wait_intra_core` only stalls the specified pipeline — the SU and other pipelines continue executing. + +#### `pto.set_intra_block(block_id, event_id)` + +**Description**: Signal a synchronization event within a block. Specifies which trigger pipe fires the event. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block or pipeline identifier for the trigger source | +| `event_id` | `Event` | Event identifier | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import Event + +# Signal event ID0 on block/pipeline 0 +pto.set_intra_block(0, Event.ID0) +``` + +#### `pto.wait_intra_core(block_id, event_id)` + +**Description**: Wait for an intra-block event. Only the specified pipeline stalls — the SU and other pipelines continue executing independently. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block or pipeline identifier specifying which pipeline waits | +| `event_id` | `Event` | Event identifier to wait on | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +from pto import Event + +# Pipeline 1 waits for event ID0 from pipeline 0 within the same block +pto.wait_intra_core(1, Event.ID0) +``` + +### 10.5.3 Intra-core configuration: `set_intra_core` + +#### `pto.set_intra_core(config)` + +**Description**: Configures intra-core synchronization parameters. The meaning of `config` is hardware-specific. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `config` | `pto.i32` | Hardware-specific configuration value | + +**Returns**: None (side-effect operation). + +**Example**: + +```python +pto.set_intra_core(3) +``` + +--- + +## 10.6 Synchronization in the abstraction hierarchy + +Where do sync operations belong in PTODSL's layered model? + +| Layer | Sync responsibility | +|-------|---------------------| +| L1 `@pto.jit` | Tile ops require sync, but PTOAS **auto-inserts** `set_flag`/`wait_flag` pairs based on op-to-pipe mapping — the user does not write sync explicitly | +| L2 `@pto.ukernel` | User writes micro-instructions directly and takes full responsibility for sync: `set_flag`/`wait_flag` between DMA and compute, `mem_bar` between compute phases, `pipe_barrier` at block boundaries | +| L3 `@pto.cube` / `@pto.simd` | Cross-pipeline sync (`set_flag`/`wait_flag`) is managed by the calling ukernel. Sub-kernels may still use `mem_bar` for intra-pipeline ordering (e.g., store-then-load to the same UB region) | + +**Rule of thumb**: at L1, sync can be manual or auto-inserted (`--enable-insert-sync`). At L2, sync is always explicit. + +### Auto-sync at the tile level + +When writing `@pto.jit` code with tile ops (`tload`, `tstore`, `tadd`, etc.), each op carries a pipe assignment (e.g., `tload` → `PIPE_MTE2`, `tadd` → `PIPE_V`). PTOAS's sync-insertion pass analyzes the op sequence, infers the necessary `set_flag`/`wait_flag` pairs from the pipe transitions, and injects them into the lowered code. The tile ops themselves still require synchronization — the difference is that the compiler, not the user, writes it. + +### Quick reference: which sync for which scenario + +| Scenario | Sync primitive | +|----------|----------------| +| DMA load must finish before compute | `set_flag(MTE2, V, id)` + `wait_flag(MTE2, V, id)` | +| Compute must finish before DMA store | `set_flag(V, MTE3, id)` + `wait_flag(V, MTE3, id)` | +| Two compute phases must not overlap | `mem_bar(BarrierType.VV_ALL)` | +| Store must be visible to later load (same UB) | `mem_bar(BarrierType.VST_VLD)` | +| Full pipeline sync point | `pipe_barrier(Pipe.ALL)` | +| Double-buffer handoff (compute → DMA) | `rls_buf(V, id)` + `get_buf(MTE2, id)` | +| Double-buffer handoff (DMA → compute) | `rls_buf(MTE2, id)` + `get_buf(V, id)` | +| Core A notifies core B | `set_cross_core(B, id)` + `wait_flag_dev(A, id)` | diff --git a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md new file mode 100644 index 000000000..4d7b05c5b --- /dev/null +++ b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md @@ -0,0 +1,527 @@ +# 11. Flash Attention Complete Walkthrough + +This chapter walks through `demos/flash_attention_sketch.py` layer by layer, tracing a complete flash attention implementation from the user-facing Python wrapper down to hardware-bound sub-kernels. Every API discussed in Chapters 1–10 appears in context here. + +The sketch computes **online-softmax flash attention** for one `(batch, head)` slice per launch instance. It partitions Q into blocks along the sequence dimension, iterates over KV blocks for each Q block, and maintains rolling softmax state across KV iterations. + +## 11.1 Architecture overview + +``` +flash_attention(...) L0 user-facing wrapper + └─ @pto.jit flash_attention_kernel + ├─ Tile Ops tload / tstore at the GM↔UB boundary + └─ @pto.ukernel kv_block_process + ├─ @pto.simt materialize_tile_bounds + ├─ @pto.cube qk_matmul + ├─ @pto.simd online_softmax_rows + ├─ @pto.cube pv_matmul + └─ @pto.simt blend_output_rows +``` + +The dataflow for one KV block: + +``` +ukernel loads K/V block and sequences the pipeline + │ + ├─ cube: Q + K ───────────────► S + ├─ simd: S + (m_prev, l_prev) ─► P, (m_next, l_next), alpha, beta + ├─ cube: P + V ───────────────► PV + └─ simt: (o_prev, PV, alpha, beta) ─► o_next + +After each KV block: + (m_prev, l_prev, o_prev) := (m_next, l_next, o_next) +``` + +## 11.2 L0 — Python wrapper + +```python +def flash_attention(Q, K, V, *, O=None, causal=False, + block_q=128, block_kv=128, stream=None): + if O is None: + O = pto.empty_like(Q) + + batch, seq_q, heads, dim = Q.shape + _, seq_k, _, _ = K.shape + + compiled = flash_attention_kernel.compile( + BLOCK_Q=block_q, BLOCK_KV=block_kv, CAUSAL=causal, + ) + compiled[batch * heads, stream](Q, K, V, O) + return O +``` + +This is plain Python — no PTO types, no IR. It handles ergonomic runtime concerns: + +- **Output allocation**: `pto.empty_like(Q)` when the caller doesn't provide one. +- **Shape extraction**: reads `batch`, `seq_q`, `heads`, `dim` from the framework tensors. +- **Compile + launch**: `flash_attention_kernel.compile(...)` JIT-compiles the kernel with the given constexpr parameters, then launches it with a `[batch * heads]` grid — one block per `(batch, head)` slice. + +L0 knows nothing about tiles, UB, or pipelines. It is the boundary between the user's tensor world and the PTO device world. + +## 11.3 L1 — `@pto.jit` kernel entry + +```python +@pto.jit(target="a5") +def flash_attention_kernel( + Q, K, V, O, *, + BLOCK_Q: pto.constexpr = 128, + BLOCK_KV: pto.constexpr = 128, + CAUSAL: pto.constexpr = False, + NUM_STAGES: pto.constexpr = 2, +): +``` + +The `@pto.jit` decorator marks the compile + launch boundary. Inputs are Python-native tensors; outputs are written in-place to `O`. Keyword-only `constexpr` parameters (`BLOCK_Q`, `BLOCK_KV`, `CAUSAL`) are baked at compile time. + +### 11.3.1 TensorView construction + +```python +q_view = pto.make_tensor_view(Q, shape=[batch, seq_q, heads, dim], + strides=Q.strides) +k_view = pto.make_tensor_view(K, shape=[batch, seq_k, heads, dim], + strides=K.strides) +v_view = pto.make_tensor_view(V, shape=[batch, seq_k, heads, dim], + strides=V.strides) +o_view = pto.make_tensor_view(O, shape=[batch, seq_q, heads, dim], + strides=O.strides) +``` + +`make_tensor_view` wraps each framework tensor with a PTO TensorView descriptor — a GM pointer paired with shape and stride metadata. These descriptors are what the rest of the kernel uses to address global memory. No data moves yet. + +### 11.3.2 SPMD launch contract + +```python +block_idx = pto.get_block_idx() +block_num = pto.get_block_num() +subblock_idx = pto.get_subblock_idx() +subblock_num = pto.get_subblock_num() + +batch_idx = block_idx // heads +head_idx = block_idx % heads +``` + +The launch grid is `[batch * heads]`. Each block computes one `(batch, head)` slice. `get_block_idx()` returns the current block's linear index; dividing by `heads` recovers the batch and head indices. + +### 11.3.3 Per-head view selection + +```python +q_head = pto.select_head_view(q_view, batch=batch_idx, head=head_idx, + shape=[seq_q, dim]) +k_head = pto.select_head_view(k_view, batch=batch_idx, head=head_idx, + shape=[seq_k, dim]) +v_head = pto.select_head_view(v_view, batch=batch_idx, head=head_idx, + shape=[seq_k, dim]) +o_head = pto.select_head_view(o_view, batch=batch_idx, head=head_idx, + shape=[seq_q, dim]) +``` + +`select_head_view` extracts a 2D slice `[seq, dim]` from the 4D tensor view for the current head. The resulting views are the working set for this block's entire computation. + +### 11.3.4 Tile allocation + +Two categories of tiles are allocated: + +**UB-resident tiles** — data tiles that live in the Unified Buffer: + +```python +q_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) +k_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) +v_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) + +o_prev_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) +o_next_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) +m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + +s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) +p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) +pv_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) +alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +``` + +The online-softmax algorithm requires **ping-pong state tiles**: `m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next`. After each KV block, `next` becomes `prev` for the following iteration. + +**Cube-local scratch tiles** — allocated in specific memory spaces: + +```python +q_l0a = pto.alloc_tile(shape=[Br, dim], dtype=pto.f16, + memory_space=pto.MemorySpace.LEFT) +p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f16, + memory_space=pto.MemorySpace.LEFT) +rhs_l0b = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f16, + memory_space=pto.MemorySpace.RIGHT) +qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, + memory_space=pto.MemorySpace.ACC) +pv_acc_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32, + memory_space=pto.MemorySpace.ACC) +``` + +Cube scratch tiles are NOT UB buffers. `LEFT`, `RIGHT`, and `ACC` are distinct hardware memory spaces inside the Cube unit. They serve as staging for matrix operands and accumulators. + +### 11.3.5 SIMT metadata buffer + +```python +meta_tile = pto.alloc_tile(shape=[3, 1], dtype=pto.i32) +meta_ptr = pto.tile_buf_addr(meta_tile) +``` + +A small UB tile stores three scalar loop bounds (`row_start`, `row_stop`, `valid_cols`). `tile_buf_addr` materializes a typed UB pointer into it, which is passed to the ukernel as scalar control metadata. + +### 11.3.6 Outer Q loop + inner KV loop + +```python +with pto.for_(0, q_blocks, step=1) as qi: + q_part = pto.partition_view(q_head, offsets=[qi * Br, 0], + sizes=[Br, dim]) + o_part = pto.partition_view(o_head, offsets=[qi * Br, 0], + sizes=[Br, dim]) + + pto.tload(q_part, q_tile) + + m_prev_tile.fill(float("-inf")) + l_prev_tile.fill(0.0) + o_prev_tile.fill(0.0) + + kv_loop = pto.for_(0, kv_blocks, step=1).carry( + m=m_prev_tile, l=l_prev_tile, o=o_prev_tile, + ) + with kv_loop: + kj = kv_loop.iv + m_cur = kv_loop.m + l_cur = kv_loop.l + o_cur = kv_loop.o + k_part = pto.partition_view(k_head, + offsets=[kj * Bc, 0], sizes=[Bc, dim]) + v_part = pto.partition_view(v_head, + offsets=[kj * Bc, 0], sizes=[Bc, dim]) + + kv_block_process( + q_tile, k_part, v_part, k_tile, v_tile, + o_cur, o_next_tile, + m_cur, l_cur, m_next_tile, l_next_tile, + s_tile, p_tile, pv_tile, + alpha_tile, beta_tile, + q_l0a, p_l0a, rhs_l0b, + qk_acc_tile, pv_acc_tile, + meta_ptr, + ) + + kv_loop.update(m=m_next_tile, l=l_next_tile, o=o_next_tile) + + o_final_tile = kv_loop.final("o") + pto.tstore(o_final_tile, o_part) +``` + +Key points: + +- **`tload` at the L1 boundary**: Q is loaded once per Q block using a tile op. The compiler auto-inserts the necessary `set_flag`/`wait_flag` pairs. +- **State initialization**: `fill(float("-inf"))` and `fill(0.0)` initialize the online-softmax accumulators before the first KV block. +- **Carry state**: the inner `kv_loop` carries three ping-pong tiles (`m`, `l`, `o`) across iterations using `.carry(...)` / `.update(...)` / `.final(...)`. After each KV block, the loop updates the carried values to the `_next` tiles. After the loop, `.final("o")` extracts the final output accumulator. +- **`tstore` at the L1 boundary**: writes the final result for this Q block back to GM. + +## 11.4 L2 — `@pto.ukernel` + +```python +@pto.ukernel +def kv_block_process( + q_tile, k_part, v_part, k_tile, v_tile, + o_prev_tile, o_next_tile, + m_prev_tile, l_prev_tile, m_next_tile, l_next_tile, + s_tile, p_tile, pv_tile, + alpha_tile, beta_tile, + q_l0a, p_l0a, rhs_l0b, + qk_acc_tile, pv_acc_tile, + meta_ptr, +): +``` + +The ukernel processes one KV block against an already-loaded Q tile. It owns the execution sandwich: + +### Phase 0 — Stage K/V data + +```python +pto.mte_load(k_part, k_tile) +pto.mte_load(v_part, v_tile) +pto.mem_bar(pto.BarrierType.SYNC) +``` + +`mte_load` copies the current K and V block from GM to UB. `mem_bar` ensures the DMA stores are visible before the cube unit reads `k_tile`/`v_tile`. + +### Phase 0b — Materialize loop bounds + +```python +materialize_tile_bounds(meta_ptr, + pto.tile_valid_rows(q_tile), + pto.tile_valid_rows(k_tile)) +row_start = scalar.load(meta_ptr + 0) +row_stop = scalar.load(meta_ptr + 4) +valid_cols = scalar.load(meta_ptr + 8) +``` + +The SIMT sub-kernel `materialize_tile_bounds` writes `{0, valid_rows, valid_cols}` into the metadata buffer. The ukernel then loads these scalars. They control the row iteration range in subsequent sub-kernels, handling partial tail blocks. + +### Phase 1 — `S = Q @ K^T` + +```python +qk_matmul(q_tile, k_tile, q_l0a, rhs_l0b, qk_acc_tile, s_tile) +pto.mem_bar(pto.BarrierType.SYNC) +``` + +Dispatches the cube sub-kernel. `mem_bar` separates the matrix multiply from the subsequent softmax. + +### Phase 2 — Online softmax + +```python +online_softmax_rows( + s_tile, p_tile, + m_prev_tile, l_prev_tile, + m_next_tile, l_next_tile, + alpha_tile, beta_tile, + row_start, row_stop, valid_cols, +) +pto.mem_bar(pto.BarrierType.SYNC) +``` + +The simd sub-kernel computes per-row softmax on `S`, updates the running `m`/`l` state, and writes `P`, `alpha`, and `beta`. + +### Phase 3 — `PV = P @ V` + +```python +pv_matmul(p_tile, v_tile, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) +pto.mem_bar(pto.BarrierType.SYNC) +``` + +Second cube dispatch. `rhs_l0b` is reused for `V` (it previously held `K`). `pv_acc_tile` is reused from the QK^T accumulator. + +### Phase 4 — Blend output + +```python +blend_output_rows( + o_prev_tile, pv_tile, alpha_tile, beta_tile, + o_next_tile, row_start, row_stop, + pto.tile_valid_cols(v_tile), +) +pto.mem_bar(pto.BarrierType.SYNC) +``` + +The simt sub-kernel blends the old output accumulator with the new PV contribution, weighted by `alpha` and `beta`. + +### Why the ukernel owns sync + +Each `mem_bar` between phases is explicit in the ukernel body. This is intentional: at the L2 micro-instruction level, the user controls pipeline ordering. There is no auto-sync insertion — the ukernel is the single place where the hardware execution sequence is spelled out. + +## 11.5 L3a — `@pto.cube` sub-kernels + +### `qk_matmul` — `S = Q @ K^T` + +```python +@pto.cube +def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): + m = pto.tile_valid_rows(q_tile) + k = pto.tile_valid_cols(q_tile) + n = pto.tile_valid_rows(k_tile) + + pto.mte_l1_l0a(q_tile, q_l0a, m, k) + pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) + pto.mad(q_l0a, k_l0b, s_acc) + pto.mte_l0c_ub(s_acc, s_tile, m, n) +``` + +Four cube ops: + +1. **`mte_l1_l0a`**: load Q tile from UB into LEFT scratch (`q_l0a`). +2. **`mte_l1_l0b`**: load K tile from UB into RIGHT scratch (`k_l0b`), with `transpose=True` for K^T. +3. **`mad`**: matrix multiply-accumulate — `s_acc = q_l0a @ k_l0b`. +4. **`mte_l0c_ub`**: write the accumulator result to the UB output tile `s_tile`. + +The cube kernel does not allocate scratch — the caller (L1) owns scratch lifetime. The cube kernel only expresses dataflow. + +### `pv_matmul` — `PV = P @ V` + +```python +@pto.cube +def pv_matmul(p_tile, v_tile, p_l0a, v_l0b, pv_acc, pv_tile): + m = pto.tile_valid_rows(p_tile) + k = pto.tile_valid_cols(p_tile) + n = pto.tile_valid_cols(v_tile) + + pto.mte_l1_l0a(p_tile, p_l0a, m, k) + pto.mte_l1_l0b(v_tile, v_l0b, k, n) + pto.mad(p_l0a, v_l0b, pv_acc) + pto.mte_l0c_ub(pv_acc, pv_tile, m, n) +``` + +Structurally identical to `qk_matmul`, but without transposition and with different input/output tiles. The scratch tiles `p_l0a`, `v_l0b`, and `pv_acc` are reused across KV blocks — the caller (L1) allocates them once. + +## 11.6 L3b — `@pto.simd` online softmax + +```python +@pto.simd +def online_softmax_rows( + s_tile, p_tile, + m_prev_tile, l_prev_tile, + m_next_tile, l_next_tile, + alpha_tile, beta_tile, + row_start, row_stop, valid_cols, +): +``` + +The simd kernel iterates over rows with `pto.for_`, processing one row per iteration: + +```python +with pto.for_(row_start, row_stop, step=1) as row: + col_mask = pto.make_mask(pto.f32, valid_cols) + + s_row = pto.vlds(s_tile[row, 0:]) + m_prev = scalar.load(m_prev_tile[row, 0]) + l_prev = scalar.load(l_prev_tile[row, 0]) +``` + +- **Mask creation**: `make_mask(pto.f32, valid_cols)` generates a tail mask for the column dimension. On the last KV block, `valid_cols` may be less than the full block width. +- **Vector load**: `vlds(s_tile[row, 0:])` loads one entire row of `S` from UB into a vector register. The slice syntax `[row, 0:]` selects the full row. +- **Scalar load**: `lds` reads per-row scalars (`m_prev`, `l_prev`) from the state tiles. + +### Softmax computation + +```python + row_max = pto.vcgmax(s_row, col_mask) + m_next = scalar.max(m_prev, row_max) + + s_shifted = pto.vsubs(s_row, m_next, col_mask) + p_row = pto.vexp(s_shifted, col_mask) + + row_sum = pto.vcgadd(p_row, col_mask) + l_scaled = l_prev * scalar.exp(m_prev - m_next) + l_next = l_scaled + row_sum + + alpha = l_scaled / l_next + beta = 1.0 / l_next +``` + +This implements the online-softmax update from the Flash Attention paper: + +- `vcgmax` (cross-lane max reduction) finds the row maximum. +- `max(m_prev, m_next)` combines with the running maximum. +- `vsubs` subtracts the scalar `m_next` from every lane (stabilized softmax). +- `vexp` computes `exp(s_shifted)` element-wise. +- `vcgadd` (cross-lane sum reduction) computes the row sum. +- `l_scaled` rescales the previous sum with the running-max correction factor. +- `alpha` and `beta` are the blending coefficients for the output update. + +### Store results + +```python + pto.vsts(p_row, p_tile[row, 0:], col_mask) + scalar.sts(m_next_tile[row, 0], m_next) + scalar.sts(l_next_tile[row, 0], l_next) + scalar.sts(alpha_tile[row, 0], alpha) + scalar.sts(beta_tile[row, 0], beta) +``` + +- `vsts` stores the vector `p_row` back to UB under the column mask. +- `sts` stores each scalar to its respective UB tile. + +**Boundary contract**: vreg values (`s_row`, `p_row`, `row_max`, `row_sum`) never escape the simd kernel. All persistent state is written to UB tiles. + +## 11.7 L3c — `@pto.simt` sub-kernels + +### `materialize_tile_bounds` — scalar metadata + +```python +@pto.simt +def materialize_tile_bounds(meta_ptr, valid_rows, valid_cols): + scalar.sts(meta_ptr + 0, 0) + scalar.sts(meta_ptr + 4, valid_rows) + scalar.sts(meta_ptr + 8, valid_cols) +``` + +Three scalar stores write the loop bounds into the metadata buffer. `meta_ptr` is a typed UB pointer; `+ 0`, `+ 4`, `+ 8` are byte offsets (three `i32` values). This is the simplest sub-kernel in the sketch — it handles scalar control metadata, not vector math. + +### `blend_output_rows` — output accumulation + +```python +@pto.simt +def blend_output_rows(o_prev_tile, pv_tile, alpha_tile, beta_tile, + o_next_tile, row_start, row_stop, valid_dim): + with pto.for_(row_start, row_stop, step=1) as row: + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + + with pto.for_(0, valid_dim, step=1) as col: + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) + o_next = alpha * o_prev + beta * pv_val + scalar.sts(o_next_tile[row, col], o_next) +``` + +This is a scalar element-wise blend over the tile domain: + +``` +O_next[row, col] = alpha[row] * O_prev[row, col] + beta[row] * PV[row, col] +``` + +The SIMT kernel walks the tile element by element with nested `pto.for_` loops. Each iteration loads two scalars (`o_prev` and `pv_val`), computes the weighted sum, and stores the result. The `alpha`/`beta` coefficients are per-row (loaded once per row), while the blend is per-element. + +**Why SIMT instead of SIMD?** The intent is to contrast with `online_softmax_rows`: softmax is dominated by row-wise vector reductions and exponentials — natural SIMD work. The final blend is a simple linear combination with per-row coefficients — expressing it as explicit scalar work-items makes the per-element access pattern explicit and leaves the compiler free to vectorize or fuse as it sees fit. + +### Context manager alternative + +For trivial sub-kernels like `materialize_tile_bounds`, a named function is overkill — the context manager form keeps the logic inline where it's used. Here is how the ukernel body would look with `materialize_tile_bounds` inlined: + +```python +@pto.ukernel +def kv_block_process(...): + pto.mte_load(k_part, k_tile) + pto.mte_load(v_part, v_tile) + pto.mem_bar(pto.BarrierType.SYNC) + + # Inline SIMT: materialize loop bounds (replaces the named @pto.simt function) + with pto.simt(): + scalar.sts(meta_ptr + 0, 0) + scalar.sts(meta_ptr + 4, valid_rows) + scalar.sts(meta_ptr + 8, valid_cols) + + pto.mem_bar(pto.BarrierType.SYNC) + + qk_matmul(q_tile, k_tile, ...) + ... +``` + +The `with pto.simt():` block is semantically identical to calling a `@pto.simt` function — the compiler treats it as an anonymous sub-kernel. For 3-line helpers that have no reuse, the context manager avoids the indirection of a separate function. For complex, reusable logic like `online_softmax_rows` or `qk_matmul`, the named decorator form remains the better fit. + +## 11.8 Putting it all together: one KV block execution + +For one KV block, the full execution sequence is: + +| Step | Layer | Operation | Hardware | +|------|-------|-----------|----------| +| 1 | L1 | `tload(q_part, q_tile)` | MTE2 → UB | +| 2 | L2 | `mte_load(k_part, k_tile)` | MTE2 → UB | +| 3 | L2 | `mte_load(v_part, v_tile)` | MTE2 → UB | +| 4 | L2 | `mem_bar(SYNC)` | — | +| 5 | L3c | `materialize_tile_bounds` | SIMT | +| 6 | L3a | `qk_matmul` (mte_l1_l0a, mte_l1_l0b, mad, mte_l0c_ub) | Cube | +| 7 | L2 | `mem_bar(SYNC)` | — | +| 8 | L3b | `online_softmax_rows` (vlds, vcgmax, vexp, vcgadd, vsts, ...) | SIMD | +| 9 | L2 | `mem_bar(SYNC)` | — | +| 10 | L3a | `pv_matmul` | Cube | +| 11 | L2 | `mem_bar(SYNC)` | — | +| 12 | L3c | `blend_output_rows` | SIMT | +| 13 | L2 | `mem_bar(SYNC)` | — | + +After all KV blocks: L1 issues `tstore(o_final_tile, o_part)` to write the result back to GM. + +## 11.9 Design patterns in this sketch + +**Ping-pong state for online accumulators**: `m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next` make the state transition explicit. After each KV block, the caller swaps the ping-pong pair (via `kv_loop.update(...)`) rather than aliasing in place. + +**Scratch reuse**: `rhs_l0b` serves both `K` (in `qk_matmul`) and `V` (in `pv_matmul`). `pv_acc_tile` reuses the accumulator from QK^T. The caller (L1) allocates once; the ukernel passes them to both cube sub-kernels. + +**Tile-level boundary vs micro-instruction boundary**: `tload`/`tstore` appear only in `@pto.jit`. `mte_load`/`mte_store` appear only in `@pto.ukernel`. This is the key abstraction split: L1 operates on tiles, L2 operates on micro-instructions. + +**No vreg across sub-kernel boundaries**: vector registers are local to each `@pto.simd` kernel. Data crosses sub-kernel boundaries through UB tiles — the boundary contract is enforced by the type system. + +**L3 invocation flexibility**: This sketch uses the explicit `@pto.ukernel` → L3 path for full control over MTE and sync. For simpler kernels that don't need that control, L3 sub-kernels can be called directly from `@pto.jit` (the compiler handles MTE + sync) or written inline as context managers (`with pto.simd():`, etc.). See Chapter 3 for details. diff --git a/ptodsl/docs/user_guide/11-vector-arithmetic-operations.md b/ptodsl/docs/user_guide/11-vector-arithmetic-operations.md deleted file mode 100644 index ede8388df..000000000 --- a/ptodsl/docs/user_guide/11-vector-arithmetic-operations.md +++ /dev/null @@ -1,1611 +0,0 @@ -### Unary Vector Operations - -Element-wise unary operations on vector registers. - -#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Absolute value of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Absolute values | - -**Constraints**: -- Mask granularity must match vector element type (e.g., `f32` requires `mask_b32`) - -**Example**: -```python -abs_vec = pto.vabs(vec_f32, mask32) -``` - -#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Exponential of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Exponential values | - -#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Natural logarithm of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Natural logarithm values | - -#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Square root of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Square root values | - -#### `pto.vrec(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Reciprocal of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Reciprocal values | - -#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: ReLU activation (max(0, x)) of vector elements. - -**Supported dtypes**: `si32`, `i32`, `f16`, `f32` - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | ReLU-activated values | - -#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Bitwise NOT of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise NOT values | - -#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Reduction add of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Reduction result vector | - -**Type Rules**: -- For floating-point inputs and `i32/ui32`, the result vector type matches the input vector type. -- For `i8/ui8` inputs, `pto.vcadd` returns a widened `i16/ui16` vector. -- For `i16/ui16` inputs, `pto.vcadd` returns a widened `i32/ui32` vector. -- The result mask granularity follows the result vector element type. - -#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Complex maximum of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Complex maximum result | - -#### `pto.vbcnt(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Bit count (population count) of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bit count values | - -#### `pto.vneg(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Negation of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Negated values | - -**Constraints**: -- Mask granularity must match vector element type - -**Example**: -```python -neg_vec = pto.vneg(vec_f32, mask32) -``` - -#### `pto.vcls(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Count leading sign bits of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Count of leading sign bits | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vcmin(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Complex minimum of vector elements (treating pairs as complex numbers). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Complex minimum result | - -#### `pto.vrsqrt(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Reciprocal square root of vector elements (1/√x). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Reciprocal square root values | - -**Constraints**: -- For floating-point vector types only - -#### `pto.vprelu(vec: VRegType, alpha: VRegType, mask: MaskType) -> VRegType` - -**Description**: Parametric ReLU activation of vector elements: `x if x >= 0 else alpha * x`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `alpha` | `VRegType` | Slope parameter for negative values | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Parametric ReLU activated values | - -#### `pto.vmov(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Vector move (data movement). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Copied vector | - -#### `pto.vsunpack(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Signed unpack of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Unpacked signed values | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vzunpack(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Zero-extended unpack of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Unpacked zero-extended values | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vusqz(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Unsigned squeeze (compression) of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Compressed unsigned values | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vsqz(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Signed squeeze (compression) of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Compressed signed values | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, part: pto.VcvtPartMode) -> VRegType` - -**Description**: Fused exponential difference `exp(vec - max_vec)` for numerically stable softmax lowering. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `max_vec` | `VRegType` | Per-lane max vector subtracted before exponentiation | -| `mask` | `MaskType` | Predicate mask. Use `b16` for `f16` inputs and `b32` for `f32` inputs. | -| `part` | `pto.VcvtPartMode` | Output part selector enum. Use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD`. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Exponential difference values; result element type is `f32` | - -**Constraints**: -- Supports `f16` and `f32` input vectors only -- `vec` and `max_vec` must use the same vector type -- `mask` granularity must match the input vector element width -- `part` should use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD` -- Canonical strings `"EVEN"` / `"ODD"` are still accepted for compatibility - -### Binary Vector Operations - -Element-wise binary operations on vector registers. - -#### `pto.vadd(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise addition of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sum of vectors | - -**Example**: -```python -sum_vec = pto.vadd(vec_a, vec_b, mask32) -``` - -#### `pto.vsub(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise subtraction of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Difference of vectors | - -#### `pto.vmul(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise multiplication of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Product of vectors | - -#### `pto.vdiv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise division of two vectors. - -- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16` and `f32`. -- `f16`/`f32` authoring code stays on the public `pto.vdiv` VPTO path. -- Integer `pto.vdiv` also uses the same public surface, but lowers through an internal soft-helper path. -- For `i8`/`ui8`, the integer lowering widens to 16-bit lanes, computes the soft division, then narrows back to 8-bit lanes. -- Internal helper names such as `_tl_soft_vdiv_*` are implementation details and are not part of the supported DSL call surface. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Quotient of vectors | - -#### `pto.vmod(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise modulo of two vectors. - -- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`). -- Floating-point `vmod` is not part of the current TileLang DSL v1 public surface. -- `pto.vmod` is the only public vector modulo entry point in TileLang DSL v1. -- The current implementation lowers through an internal soft-helper path; helper names such as `_tl_soft_vmod_*` are intentionally hidden implementation details. -- For `i8`/`ui8`, the modulo path uses an explicit widen-to-16-bit, soft-compute, narrow-back-to-8-bit profile. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | Dividend vector | -| `vec2` | `VRegType` | Divisor vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Remainder vector | - -#### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise maximum of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Element-wise maximum | - -#### `pto.vmin(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise minimum of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Element-wise minimum | - -#### `pto.vand(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise AND of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise AND result | - -#### `pto.vor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise OR of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise OR result | - -#### `pto.vxor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise XOR of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise XOR result | - -#### `pto.vshl(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise shift left (vector shift amounts). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `VRegType` | Shift amounts (per element) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -#### `pto.vshr(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise shift right (vector shift amounts). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `VRegType` | Shift amounts (per element) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -#### `pto.vaddrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Addition with ReLU activation (max(0, vec1 + vec2)). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | ReLU-activated sum of vectors | - -#### `pto.vaddreluconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Convolution addition with ReLU activation (convolution-specific fused operation). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | ReLU-activated convolution sum | - -**Constraints**: -- Optimized for convolution-specific patterns - -#### `pto.vsubrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Subtraction with ReLU activation (max(0, vec1 - vec2)). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | ReLU-activated difference of vectors | - -#### `pto.vaxpy(alpha: VRegType, x: VRegType, y: VRegType, mask: MaskType) -> VRegType` - -**Description**: BLAS AXPY operation (αx + y). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `alpha` | `VRegType` | Scaling factor | -| `x` | `VRegType` | Input vector x | -| `y` | `VRegType` | Input vector y | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Result of αx + y | - -#### `pto.vmulconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Convolution multiplication (convolution-specific multiplication). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Convolution product | - -**Constraints**: -- Optimized for convolution-specific patterns - -#### `pto.vmull(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, VRegType)` - -**Description**: Widening multiply with split low/high results (extended arithmetic). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `low` | `VRegType` | Low part of widened product (`r & 0xFFFFFFFF`) | -| `high` | `VRegType` | High part of widened product (`r >> 32`) | - -**Constraints**: -- Current A5 documented form is native `i32/u32` 32x32->64 widening multiply -- Result is split into two vector outputs instead of a single widened vector - -**Example**: -```python -low, high = pto.vmull(lhs_i32, rhs_i32, mask32) -``` - -#### `pto.vmula(vec1: VRegType, vec2: VRegType, vec3: VRegType, mask: MaskType) -> VRegType` - -**Description**: Fused multiply-add (vec1 * vec2 + vec3). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector (multiplier) | -| `vec2` | `VRegType` | Second input vector (multiplicand) | -| `vec3` | `VRegType` | Third input vector (addend) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Result of vec1 * vec2 + vec3 | - -### Vector-Scalar Operations - -Operations between vectors and scalars. - -#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector multiplied by scalar (broadcast). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar multiplier | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Scaled vector | - -**Example**: -```python -scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) -``` - -#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector plus scalar (broadcast). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar addend | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Result vector | - -#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Element-wise maximum of vector and scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar value | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Maximum values | - -#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Element-wise minimum of vector and scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar value | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Minimum values | - -#### `pto.vlrelu(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Leaky ReLU activation (max(αx, x)). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Alpha coefficient | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Leaky ReLU activated values | - -#### `pto.vshls(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` - -**Description**: Vector shift left by scalar (uniform shift). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `i16` | Shift amount (same for all elements) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -#### `pto.vshrs(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` - -**Description**: Vector shift right by scalar (uniform shift). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `i16` | Shift amount (same for all elements) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -#### `pto.vands(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise AND of vector and scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar operand | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise AND result | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise OR of vector and scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar operand | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise OR result | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vxors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise XOR of vector and scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar operand | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise XOR result | - -**Constraints**: -- Operates on integer vector types only - -#### `pto.vsubs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector minus scalar (broadcast). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar subtrahend | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Difference vector | - -#### `pto.vbr(value: ScalarType) -> VRegType` - -**Description**: Broadcast scalar to all vector lanes. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `value` | `ScalarType` | Scalar source | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Vector whose active lanes all carry `value` | - -**Constraints**: -- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. -- For integer types, only the low bits of the scalar source are consumed according to the bit width (8, 16, or 32 bits). - -**Example**: -```python -# Broadcast scalar constant to vector -zero_vec = pto.vbr(0.0) -one_vec = pto.vbr(1.0) - -# Reduction seed with explicit floating dtype -rowmax_seed_f32 = pto.vbr(pto.f32("-inf")) -rowmax_seed_f16 = pto.vbr(pto.f16("0xFC00")) -``` - -#### `pto.vdup(input: ScalarType, mask: MaskType) -> VRegType` -#### `pto.vdup(input: VRegType, mask: MaskType, position: PositionMode = PositionMode.LOWEST) -> VRegType` - -**Description**: Duplicate a scalar value or one selected vector element into -the active lanes of a destination vector. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `input` | `ScalarType` or `VRegType` | Input scalar or source vector | -| `mask` | `MaskType` | Predicate mask controlling which lanes are written | -| `position` | `PositionMode` | Optional enum for the vector-input overload, selecting the source vector element to duplicate (default: `PositionMode.LOWEST`) | - -**Position Mode Enum**: The `PositionMode` enum provides type-safe source-lane -selection for `pto.vdup`. `LOWEST` selects the lowest-index element of the -source vector and `HIGHEST` selects the highest-index element. The enum is only -used by the vector-input overload. - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Vector whose active lanes receive the duplicated value | - -**Constraints**: -- `mask` granularity must match the destination vector element type. For - example, `f32`/`i32`/`si32`/`ui32` vectors require `mask_b32`. -- When `input` is a scalar, the scalar value is duplicated to every active lane. -- When `input` is a vector, `position` selects a single source element and that - value is duplicated to every active lane. -- The scalar overload does not accept `position`. -- Inactive lanes follow VPTO predicate semantics and are not guaranteed to carry - meaningful values for subsequent masked-off use. -- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. -- `position` is only meaningful for vector input. TileLang DSL currently exposes - `PositionMode.LOWEST` and `PositionMode.HIGHEST`, matching VPTO v0.3. - -**Example**: -```python -mask32 = pto.make_mask(pto.f32, pto.PAT.ALL) - -# Duplicate a scalar into all active lanes. -broadcast = pto.vdup(3.14, mask32) - -# Use dtype constructors for floating-point special values. -seed = pto.vdup(pto.f32("-inf"), mask32) -seed_f16 = pto.vdup(pto.f16("0xFC00"), pto.make_mask(pto.f16, pto.PAT.ALL)) - -# Assume `vec` is an existing `f32` vector register value. -vec = pto.vlds(src, 0) - -# Duplicate the lowest source lane to all active lanes. -dup_lowest = pto.vdup(vec, mask32) # position defaults to "LOWEST" - -# Duplicate the highest source lane to all active lanes. -dup_highest = pto.vdup(vec, mask32, pto.PositionMode.HIGHEST) -``` - -**Type Safety Note**: -- For floating-point seeds, prefer `pto.f16(...)` / `pto.bf16(...)` / `pto.f32(...)` constructors. -- Do not pass integer bit-pattern literals directly (for example `0xFF800000`) when a floating vector type is intended. - -### Carry & Select Operations - -Operations with carry propagation and selection. - -**Comparison Mode Enum**: The `CmpMode` enum provides type-safe comparison mode specification for `pto.vcmp` and `pto.vcmps` operations. It includes the following values: `EQ` (equal), `NE` (not equal), `LT` (less than), `LE` (less than or equal), `GT` (greater than), `GE` (greater than or equal). - -Implemented current-package carry/select surface also includes: -- `pto.vselr(vec0, vec1) -> VRegType` -- `pto.vselrv2(vec0, vec1) -> VRegType` -- `pto.vaddcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` -- `pto.vsubcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` - -#### `pto.vcmp(vec0: VRegType, vec1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` - -**Description**: Element-wise vector comparison with seed mask. Compares two vectors element-wise and generates a predicate mask based on the specified comparison mode. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec0` | `VRegType` | First input vector | -| `vec1` | `VRegType` | Second input vector | -| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | -| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ` (equal), `CmpMode.NE` (not equal), `CmpMode.LT` (less than), `CmpMode.LE` (less than or equal), `CmpMode.GT` (greater than), `CmpMode.GE` (greater than or equal) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Generated predicate mask based on element-wise comparison | - -**Constraints**: -- Only lanes enabled by `seed_mask` participate in the comparison -- The two input vectors must have the same element type and vector length -- The output mask granularity matches the input vector element type - -**Example**: -```python -# Compare two vectors for less-than relation -all_mask = pto.make_mask(pto.f32, PAT.ALL) -lt_mask = pto.vcmp(vec_a, vec_b, all_mask, CmpMode.LT) -``` - -#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` - -**Description**: Vector-scalar comparison with seed mask. Compares each element of a vector against a scalar value and generates a predicate mask based on the specified comparison mode. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar value to compare against (must match vector element type) | -| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | -| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ`, `CmpMode.NE`, `CmpMode.LT`, `CmpMode.LE`, `CmpMode.GT`, `CmpMode.GE` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Generated predicate mask based on vector-scalar comparison | - -**Constraints**: -- Only lanes enabled by `seed_mask` participate in the comparison -- The scalar type must match the vector element type -- The output mask granularity matches the input vector element type - -**Example**: -```python -# Check which elements are greater than zero -all_mask = pto.make_mask(pto.f32, PAT.ALL) -positive_mask = pto.vcmps(values, pto.f32(0.0), all_mask, CmpMode.GT) -``` - -#### `pto.vaddc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` - -**Description**: Vector addition with carry output. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sum vector | -| `carry_out` | `MaskType` | Output carry mask | - -#### `pto.vsubc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` - -**Description**: Vector subtraction with borrow output. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Difference vector | -| `borrow_out` | `MaskType` | Output borrow mask | - -#### `pto.vsel(true_vec: VRegType, false_vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Vector select based on mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `true_vec` | `VRegType` | Vector selected when mask bit is 1 | -| `false_vec` | `VRegType` | Vector selected when mask bit is 0 | -| `mask` | `MaskType` | Selection mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Selected vector | - -**Example**: -```python -result = pto.vsel(scaled_vec, original_vec, mask32) -``` - -### Reduction Operations - -Reduction operations across vector lanes or channels. - -#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Cross-group addition reduction (reduction across VLanes). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Reduced sum across groups | - -#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Cross-group maximum reduction (reduction across VLanes). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Reduced maximum across groups | - -#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Cross-group minimum reduction (reduction across VLanes). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Reduced minimum across groups | - -#### `pto.vcpadd(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Cross-channel addition reduction (reduction across channels). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Reduced sum across channels | - -### Data Rearrangement - -Operations for rearranging data within vectors. - -Predicate rearrangement ops `pto.pdintlv_b8` and `pto.pintlv_b16` are documented in `10-predicate-operations.md` because they operate on predicate masks rather than vector registers. - -Implemented current-package rearrangement surface also includes: -- `pto.vintlvv2(vec0, vec1, part) -> VRegType` -- `pto.vdintlvv2(vec0, vec1, part) -> VRegType` - -#### `pto.vintlv(vec1: VRegType, vec2: VRegType) -> (VRegType, VRegType)` - -**Description**: Interleave two vectors and return the low/high results. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `low` | `VRegType` | Low interleaved result | -| `high` | `VRegType` | High interleaved result | - -#### `pto.vdintlv(vec0: VRegType, vec1: VRegType) -> (VRegType, VRegType)` - -**Description**: Deinterleave a pair of vectors into low/high results. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec0` | `VRegType` | First input vector | -| `vec1` | `VRegType` | Second input vector | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec1` | `VRegType` | First deinterleaved vector | -| `vec2` | `VRegType` | Second deinterleaved vector | - -#### `pto.vpack(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Vector packing (combine elements from two vectors). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Packed vector | - -#### `pto.vperm(vec: VRegType, indices: VRegType, mask: MaskType) -> VRegType` - -**Description**: Vector permutation (reorder elements according to index vector). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `indices` | `VRegType` | Permutation indices | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Permuted vector | - -#### `pto.vshift(vec: VRegType, shift_amount: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Generic vector shift (shift all elements by same amount). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift_amount` | `ScalarType` | Shift amount (same for all elements) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted vector | - -#### `pto.vslide(vec: VRegType, window_size: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector sliding window (create overlapping windows). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `window_size` | `ScalarType` | Size of sliding window | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sliding window result | - -#### `pto.vsort32(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: 32-element sorting of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (32 elements) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sorted vector | - -**Constraints**: -- Input vector must have exactly 32 elements - -#### `pto.vmrgsort(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Merge sort of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Merged and sorted vector | - -#### `pto.vtranspose(dest: ptr, src: ptr, config: pto.i64) -> None` [Advanced Tier] - -**Description**: UB-to-UB transpose operation. This op works on UB memory directly (not `vreg -> vreg`). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `dest` | `ptr` | Destination pointer in UB memory space | -| `src` | `ptr` | Source pointer in UB memory space | -| `config` | `pto.i64` | ISA control/config operand that encodes transpose layout behavior | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `None` | `None` | Side-effect operation that writes transposed data to `dest` | - -**Constraints**: -- `dest` and `src` must be UB pointers -- Correctness depends on the `config` encoding and UB layout contract - -**Example**: -```python -pto.vtranspose(dst_ub_ptr, src_ub_ptr, config_word) -``` - -### Conversion & Special Operations - -Type conversion and specialized operations. - -#### `pto.vtrc(vec: VRegType, mask: MaskType, rnd: pto.VcvtRoundMode | None = None) -> VRegType` - -**Description**: Truncate/round float to integer-valued float (stays in float type). This is the TileLang DSL surface for the VPTO `pto.vtrc` operation. - -**Attribute Enums**: -- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` (note: `vtrc` does not support `O`) - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | -| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `round_mode`. Defaults to `R` if not specified. | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Truncated vector with integer-valued float elements | - -**Constraints**: -- Current TileLang DSL v1 accepts exactly two positional arguments: `pto.vtrc(vec, mask)`. Optional `rnd` attribute is exposed as keyword argument: `rnd=...`. -- The underlying VPTO op syntax is `pto.vtrc %input, %mask, "RND"`. -- Supported rounding modes are `R` (round to nearest), `A` (round away from zero), `F` (floor), `C` (ceil), `Z` (truncate toward zero). -- The enum form is preferred. For compatibility, canonical strings such as `"R"`, `"A"`, `"F"`, `"C"`, `"Z"` are also accepted. -- This op does not change the element type; input and output have the same vector type. -- Only floating-point element types are supported: `f16`, `bf16`, `f32`. - -#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType, rnd: pto.VcvtRoundMode | None = None, sat: pto.VcvtSatMode | None = None, part: pto.VcvtPartMode | None = None) -> VRegType` - -**Description**: Convert vector elements between supported float and integer -families. This is the TileLang DSL surface for the VPTO `pto.vcvt` conversion -family. - -**Attribute Enums**: -- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` -- `pto.VcvtSatMode`: `SAT`, `NOSAT` -- `pto.VcvtPartMode`: `EVEN`, `ODD`, `P0`, `P1`, `P2`, `P3` - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `to_type` | `Type` | Target scalar dtype symbol for the result vector element type | -| `mask` | `MaskType` | Predicate mask selecting active source lanes. Its granularity must match the source vector family, not the destination family | -| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `rnd` | -| `sat` | `pto.VcvtSatMode` \| `None` | Optional saturation attribute lowered to VPTO `sat` | -| `part` | `pto.VcvtPartMode` \| `None` | Optional width-changing lane-placement selector lowered to VPTO `part` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Converted vector with the vreg shape implied by `to_type` | - -**Constraints**: -- Current TileLang DSL v1 accepts exactly three positional arguments: - `pto.vcvt(vec, to_type, mask)`. Optional attributes are exposed as keyword - arguments: `rnd=...`, `sat=...`, `part=...`. -- The underlying VPTO op family is the fuller - `pto.vcvt %input, %mask {rnd, sat, part}` surface, and the DSL keywords map - directly to those VPTO attributes. -- `mask` always follows the source vector family: - `f32`/`i32`/`si32`/`ui32` use `mask_b32`; - `f16`/`bf16`/`i16`/`si16`/`ui16` use `mask_b16`; - `i8`/`si8`/`ui8` use `mask_b8`. -- The enum form is preferred. For compatibility, canonical strings such as - `"R"`, `"SAT"`, and `"EVEN"` are also accepted. -- VPTO `part` supports two families: `Part` (`EVEN`/`ODD`) for ordinary - width-changing conversions (e.g. `32 -> 16`, `16 -> 32`), and `Part_T` - (`P0`–`P3`) for 4-way packed placement (e.g. `32 -> 8`, fp8/fp4 flows). - - | Mode | VPTO spelling | Family | Description | TileLang DSL v1 status | - |------|---------------|--------|-------------|------------------------| - | `EVEN` | `PART_EVEN` | `Part` | Output to even-indexed lanes | Exposed as `pto.VcvtPartMode.EVEN` | - | `ODD` | `PART_ODD` | `Part` | Output to odd-indexed lanes | Exposed as `pto.VcvtPartMode.ODD` | - | `P0` | `PART_P0` | `Part_T` | Output to sub-part 0 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P0` | - | `P1` | `PART_P1` | `Part_T` | Output to sub-part 1 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P1` | - | `P2` | `PART_P2` | `Part_T` | Output to sub-part 2 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P2` | - | `P3` | `PART_P3` | `Part_T` | Output to sub-part 3 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P3` | -- Only backend-supported source/destination type pairs are legal. For the full - A5 `vcvt` type matrix, width-changing packing rules, and attribute-sensitive - forms, refer to - [`../vpto_spec/vpto-spec-current.md`](../vpto_spec/vpto-spec-current.md). -- Attribute requirements are type-pair specific. The DSL enforces the same - per-form contract as VPTO, so some pairs require attributes while others - reject them. -- Examples: - `f32 -> si32` requires `rnd` and `sat`; - `f16 -> si32` requires `rnd` and `part`, and rejects `sat`; - `bf16 -> f16` requires `rnd` and `sat`; - `f16 -> f32` requires `part`; - `f32 -> f16` requires `rnd`, `sat`, and `part`; - `si32 -> f32` requires `rnd`. -- VPTO does not define a `mask_b64` form. Conversions that produce `si64` - results still use the typed mask granularity of the source vector family. -- Width-changing conversions continue to follow VPTO packing semantics even on - the simplified DSL surface. For example, `f16 -> f32` uses an `f16`-family - `mask_b16`, because the mask is attached to the source vector family. -- A common `tcvt`-style pair is: - `f16 -> f32`: `pto.vlds(..., dist=pto.VLoadDist.UNPK_B16)` + `pto.vcvt(..., part=pto.VcvtPartMode.EVEN)`; - `f32 -> f16`: `pto.vcvt(..., rnd=..., sat=..., part=pto.VcvtPartMode.EVEN)` + `pto.vsts(..., dist=pto.VStoreDist.PK_B32)`. -- In those `tcvt` flows, the `vcvt` mask still follows the source vector family: - `f16 -> f32` uses `mask_b16`, while `f32 -> f16` uses `mask_b32`. -- The follow-on `vsts` mask is checked against the store `dist`, not the narrowed element dtype alone. For example, `pto.vsts(vec_f16, ..., mask32, dist=pto.VStoreDist.PK_B32)` is valid and expected for `f32 -> f16` rowwise `tcvt`. - -**Example**: -```python -mask16 = pto.make_mask(pto.f16, PAT.ALL) -vec_f16 = pto.vlds(src, 0) -vec_f32 = pto.vcvt(vec_f16, pto.f32, mask16) - -mask32 = pto.make_mask(pto.f32, PAT.ALL) -vec_i32 = pto.vcvt(vec_f32, pto.si32, mask32) - -vec_i32_wide = pto.vcvt( - vec_f16, - pto.si32, - mask16, - rnd=pto.VcvtRoundMode.R, - part=pto.VcvtPartMode.EVEN, -) - -vec_f16_from_bf16 = pto.vcvt( - vec_bf16, - pto.f16, - mask16, - rnd=pto.VcvtRoundMode.R, - sat=pto.VcvtSatMode.SAT, -) - -vec_f16_narrow = pto.vcvt( - vec_f32, - pto.f16, - mask32, - rnd=pto.VcvtRoundMode.R, - sat=pto.VcvtSatMode.SAT, - part=pto.VcvtPartMode.ODD, -) - -# Rowwise tcvt-style widening from f16 to f32 -vec_f16_unpacked = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) -vec_f32_from_f16 = pto.vcvt( - vec_f16_unpacked, - pto.f32, - mask16, - part=pto.VcvtPartMode.EVEN, -) - -# Rowwise tcvt-style narrowing from f32 to f16 -vec_f16_packed = pto.vcvt( - vec_f32, - pto.f16, - mask32, - rnd=pto.VcvtRoundMode.R, - sat=pto.VcvtSatMode.SAT, - part=pto.VcvtPartMode.EVEN, -) -pto.vsts(vec_f16_packed, dst, 0, mask32, dist=pto.VStoreDist.PK_B32) -``` - -#### `pto.vbitsort(dest: ptr, src: ptr, indices: ptr, repeat_times: index) -> None` [Advanced Tier] - -**Description**: Sort 32 region proposals by score and materialize sorted proposal -records into UB memory. This is a UB helper and not a `vreg -> vreg` operation. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `dest` | `ptr` | Destination pointer in UB memory space | -| `src` | `ptr` | Source score pointer in UB memory space | -| `indices` | `ptr` | Source index pointer in UB memory space | -| `repeat_times` | `index` | Repeat count; each repeat processes the next adjacent group of 32 scores and 32 indices | - -**Returns**: -None. The op writes UB memory directly. - -**Constraints**: -- `dest`, `src`, and `indices` must be UB-backed pointers -- Scores are sorted in descending order -- Equal-score ties preserve the earlier input proposal first -- Output records occupy 8 bytes each: upper 4 bytes for the index and lower 4 bytes for the score - -#### `pto.vmrgsort4(dest: ptr, src0: ptr, src1: ptr, src2: ptr, src3: ptr, count: pto.i64, config: pto.i64) -> None` [Advanced Tier] - -**Description**: Merge-sort 4 pre-sorted UB inputs. This op writes UB memory -directly and does not return a vector SSA value. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `dest` | `ptr` | Destination pointer in UB memory space | -| `src0` | `ptr` | First pre-sorted input pointer in UB memory space | -| `src1` | `ptr` | Second pre-sorted input pointer in UB memory space | -| `src2` | `ptr` | Third pre-sorted input pointer in UB memory space | -| `src3` | `ptr` | Fourth pre-sorted input pointer in UB memory space | -| `count` | `pto.i64` | Number of valid input elements participating in the merge | -| `config` | `pto.i64` | Operation control word encoding sort behavior | - -**Returns**: -None. The op writes UB memory directly. - -**Constraints**: -- `dest` and `src0` through `src3` must be UB-backed pointers -- Inputs must already be sorted according to the order encoded by `config` - -#### `pto.get_vms4_sr() -> (pto.i16, pto.i16, pto.i16, pto.i16)` [Advanced Tier] - -**Description**: Read `VMS4_SR` after exhausted `pto.vmrgsort4` and return the -finished element counts for source lists 0 through 3. - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `list0` | `pto.i16` | Finished count from `VMS4_SR[15:0]` | -| `list1` | `pto.i16` | Finished count from `VMS4_SR[31:16]` | -| `list2` | `pto.i16` | Finished count from `VMS4_SR[47:32]` | -| `list3` | `pto.i16` | Finished count from `VMS4_SR[63:48]` | - -**Example**: -```python -list0, list1, list2, list3 = pto.get_vms4_sr() -``` - -**Order Mode Enum**: The `OrderMode` enum provides type-safe order selection for `pto.vci` operations. `ASC` and `DESC` are supported. - -#### `pto.vci(index: ScalarType, order: OrderMode = OrderMode.ASC) -> VRegType` - -**Description**: Generate a lane-index vector from a scalar seed/index value (DSA/SFU operation). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `index` | `ScalarType` | Scalar seed or base index value | -| `order` | `OrderMode` | Order mode enum (default: `OrderMode.ASC`; supported values: `ASC`, `DESC`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Generated index vector | - -**Constraints**: -- This is an index-generation family, not a numeric conversion -- The `order` parameter and result element type together determine how indices are generated -- Supported order modes are ascending (`OrderMode.ASC`) and descending (`OrderMode.DESC`) - -**Example**: -```python -# Generate ascending indices starting from 0 -indices = pto.vci(pto.i32(0), OrderMode.ASC) - -# Generate descending indices starting from the seed value -indices_desc = pto.vci(pto.i32(63), OrderMode.DESC) - -# Keyword form for the optional order argument is also supported -indices_kw = pto.vci(pto.i32(0), order=OrderMode.ASC) -``` diff --git a/ptodsl/docs/user_guide/12-additional-examples.md b/ptodsl/docs/user_guide/12-additional-examples.md new file mode 100644 index 000000000..234d2981c --- /dev/null +++ b/ptodsl/docs/user_guide/12-additional-examples.md @@ -0,0 +1,400 @@ +# 12. Additional Examples + +This chapter presents four self-contained examples that build on the concepts introduced in Chapters 1–11. Each example demonstrates a specific pattern: blocked 2D processing, tail handling with masks, matrix multiplication on the Cube unit, and loop-carried state for online normalization. + +## 12.1 Blocked 2D elementwise addition + +Chapter 2 showed a 1D vector add with a single blocking dimension. Real workloads often involve 2D tensors — matrices — where blocking happens along both rows and columns. + +```python +@pto.jit(target="a5") +def mat_add(A, B, O, *, BLOCK_M: pto.constexpr = 64, BLOCK_N: pto.constexpr = 128): + M, N_ = A.shape + + a_view = pto.make_tensor_view(A, shape=[M, N_], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[M, N_], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[M, N_], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + + num_m = (M + BLOCK_M - 1) // BLOCK_M + num_n = (N_ + BLOCK_N - 1) // BLOCK_N + + with pto.for_(0, num_m, step=1) as mi: + m_off = mi * BLOCK_M + with pto.for_(0, num_n, step=1) as ni: + n_off = ni * BLOCK_N + + a_part = pto.partition_view(a_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) + b_part = pto.partition_view(b_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) + o_part = pto.partition_view(o_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) + + pto.tload(a_part, a_tile) + pto.tload(b_part, b_tile) + pto.tadd(a_tile, b_tile, o_tile) + pto.tstore(o_tile, o_part) +``` + +**Key points**: + +- Nested `pto.for_` loops produce a 2D block traversal. Both loops are recorded as device-side control flow — they adapt to the runtime shape `M`. +- Tile shape `[BLOCK_M, BLOCK_N]` is 2D; all three tiles use the same shape so `tadd` is elementwise. +- `partition_view` takes 2D offsets and sizes. +- `BLOCK_M` and `BLOCK_N` are `constexpr` — the compiler specializes the kernel per tile shape. + +The L0 wrapper follows the same pattern as Chapter 2: + +```python +def mat_add_wrapper(A, B, O=None, stream=None): + if O is None: + O = pto.empty_like(A) + compiled = mat_add.compile(BLOCK_M=64, BLOCK_N=128) + m, n = A.shape[1], A.shape[2] # assuming batch-first: [batch, M, N] + compiled[A.shape[0], stream](A, B, O) + return O +``` + +The grid is `A.shape[0]` so each SPMD block processes one slice of the leading batch dimension. + +## 12.2 Vector operations with tail handling + +When a data dimension is not evenly divisible by the tile size or the hardware vector width, the last iteration must operate on fewer elements. PTODSL provides masks for this — `make_mask` produces a predicate that guards loads, computes, and stores so out-of-bounds lanes are not touched. + +### 12.2.1 Tail handling in a SIMD kernel + +Below is a self-contained `@pto.simd` kernel that adds two tiles row by row, handling column tails with `make_mask`: + +```python +@pto.simd +def add_rows_with_tail(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + rows: pto.i32, cols: pto.i32): + VEC = pto.elements_per_vreg(pto.f32) # 64 for f32 + + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + + a_vec = pto.vlds(a_tile[r, c:]) # load under mask + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) # compute under mask + pto.vsts(o_vec, o_tile[r, c:], mask) # store under mask + + col_loop.update(remained=remained) +``` + +The pattern: + +1. **Chunk**: Each iteration processes `VEC` elements (one vector register's worth). +2. **Mask**: `make_mask` returns a predicate and the updated remainder. On the last iteration, where `remained < VEC`, the mask has `remained` valid lanes followed by inactive lanes. +3. **Guard**: `vlds`, `vadd`, and `vsts` all accept the mask — inactive lanes are neither loaded, computed, nor stored. +4. **Carry**: `.carry(remained=cols)` carries the remaining column count across iterations. `col_loop.update(remained=remained)` feeds the updated count to the next iteration. + +### 12.2.2 Tile-level tail handling + +At the Tile Op level, tail handling is built into `tload` and `tstore`. When a partition size along a dimension is smaller than the tile size, the tile's `valid_shape` tracks the actual data extent: + +```python +@pto.jit(target="a5") +def vec_add_with_tail(A, B, O, *, BLOCK: pto.constexpr): + N = A.shape[0] + + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + + num_blocks = (N + BLOCK - 1) // BLOCK + + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + this_block = min(BLOCK, N - offset) + + a_part = pto.partition_view(a_view, offsets=[offset], sizes=[this_block]) + b_part = pto.partition_view(b_view, offsets=[offset], sizes=[this_block]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) + + pto.tload(a_part, a_tile) + pto.tload(b_part, b_tile) + + a_tile.valid_shape = [this_block] + b_tile.valid_shape = [this_block] + o_tile.valid_shape = [this_block] + + pto.tadd(a_tile, b_tile, o_tile) + pto.tstore(o_tile, o_part) +``` + +- `this_block = min(BLOCK, N - offset)` computes the actual block size for the tail iteration. +- `sizes=[this_block]` on the partition and `valid_shape` on the tile tell `tload`/`tadd`/`tstore` how many elements are live. + +### 12.2.3 The general rule + +| Tail scenario | Mechanism | +|---------------|-----------| +| Tile Op boundary (tload/tstore) | `valid_shape` on tile + smaller `sizes` on partition | +| SIMD vector boundary (vlds/vadd/vsts) | `make_mask` + mask parameter on op | +| SIMT scalar loop boundary | `min(BLOCK, N - offset)` in loop bound | + +## 12.3 GEMM: matrix multiplication on the Cube unit + +This example demonstrates a complete GEMM kernel: `C = A @ B` where A is `[M, K]` and B is `[K, N]`. It uses `@pto.jit` for tile allocation and loop scheduling, and `@pto.cube` for the actual matrix multiply. + +### 12.3.1 L3: Cube sub-kernel + +```python +@pto.cube +def gemm_tile(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + a_l0a: pto.Tile, b_l0b: pto.Tile, o_acc: pto.Tile): + m = pto.tile_valid_rows(a_tile) + k = pto.tile_valid_cols(a_tile) + n = pto.tile_valid_rows(b_tile) + + pto.mte_l1_l0a(a_tile, a_l0a, m, k) + pto.mte_l1_l0b(b_tile, b_l0b, k, n, transpose=True) + pto.mad(a_l0a, b_l0b, o_acc) + pto.mte_l0c_ub(o_acc, o_tile, m, n) +``` + +The cube sub-kernel consumes UB tiles and cube-local scratch buffers. The four-step sequence — stage left operand, stage right operand, multiply, writeback — is the canonical cube compute pattern. + +### 12.3.2 L1: Tile orchestration + +```python +@pto.jit(target="a5") +def gemm(A, B, O, *, BLOCK_M: pto.constexpr = 64, + BLOCK_K: pto.constexpr = 64, BLOCK_N: pto.constexpr = 64): + M, K_ = A.shape + _, N_ = B.shape + + a_view = pto.make_tensor_view(A, shape=[M, K_], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[K_, N_], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[M, N_], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + + a_l0a = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT) + b_l0b = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32, + memory_space=pto.MemorySpace.RIGHT) + o_acc = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32, + memory_space=pto.MemorySpace.ACC) + + num_m = (M + BLOCK_M - 1) // BLOCK_M + num_n = (N_ + BLOCK_N - 1) // BLOCK_N + num_k = (K_ + BLOCK_K - 1) // BLOCK_K + + with pto.for_(0, num_m, step=1) as mi: + m_off = mi * BLOCK_M + with pto.for_(0, num_n, step=1) as ni: + n_off = ni * BLOCK_N + + o_tile.fill(0.0) + + with pto.for_(0, num_k, step=1) as ki: + k_off = ki * BLOCK_K + + a_part = pto.partition_view(a_view, offsets=[m_off, k_off], + sizes=[BLOCK_M, BLOCK_K]) + b_part = pto.partition_view(b_view, offsets=[k_off, n_off], + sizes=[BLOCK_K, BLOCK_N]) + o_part = pto.partition_view(o_view, offsets=[m_off, n_off], + sizes=[BLOCK_M, BLOCK_N]) + + pto.tload(a_part, a_tile) + pto.tload(b_part, b_tile) + + gemm_tile(a_tile, b_tile, o_tile, a_l0a, b_l0b, o_acc) + + pto.tstore(o_tile, o_part) +``` + +**Key points**: + +- **Triply nested loops**: M, N, and K dimensions are all blocked. The K loop accumulates partial results into `o_tile`. +- **Accumulation**: `o_tile.fill(0.0)` resets the accumulator before the K loop. Each K-block calls `gemm_tile` which writes its partial product back to `o_tile`. The Cube unit accumulates implicitly via `mad` — each K-block's partial result is added to the running total in `o_acc`. +- **Cube-local scratch**: `a_l0a`, `b_l0b`, and `o_acc` are allocated with explicit `memory_space` parameters (`LEFT`, `RIGHT`, `ACC`). Cube-local state does not leak into UB. +- **Direct L3 call**: `gemm_tile` is called directly from `@pto.jit` — no ukernel needed. The compiler handles sync between `tload` and the Cube sub-kernel. +- **Cube sub-kernel reuse**: the same `gemm_tile` function is called for every K-block — the named decorator form enables reuse. + +### 12.3.3 L0 wrapper + +```python +def gemm_wrapper(A, B, O=None, stream=None): + if O is None: + O = pto.empty([A.shape[0], B.shape[1]], dtype=A.dtype) + compiled = gemm.compile(BLOCK_M=64, BLOCK_K=64, BLOCK_N=64) + compiled[1, stream](A, B, O) + return O +``` + +This pattern extends directly to batch-GEMM: pass a grid of `batch` and use `pto.get_block_idx()` to select the per-batch slice from `A` and `B`. + +### 12.3.4 Comparison with ukernel path + +For reference, the same GEMM could be written using `@pto.ukernel` for explicit MTE control. The ukernel would replace the inner `tload`/`tstore` calls with `mte_load`/`mte_store` and add `mem_bar` synchronization between DMA and compute. The direct-call path used above is recommended for most users — the ukernel path is for cases that need hand-tuned DMA scheduling. + +## 12.4 Online normalization with loop-carried state + +Chapter 11 demonstrated online softmax with ping-pong state tiles. A simpler but instructive case is **online layer normalization** — computing mean and variance incrementally across blocks without a second pass. + +Given a vector `X` of length `N`, the streaming Welford algorithm updates the running mean `mu` and variance `var` as each new element `x` arrives: + +``` +n_next = n_prev + 1 +delta = x - mu_prev +mu_next = mu_prev + delta / n_next +m2_next = m2_prev + delta * (x - mu_next) +``` + +The example below applies this pattern block by block, using a ukernel for the per-block SIMD work and `pto.for_` carry state to shuttle the running statistics between blocks. + +### 12.4.1 L3: SIMD block statistics + +```python +@pto.simd +def block_mean_var(x_tile: pto.Tile, block_size: pto.i32, + mu_prev: pto.f32, n_prev: pto.f32, m2_prev: pto.f32, + mu_next_tile: pto.Tile, n_next_tile: pto.Tile, + m2_next_tile: pto.Tile): + VEC = pto.elements_per_vreg(pto.f32) + + # Per-row cross-lane reductions to compute the block sum and sum-of-squares + row_sum = pto.vdup(0.0, pto.f32) + row_sum2 = pto.vdup(0.0, pto.f32) + + col_loop = pto.for_(0, block_size, step=VEC).carry(row_sum=row_sum, row_sum2=row_sum2) + with col_loop: + c = col_loop.iv + remained = pto.i32(block_size) - c + mask, _ = pto.make_mask(pto.f32, remained) + + x_vec = pto.vlds(x_tile[0, c:]) + row_sum = pto.vcadd(x_vec, mask) + row_sum2 = pto.vcadd(pto.vmul(x_vec, x_vec, mask), mask) + col_loop.update(row_sum=row_sum, row_sum2=row_sum2) + + block_n = pto.cvt(block_size, pto.f32) + block_mean = pto.vdiv(col_loop.final("row_sum"), block_n) + block_mean_sq = pto.vdiv(col_loop.final("row_sum2"), block_n) + + # Welford update: merge block statistics into running state + n_next = n_prev + block_n + delta = block_mean - mu_prev + mu_next = mu_prev + delta * block_n / n_next + m2_next = m2_prev + pto.vdiv(row_sum2, block_n) * block_n # simplified + + scalar.store(n_next, n_next_tile[0, 0]) + scalar.store(mu_next, mu_next_tile[0, 0]) + scalar.store(m2_next, m2_next_tile[0, 0]) +``` + +### 12.4.2 L2: Ukernel with carry orchestration + +```python +@pto.ukernel +def norm_block(x_part: pto.PartitionTensorView, x_tile: pto.Tile, + block_size: pto.i32, + mu_prev: pto.f32, n_prev: pto.f32, m2_prev: pto.f32, + mu_next_tile: pto.Tile, n_next_tile: pto.Tile, + m2_next_tile: pto.Tile): + pto.mte_load(x_part, x_tile) + pto.mem_bar(pto.BarrierType.SYNC) + + block_mean_var(x_tile, block_size, + mu_prev, n_prev, m2_prev, + mu_next_tile, n_next_tile, m2_next_tile) + pto.mem_bar(pto.BarrierType.SYNC) +``` + +### 12.4.3 L1: JIT entry with carry state + +```python +@pto.jit(target="a5") +def online_layernorm(X, O, *, BLOCK: pto.constexpr): + N = X.shape[0] + x_view = pto.make_tensor_view(X, shape=[N], strides=X.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + x_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + + mu_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + n_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + m2_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + + num_blocks = (N + BLOCK - 1) // BLOCK + + # Carry: running statistics across blocks + block_loop = pto.for_(0, num_blocks, step=1).carry( + mu=pto.f32(0.0), n=pto.f32(0.0), m2=pto.f32(0.0) + ) + with block_loop: + i = block_loop.iv + offset = i * BLOCK + this_block = min(BLOCK, N - offset) + + x_part = pto.partition_view(x_view, offsets=[offset], sizes=[this_block]) + + mu_prev = block_loop.mu + n_prev = block_loop.n + m2_prev = block_loop.m2 + + norm_block(x_part, x_tile, pto.i32(this_block), + mu_prev, n_prev, m2_prev, + mu_tile, n_tile, m2_tile) + + n_next = scalar.load(n_tile[0, 0]) + mu_next = scalar.load(mu_tile[0, 0]) + m2_next = scalar.load(m2_tile[0, 0]) + + block_loop.update(mu=mu_next, n=n_next, m2=m2_next) + + # After all blocks: finalize normalization with the running stats + global_var = m2_next / n_next + + # Second pass: normalize each block (using same tiling) + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + this_block = min(BLOCK, N - offset) + x_part = pto.partition_view(x_view, offsets=[offset], sizes=[this_block]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) + + pto.tload(x_part, x_tile) + pto.tnormalize(x_tile, mu_next, global_var, o_tile) + pto.tstore(o_tile, o_part) +``` + +**Key points**: + +- **Carry state**: `.carry(mu=..., n=..., m2=...)` on the `pto.for_` declares three loop-carried values. Each iteration reads the previous values via `block_loop.mu` etc. and feeds the updated values via `block_loop.update(...)`. +- **Ping-pong implicit**: The carry mechanism produces a clean SSA-style handoff between iterations — no explicit swap of tile pairs needed. +- **Two-pass algorithm**: The first pass accumulates statistics; the second pass applies the normalization. For a single-pass online version, the normalized output would be written block-by-block inside the first loop, but that requires storing the running statistics per element — a tradeoff between memory and passes. +- **Compare to flash attention**: The flash attention carry in Chapter 11 carries six values (`m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next`) and uses ping-pong tiles. This example shows that for simpler scalar carries, direct values (no tile swap) suffice. + +## 12.5 Design guidelines + +**Start simple, refine later.** Begin with `@pto.jit` + Tile Ops. If Tile Ops don't cover the computation (e.g., custom softmax, specialized activation), add an L3 sub-kernel. If you need explicit DMA scheduling or inter-pipeline sync, drop to `@pto.ukernel`. + +**Choose the right entry for each piece:** + +| Goal | Use | +|------|-----| +| Whole-kernel orchestration, GM↔UB boundary | `@pto.jit` | +| Tile-level data movement | `tload` / `tstore` | +| Custom row-wise vector math | `@pto.simd` | +| Custom per-element logic | `@pto.simt` | +| Matrix multiply | `@pto.cube` | +| Explicit DMA + sync ordering | `@pto.ukernel` | +| Inline L3 for quick prototyping | `with pto.simd():` etc. | + +**Respect boundary contracts.** Vregs don't cross `@pto.simd` boundaries. Cube-local state doesn't leak into UB. Tile Ops and MTE Ops live at different abstraction levels — keep them in their respective layers. diff --git a/ptodsl/docs/user_guide/12-cube-operations.md b/ptodsl/docs/user_guide/12-cube-operations.md deleted file mode 100644 index 275039838..000000000 --- a/ptodsl/docs/user_guide/12-cube-operations.md +++ /dev/null @@ -1,454 +0,0 @@ -# Cube Matrix Multiply Operations - -Cube operations target the AIC (Cube) hardware unit for matrix multiplication and -staged data movement. They are only available inside `@pto.ckernel` function -bodies. All Cube operands use `pto.ptr` raw pointers — no -`vecscope` execution scope is used. - -## Address Spaces - -Cube operations use the following address spaces via the `MemorySpace` enum. -The IR type column shows the canonical `!pto.ptr` spelling. Older -`mat`/`left`/`right`/`acc`/`bias`/`scaling` pointer spellings are accepted as -parser aliases and print back as `l1`/`l0a`/`l0b`/`l0c`/`bt`/`fb`. - -| Address Space | Enum Value | Canonical IR Type | Legacy ptr alias | Description | -|--------------|------------|-------------------|------------------|-------------| -| `GM` | `MemorySpace.GM` | `!pto.ptr` | - | Global memory | -| `MAT` | `MemorySpace.MAT` | `!pto.ptr` | `mat` | L1 buffer (cbuf) | -| `LEFT` | `MemorySpace.LEFT` | `!pto.ptr` | `left` | L0A left-operand buffer | -| `RIGHT` | `MemorySpace.RIGHT` | `!pto.ptr` | `right` | L0B right-operand buffer | -| `ACC` | `MemorySpace.ACC` | `!pto.ptr` | `acc` | L0C accumulator buffer | -| `BIAS` | `MemorySpace.BIAS` | `!pto.ptr` | `bias` | Bias table | -| `UB` | `MemorySpace.UB` | `!pto.ptr` | `vec` | Unified buffer (Vector side) | - -## Shared Infrastructure - -Cube operations reuse general tile and pointer facilities documented elsewhere: - -| Facility | Description | Reference | -|----------|-------------|-----------| -| `pto.Tile` | Allocate a tile buffer with address space | [Type System — Tile Type Definition](05-type-system.md#tile-type-definition) | -| `.as_ptr()` | Get raw pointer from Tile / TensorView | [Frontend Operations — Pointer Construction](07-frontend-operations.md#pointer-construction-advanced-tier) | -| `pto.addptr` | Element-offset a pointer | [Frontend Operations — Pointer Construction](07-frontend-operations.md#pointer-construction-advanced-tier) | - ---- - -## Matrix Compute Operations - -### `pto.mad` — zero-init matmul - -#### `pto.mad(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` - -**Description**: Zero-init cube matrix multiply. Clears the accumulator and computes -`dst = lhs * rhs`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `lhs` | `pto.ptr` | L0A left operand | -| `rhs` | `pto.ptr` | L0B right operand | -| `dst` | `pto.ptr` | L0C accumulator destination | -| `m` | `int` | M dimension size | -| `n` | `int` | N dimension size | -| `k` | `int` | K dimension size | -| `unit_flag_ctrl` | `int` | Accumulator control flag (0 / 2 / 3) | -| `disable_gemv` | `bool` | GEMV disable control | - -**Constraints**: -- `lhs` must be in `l0a` address space. -- `rhs` must be in `l0b` address space. -- `dst` must be in `l0c` address space. - -**Example**: -```python -pto.mad(l0a, l0b, l0c, 16, 16, 64) -``` - ---- - -### `pto.mad_acc` — accumulating matmul - -#### `pto.mad_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` - -**Description**: Accumulating cube matrix multiply. Computes `dst += lhs * rhs`. - -**Parameters**: Same as `pto.mad`. - -**Example**: -```python -pto.mad_acc(l0a, l0b, l0c, 16, 16, 64, unit_flag_ctrl=2) -``` - ---- - -### `pto.mad_bias` — bias-init matmul - -#### `pto.mad_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` - -**Description**: Bias-init cube matrix multiply. Computes `dst = lhs * rhs + bias`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `bias` | `pto.ptr` | Bias table pointer | - -Other parameters are the same as `pto.mad`. - -**Constraints**: -- `bias` must be in `bt` address space. - -**Example**: -```python -pto.mad_bias(l0a, l0b, l0c, bt, 16, 16, 64) -``` - ---- - -### `pto.mad_mx` — zero-init MX matmul - -#### `pto.mad_mx(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` - -**Description**: Zero-init MX (micro-scaling) cube matrix multiply. Same semantics -as `pto.mad`, for MX-capable dtypes such as `f8E4M3FN`. - -**Parameters**: Same as `pto.mad`. - -**Example**: -```python -pto.mad_mx(l0a, l0b, l0c, 16, 16, 64) -``` - ---- - -### `pto.mad_mx_acc` — accumulating MX matmul - -#### `pto.mad_mx_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` - -**Description**: Accumulating MX cube matrix multiply. Computes `dst += lhs * rhs`. - -**Parameters**: Same as `pto.mad`. - ---- - -### `pto.mad_mx_bias` — MX bias-init matmul - -#### `pto.mad_mx_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` - -**Description**: MX bias-init cube matrix multiply. Computes `dst = lhs * rhs + bias`. - -**Parameters**: Same as `pto.mad_bias`. - ---- - -## Data Movement Operations - -### `pto.cube_load` — GM → L1 (cbuf) - -#### `pto.cube_load(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` - -**Description**: Structured GM-to-L1 (`cbuf` / `l1`) data movement wrapper. Lowers -to loop/stride setup plus `pto.copy_gm_to_cbuf`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | Global memory source pointer | -| `dst` | `pto.ptr` | L1 (cbuf) destination pointer | -| `len_burst` | `int` | Burst length in bytes | -| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | -| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop params, each `(count_i, src_stride_i, dst_stride_i)` | - -**Constraints**: -- `src` must be in `gm` address space. -- `dst` must be in `l1` address space. - -**Example**: -```python -pto.cube_load(a_ptr, l1_a.as_ptr(), 16, nburst=(1, 0, 0)) -``` - ---- - -### `pto.cube_store` — L1 (cbuf) → UB - -#### `pto.cube_store(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` - -**Description**: Structured L1 (`cbuf`) to UB data movement wrapper. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | L1 source pointer | -| `dst` | `pto.ptr` | UB destination pointer | -| `len_burst` | `int` | Burst length in bytes | -| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | -| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop params | - -**Example**: -```python -pto.cube_store(l1_src.as_ptr(), ub_dst.as_ptr(), 16, nburst=(1, 0, 0)) -``` - ---- - -### `pto.cube_load_frac` — fractal load - -#### `pto.cube_load_frac(src: PtrType, dst: PtrType, mode: pto.FractalMode, *, shape: tuple[int, int], src_layout: tuple[int, int], dst_group: tuple[int, int, int, int], ctrl: tuple[int, bool]) -> None` - -**Description**: Structured fractal-load wrapper for `nd2nz` and `dn2nz` modes. -Lowers to `set_mte2_nz_para` plus `copy_gm_to_cbuf_multi_nd2nz` or -`copy_gm_to_cbuf_multi_dn2nz`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | Global memory source pointer | -| `dst` | `pto.ptr` | L1 destination pointer | -| `mode` | `pto.FractalMode` | `pto.FractalMode.ND2NZ` or `pto.FractalMode.DN2NZ` | -| `shape` | `tuple[int, int]` | `(n_value, d_value)` | -| `src_layout` | `tuple[int, int]` | `(inner_stride, outer_stride)` | -| `dst_group` | `tuple[int, int, int, int]` | `(group_count, loop2_stride, loop3_stride, loop4_stride)` | -| `ctrl` | `tuple[int, bool]` | `(l2_cache_ctrl, smallc0_en)` | - -**Constraints**: -- `src` must be in `gm` address space. -- `dst` must be in `l1` address space. - -**Example**: -```python -pto.cube_load_frac(a_ptr, l1_a.as_ptr(), pto.FractalMode.ND2NZ, - shape=(16, 16), src_layout=(4, 8), - dst_group=(1, 0, 0, 0), ctrl=(0, False)) -``` - ---- - -### `pto.bias_load` — L1 (cbuf) → bias table - -#### `pto.bias_load(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0)) -> None` - -**Description**: Structured L1 (`cbuf`) to bias-table load wrapper. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | L1 source pointer | -| `dst` | `pto.ptr` | Bias table destination pointer | -| `len_burst` | `int` | Burst length in bytes | -| `nburst` | `tuple[int, int, int]` | `(count, src_gap, dst_gap)` | - -**Constraints**: -- Supported source/destination type pairs: `f32→f32`, `i32→i32`, `f16→f32`, `bf16→f32`. - -**Example**: -```python -pto.bias_load(l1_bias.as_ptr(), bt.as_ptr(), 16, nburst=(1, 0, 0)) -``` - ---- - -### `pto.left_load` — L1 (cbuf) → L0A - -#### `pto.left_load(src: PtrType, dst: PtrType, m: int, k: int) -> None` - -**Description**: Structured L1-to-L0A wrapper. Lowers to `pto.load_cbuf_to_ca`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | L1 source pointer | -| `dst` | `pto.ptr` | L0A destination pointer | -| `m` | `int` | M dimension size | -| `k` | `int` | K dimension size | - -**Constraints**: -- `src` must be in `l1` address space. -- `dst` must be in `l0a` address space. - -**Example**: -```python -pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), 16, 64) -``` - ---- - -### `pto.right_load` — L1 (cbuf) → L0B - -#### `pto.right_load(src: PtrType, dst: PtrType, k: int, n: int) -> None` - -**Description**: Structured L1-to-L0B wrapper. Lowers to `pto.load_cbuf_to_cb`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | L1 source pointer | -| `dst` | `pto.ptr` | L0B destination pointer | -| `k` | `int` | K dimension size | -| `n` | `int` | N dimension size | - -**Constraints**: -- `src` must be in `l1` address space. -- `dst` must be in `l0b` address space. - -**Example**: -```python -pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), 64, 16) -``` - ---- - -### `pto.left_load_mx` — MX L1 → L0A - -#### `pto.left_load_mx(src: PtrType, dst: PtrType, m: int, k: int) -> None` - -**Description**: MX-mode L1-to-L0A wrapper. Lowers to `pto.load_cbuf_to_ca_mx`. - -**Parameters**: Same as `pto.left_load`. - ---- - -### `pto.right_load_mx` — MX L1 → L0B - -#### `pto.right_load_mx(src: PtrType, dst: PtrType, k: int, n: int) -> None` - -**Description**: MX-mode L1-to-L0B wrapper. Lowers to `pto.load_cbuf_to_cb_mx`. - -**Parameters**: Same as `pto.right_load`. - ---- - -## Result Writeback Operations - -### `pto.acc_store` — L0C (acc) → L1 (cbuf) - -#### `pto.acc_store(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` - -**Description**: Structured L0C (`l0c`) to L1 (`cbuf`) writeback wrapper. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | L0C source pointer | -| `dst` | `pto.ptr` | L1 (cbuf) destination pointer | -| `m` | `int` | M dimension size | -| `n` | `int` | N dimension size | -| `src_stride` | `int` | Source stride | -| `dst_stride` | `int` | Destination stride | -| `mode` | `pto.FractalMode` | Layout mode: `NZ2ND` / `NZ2DN` / `NZ2NZ` | - -Mode-dependent parameters: - -| Mode | Required | Not Accepted | -|------|----------|--------------| -| `pto.FractalMode.NZ2ND` | (none) | — | -| `pto.FractalMode.NZ2DN` | `loop0_src_stride` | — | -| `pto.FractalMode.NZ2NZ` | `split` | `loop3` | - -Optional for `pto.FractalMode.NZ2ND` and `pto.FractalMode.NZ2DN`: -`loop3=(count, src_stride3, dst_stride3)`. - -**Example**: -```python -pto.acc_store(l0c.as_ptr(), l1_out.as_ptr(), - 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) -``` - ---- - -### `pto.acc_store_gm` — L0C (acc) → GM - -#### `pto.acc_store_gm(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, sid: int = 0, l2_cache_ctrl: int = 0, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` - -**Description**: Structured L0C (`l0c`) to GM writeback wrapper. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | L0C source pointer | -| `dst` | `pto.ptr` | GM destination pointer | -| `sid` | `int` | Stream ID | -| `l2_cache_ctrl` | `int` | L2 cache control | - -Other parameters are the same as `pto.acc_store`. - -**Example**: -```python -pto.acc_store_gm(l0c.as_ptr(), c_ptr, 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) -``` - ---- - -### `pto.acc_store_ub` — L0C (acc) → UB - -#### `pto.acc_store_ub(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, dual_dst_mode: int = 0, sub_blockid: int = 0, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, channel_split_en: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` - -**Description**: Structured L0C (`l0c`) to UB writeback wrapper. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `pto.ptr` | L0C source pointer | -| `dst` | `pto.ptr` | UB destination pointer | -| `dual_dst_mode` | `int` | Dual destination mode | -| `sub_blockid` | `int` | Sub-block ID | -| `channel_split_en` | `int` or `None` | Channel split enable (required for `mode=pto.FractalMode.NZ2NZ`) | - -Other parameters are the same as `pto.acc_store`. - -**Example**: -```python -pto.acc_store_ub(l0c.as_ptr(), ub_out.as_ptr(), - 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) -``` - ---- - -## Quick Reference - -### By Data Flow - -| Data Flow | Operation | Src Space | Dst Space | -|-----------|-----------|-----------|-----------| -| GM → L1 | `pto.cube_load` | gm | l1 | -| GM → L1 (fractal) | `pto.cube_load_frac` | gm | l1 | -| L1 → UB | `pto.cube_store` | l1 | ub | -| L1 → L0A | `pto.left_load` | l1 | l0a | -| L1 → L0B | `pto.right_load` | l1 | l0b | -| L1 → L0A (MX) | `pto.left_load_mx` | l1 | l0a | -| L1 → L0B (MX) | `pto.right_load_mx` | l1 | l0b | -| L1 → Bias | `pto.bias_load` | l1 | bt | -| L0A×L0B → L0C | `pto.mad` | l0a, l0b | l0c | -| L0A×L0B → L0C (acc) | `pto.mad_acc` | l0a, l0b | l0c | -| L0A×L0B+Bias → L0C | `pto.mad_bias` | l0a, l0b, bt | l0c | -| L0C → L1 | `pto.acc_store` | l0c | l1 | -| L0C → GM | `pto.acc_store_gm` | l0c | gm | -| L0C → UB | `pto.acc_store_ub` | l0c | ub | - -### MX Variants - -| Base Op | MX Variant | Description | -|---------|------------|-------------| -| `pto.mad` | `pto.mad_mx` | Zero-init MX matmul | -| `pto.mad_acc` | `pto.mad_mx_acc` | Accumulating MX matmul | -| `pto.mad_bias` | `pto.mad_mx_bias` | Bias-init MX matmul | - ---- - -## Template Slot Support - -Cube operations support `pto.tpl()` template-slot dispatch, consistent with the -Vector DSL mechanism. See [Template Kernels](04-template-kernels.md) for general -`pto.tpl()` usage. - -**Constraints**: Variants within the same slot must have identical parameter -signatures. For example, `mad` and `mad_acc` can share a slot, but `mad_bias` -(which adds a `bias` parameter) requires a separate slot. - ---- - -## See Also - -- [Kernel Declaration](03-kernel-declaration.md) — `@pto.ckernel` decorator specification -- [Examples](13-examples.md) — full Cube kernel code examples -- [Design doc](../../../docs/designs/tilelang-cube-dsl-design.md) — Cube DSL design details diff --git a/ptodsl/docs/user_guide/13-examples.md b/ptodsl/docs/user_guide/13-examples.md deleted file mode 100644 index 16105b853..000000000 --- a/ptodsl/docs/user_guide/13-examples.md +++ /dev/null @@ -1,417 +0,0 @@ -## Examples - -### Template-based Kernel Examples - -#### Unified Arithmetic Operations - -A single kernel implementing multiple arithmetic operations using templates: - -```python -T = pto.TypeVar('T') - -@pto.vkernel( - target="a5", - ops=["tadd", "tsub", "tmul", "tdiv"], - dtypes=[(T, T, T)], - advanced=True, - templates={ - "core": { - "tadd": "vadd", - "tsub": "vsub", - "tmul": "vmul", - "tdiv": "vdiv", - } - } -) -def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - """Single implementation for four arithmetic operations.""" - dtype = dst.element_type - rows, cols = dst.valid_shape - - for row in range(0, rows, 1): - remained = cols - for col in range(0, cols, pto.elements_per_vreg(dtype)): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - out = pto.tpl("core", lhs, rhs, mask) - pto.vsts(out, dst[row, col:], mask) -``` - -#### Multiple Templates with Postprocess - -Kernel using separate templates for arithmetic and postprocess operations: - -```python -@pto.vkernel( - target="a5", - ops=["add_relu", "sub_relu", "add_abs", "sub_abs"], - dtypes=[(T, T, T)], - templates={ - "arithmetic": { - "add_relu": "vadd", - "sub_relu": "vsub", - "add_abs": "vadd", - "sub_abs": "vsub", - }, - "postprocess": { - "add_relu": "vrelu", - "sub_relu": "vrelu", - "add_abs": "vabs", - "sub_abs": "vabs", - } - } -) -def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - dtype = dst.element_type - rows, cols = dst.valid_shape - - for row in range(0, rows, 1): - remained = cols - for col in range(0, cols, pto.elements_per_vreg(dtype)): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - - # Use arithmetic template - arith_result = pto.tpl("arithmetic", lhs, rhs, mask) - - # Apply postprocess template - activated = pto.tpl("postprocess", arith_result, mask) - - pto.vsts(activated, dst[row, col:], mask) -``` - -#### Compile-time Substitution - -Template substitution happens before semantic analysis and lowering: - -```python -selected = pto.select_kernel("a5", "tadd", (ptype, ptype, ptype)) -# frontend resolves: -# pto.tpl("core", lhs, rhs, mask) -# into: -# pto.vadd(lhs, rhs, mask) -``` - -#### Benefits of Template-based Authoring - -1. **Code Reuse**: Single implementation serves multiple operations -2. **Maintenance**: Bug fixes and optimizations apply to all related operations -3. **Consistency**: Ensures uniform behavior across operation families -4. **Reduced Boilerplate**: Eliminates duplicate control flow and data movement code -5. **Type Safety**: Type variables ensure consistent operand types - -### Simple Vector Copy - -```python -@pto.vkernel(...) -def vector_copy(src: pto.Tile, dst: pto.Tile): - all_mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) - for offset in range(0, 256, 64): - vec = pto.vlds(src, offset) - pto.vsts(vec, dst, offset, all_mask) -``` - -### Conditional Computation - -```python -@pto.vkernel(...) -def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), - dst: pto.ptr(pto.f32, MemorySpace.GM), - threshold: pto.f32): - # ... setup ... - - with pto.strict_vecscope(ub_in, ub_out, threshold) as (vin, vout, thresh): - for i in range(0, 1024, 64): - vec = pto.vlds(vin, i) - - # Compare with threshold - mask = pto.pge_b32(vec, thresh) - - # Scale values above threshold - scaled = pto.vmuls(vec, pto.f32(2.0), mask) - - # Keep original values below threshold - result = pto.vsel(scaled, vec, mask) - - pto.vsts(result, vout, i, all_mask) -``` - -### Loop with Carry - -```python -@pto.vkernel(...) -def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), - dst: pto.ptr(pto.i32, MemorySpace.UB)): - all_mask = pto.make_mask(pto.i32, PAT.ALL) - carry = all_mask - - for i in range(0, 256, 64): - vec = pto.vlds(src, i) - result, carry = pto.vaddcs(vec, vec, carry, all_mask) - pto.vsts(result, dst, i, all_mask) -``` - ---- - -## Cube Kernel Examples - -Cube kernels target the AIC (Cube) hardware unit for matrix multiplication. GM data is expressed through `PartitionTensorView`, while hardware buffers in specific address spaces are constructed via `pto.Tile`. - -### Basic GEMM - -A full-pipeline matrix multiplication C = A × B: - -```python -from tilelang_dsl import ckernel, Tile, MemorySpace - -@pto.ckernel( - target="a5", - op="pto.mad", - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="gemm", -) -def gemm(a_tv: pto.PartitionTensorView, # [M, K] in GM - b_tv: pto.PartitionTensorView, # [K, N] in GM - c_tv: pto.PartitionTensorView, # [M, N] in GM, output - M: int, K: int, N: int): - # Get GM pointers from PartitionTensorViews - a_ptr = a_tv.as_ptr() - b_ptr = b_tv.as_ptr() - c_ptr = c_tv.as_ptr() - - # Allocate L1 (MAT) tile buffers - l1_a_tile = pto.Tile([M, K], pto.f16, MemorySpace.MAT) - l1_b_tile = pto.Tile([K, N], pto.f16, MemorySpace.MAT) - - # Allocate L0 tile buffers - l0a_tile = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) - l0b_tile = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) - l0c_tile = pto.Tile([M, N], pto.f32, MemorySpace.ACC) - - # GM → L1 - pto.cube_load(a_ptr, l1_a_tile.as_ptr(), K, nburst=(1, 0, 0)) - pto.cube_load(b_ptr, l1_b_tile.as_ptr(), N, nburst=(1, 0, 0)) - - # L1 → L0 - pto.left_load(l1_a_tile.as_ptr(), l0a_tile.as_ptr(), M, K) - pto.right_load(l1_b_tile.as_ptr(), l0b_tile.as_ptr(), K, N) - - # Compute: C = A × B - pto.mad(l0a_tile.as_ptr(), l0b_tile.as_ptr(), l0c_tile.as_ptr(), M, N, K) - - # L0C → GM writeback - pto.acc_store_gm(l0c_tile.as_ptr(), c_ptr, M, N, - src_stride=N, dst_stride=N, mode="nz2nd") -``` - -### Split-K GEMM - -Matrix multiplication with K-dimension splitting for large K values: - -```python -@pto.ckernel( - target="a5", - op="pto.mad", - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="gemm_splitk", -) -def gemm_splitk(a_tv: pto.PartitionTensorView, # [M, K] - b_tv: pto.PartitionTensorView, # [K, N] - c_tv: pto.PartitionTensorView, # [M, N] - M: int, K: int, N: int, BASEK: int): - iters = K // BASEK - - a_ptr = a_tv.as_ptr() - b_ptr = b_tv.as_ptr() - c_ptr = c_tv.as_ptr() - - # Allocate buffers sized for one split-K step - l1_a = pto.Tile([M, BASEK], pto.f16, MemorySpace.MAT) - l1_b = pto.Tile([BASEK, N], pto.f16, MemorySpace.MAT) - l0a = pto.Tile([M, BASEK], pto.f16, MemorySpace.LEFT) - l0b = pto.Tile([BASEK, N], pto.f16, MemorySpace.RIGHT) - l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) - - for k_step in range(iters): - k_off = k_step * BASEK - - # Offset GM pointers for this K-slice - a_k = pto.addptr(a_ptr, k_off) - b_k = pto.addptr(b_ptr, k_off) - - # GM → L1 → L0 - pto.cube_load(a_k, l1_a.as_ptr(), BASEK, nburst=(1, 0, 0)) - pto.cube_load(b_k, l1_b.as_ptr(), N, nburst=(1, 0, 0)) - pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, BASEK) - pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), BASEK, N) - - # First step: zero-init; subsequent steps: accumulate - if k_step == 0: - pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) - else: - pto.mad_acc(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) - - # L0C → GM - pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, - src_stride=N, dst_stride=N, mode="nz2nd") -``` - -### GEMM with Bias - -Matrix multiplication with bias addition C = A × B + bias: - -```python -@pto.ckernel( - target="a5", - op="pto.mad_bias", - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="gemm_bias", -) -def gemm_bias(a_tv: pto.PartitionTensorView, - b_tv: pto.PartitionTensorView, - c_tv: pto.PartitionTensorView, - bias_tv: pto.PartitionTensorView, - M: int, K: int, N: int): - a_ptr = a_tv.as_ptr() - b_ptr = b_tv.as_ptr() - c_ptr = c_tv.as_ptr() - bias_ptr = bias_tv.as_ptr() - - # L1 buffers - l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) - l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) - l1_bias = pto.Tile([1, N], pto.f32, MemorySpace.MAT) - - # L0 buffers - l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) - l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) - l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) - - # Bias table - bt = pto.Tile([1, N], pto.f32, MemorySpace.BIAS) - - # Data movement - pto.cube_load(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) - pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) - pto.cube_load(bias_ptr, l1_bias.as_ptr(), N, nburst=(1, 0, 0)) - pto.bias_load(l1_bias.as_ptr(), bt.as_ptr(), N, nburst=(1, 0, 0)) - - # L1 → L0 - pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) - pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) - - # Compute: C = A × B + bias - pto.mad_bias(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), bt.as_ptr(), M, N, K) - - # Writeback - pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, - src_stride=N, dst_stride=N, mode="nz2nd") -``` - -### Fractal Load (nd2nz) Example - -Using fractal load for ND-layout to NZ-fractal data loading: - -```python -@pto.ckernel( - target="a5", - op="pto.mad", - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="gemm_frac", -) -def gemm_frac(a_tv: pto.PartitionTensorView, - b_tv: pto.PartitionTensorView, - c_tv: pto.PartitionTensorView, - M: int, K: int, N: int): - a_ptr = a_tv.as_ptr() - b_ptr = b_tv.as_ptr() - c_ptr = c_tv.as_ptr() - - l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) - l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) - l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) - l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) - l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) - - # Fractal load: ND → NZ - pto.cube_load_frac(a_ptr, l1_a.as_ptr(), "nd2nz", - shape=(M, K), - src_layout=(K,), - dst_group=(1, 0, 0, 0), - ctrl=(0, False)) - pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) - - pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) - pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) - pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) - - pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, - src_stride=N, dst_stride=N, mode="nz2nd") -``` - -### Pure-Compute Kernel (Pre-Allocated Tiles) - -When tiles are pre-allocated externally, the kernel only performs computation: - -```python -@pto.ckernel( - target="a5", - op="pto.mad", - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="matmul_compute", -) -def matmul_compute(a_left: pto.Tile, # Pre-allocated LEFT tile (L0A) - b_right: pto.Tile, # Pre-allocated RIGHT tile (L0B) - c_acc: pto.Tile, # Pre-allocated ACC tile (L0C) - M: int, K: int, N: int): - pto.mad(a_left.as_ptr(), b_right.as_ptr(), c_acc.as_ptr(), M, N, K) -``` - -### Template-based Multi-Op Cube Kernel - -Reusing a single template body for multiple Cube matmul variants: - -```python -@pto.ckernel( - target="a5", - ops=["mad", "mad_acc"], - dtypes=[(pto.f16, pto.f16, pto.f32)], - name="gemm_template", - templates={ - "compute": {"mad": "mad", "mad_acc": "mad_acc"}, - }, -) -def gemm_template(a_tv: pto.PartitionTensorView, - b_tv: pto.PartitionTensorView, - c_tv: pto.PartitionTensorView, - M: int, K: int, N: int): - a_ptr = a_tv.as_ptr() - b_ptr = b_tv.as_ptr() - c_ptr = c_tv.as_ptr() - - l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) - l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) - l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) - l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) - l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) - - pto.cube_load(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) - pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) - pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) - pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) - - # Template slot: resolved at specialization time - pto.tpl("compute", l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) - - pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, - src_stride=N, dst_stride=N, mode="nz2nd") -``` - -Usage: - -```python -k_mad = pto.select_kernel("a5", "gemm_template", selected_op="mad") -k_acc = pto.select_kernel("a5", "gemm_template", selected_op="mad_acc") -``` diff --git a/ptodsl/docs/user_guide/14-common-errors.md b/ptodsl/docs/user_guide/14-common-errors.md deleted file mode 100644 index 46abe09b9..000000000 --- a/ptodsl/docs/user_guide/14-common-errors.md +++ /dev/null @@ -1,51 +0,0 @@ -## Common Errors - -### Typed Mask Mismatch - -``` -Error: f32 vector operation cannot consume mask_b16 -``` - -**Solution:** Ensure mask granularity matches vector element size: -- `f32` vectors use `mask_b32` -- `f16` vectors use `mask_b16` -- `i8` vectors use `mask_b8` - -### Strict Scope Implicit Capture - -``` -Error: strict_vecscope body cannot capture outer value 'ub_in' implicitly -``` - -**Solution:** Pass all required values in the capture list: - -```python -# Wrong: -with pto.strict_vecscope() as (): - vec = pto.vlds(ub_in, offset) # ub_in from outer scope - -# Correct: -with pto.strict_vecscope(ub_in) as (ub): - vec = pto.vlds(ub, offset) -``` - -### Untyped Loop Carried State - -``` -Error: loop-carried value must have explicit machine type -``` - -**Solution:** Add type annotations to loop-carried variables: - -```python -# Wrong: -remaining = 1024 # Plain Python int -for i in range(0, N, step): - mask, remaining = pto.make_mask(pto.f32, remaining) - -# Correct: -remaining: pto.i32 = 1024 -# or -remaining = pto.i32(1024) -``` - diff --git a/ptodsl/docs/user_guide/15-compatibility-notes.md b/ptodsl/docs/user_guide/15-compatibility-notes.md deleted file mode 100644 index defcf704c..000000000 --- a/ptodsl/docs/user_guide/15-compatibility-notes.md +++ /dev/null @@ -1,9 +0,0 @@ -## Compatibility Notes - -The current experimental implementation in `python/pto/dialects/pto.py` differs from this specification in several ways: - -1. **Mask types**: The experimental version uses untyped `mask` instead of `mask_b8`/`mask_b16`/`mask_b32` -2. **Barrier operation**: Uses `pto.barrier()` instead of `pto.pipe_barrier()` -3. **Operation coverage**: Implements only a subset of operations - -When implementing new code, follow this specification. The experimental implementation will be updated to match over time. diff --git a/ptodsl/docs/user_guide/16-next-steps.md b/ptodsl/docs/user_guide/16-next-steps.md deleted file mode 100644 index 2fe63b9a4..000000000 --- a/ptodsl/docs/user_guide/16-next-steps.md +++ /dev/null @@ -1,7 +0,0 @@ -## Next Steps - -- Explore the ISA documentation in `docs/isa/` for detailed operation semantics -- Check `test/samples/` for example kernels -- Refer to `docs/vpto-spec.md` for the underlying VPTO instruction specification - -For compiler developers, see `docs/PTO_IR_manual.md` for MLIR-level details. From d8db04ee54ad40d4943b1cf3c96b764449d7bbe0 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 18 May 2026 23:07:42 +0800 Subject: [PATCH 15/31] Complete the mlir text emission of the FA demo --- lib/Bindings/Python/PTOModule.cpp | 137 ++- ptodsl/README.md | 68 +- .../demos/flash_attention_sketch.py | 397 +++++--- ptodsl/docs/user_guide/01-introduction.md | 11 +- ptodsl/docs/user_guide/02-quick-start.md | 6 +- .../03-kernel-entry-and-subkernels.md | 48 +- .../user_guide/04-type-system-and-buffer.md | 6 + ptodsl/docs/user_guide/05-control-flow.md | 8 +- .../user_guide/06-scalar-and-pointer-ops.md | 11 +- .../docs/user_guide/07-data-movement-ops.md | 22 +- .../docs/user_guide/08-compute-operations.md | 20 +- ptodsl/docs/user_guide/10-sync-ops.md | 16 +- .../11-flash-attention-walkthrough.md | 189 ++-- .../docs/user_guide/12-additional-examples.md | 20 +- ptodsl/examples/softmax_dsl.py | 26 +- ptodsl/examples/tadd_dsl.py | 10 +- ptodsl/ptodsl/__init__.py | 4 +- ptodsl/ptodsl/_bootstrap.py | 11 +- ptodsl/ptodsl/_control_flow.py | 189 +++- ptodsl/ptodsl/_diagnostics.py | 99 ++ ptodsl/ptodsl/_host_tensors.py | 238 +++++ ptodsl/ptodsl/_jit.py | 91 ++ ptodsl/ptodsl/_kernel_compilation.py | 83 ++ ptodsl/ptodsl/_kernel_signature.py | 191 ++++ ptodsl/ptodsl/_module.py | 159 --- ptodsl/ptodsl/_ops.py | 929 +++++++++++++++++- ptodsl/ptodsl/_runtime_index_ops.py | 43 + ptodsl/ptodsl/_runtime_scalar_ops.py | 134 +++ ptodsl/ptodsl/_scalar_coercion.py | 97 ++ ptodsl/ptodsl/_subkernels.py | 168 ++++ ptodsl/ptodsl/_surface_types.py | 99 ++ ptodsl/ptodsl/_surface_values.py | 851 ++++++++++++++++ ptodsl/ptodsl/_tensor_factories.py | 42 + .../{vpto.py => _tile_template_tracing.py} | 240 ++--- ptodsl/ptodsl/_tracing/__init__.py | 40 + ptodsl/ptodsl/_tracing/active.py | 86 ++ ptodsl/ptodsl/_tracing/artifacts.py | 58 ++ ptodsl/ptodsl/_tracing/control_flow.py | 92 ++ ptodsl/ptodsl/_tracing/module_builder.py | 81 ++ ptodsl/ptodsl/_tracing/runtime.py | 131 +++ ptodsl/ptodsl/_tracing/session.py | 205 ++++ ptodsl/ptodsl/_types.py | 74 +- ptodsl/ptodsl/pto.py | 38 +- ptodsl/ptodsl/scalar.py | 98 +- python/pto/dialects/pto.py | 22 +- test/python/ptodsl_jit_compile.py | 618 ++++++++++++ test/python/ptodsl_jit_diagnostics.py | 166 ++++ test/python/ptodsl_subkernel_diagnostics.py | 108 ++ 48 files changed, 5693 insertions(+), 787 deletions(-) rename ptodsl/{docs => }/demos/flash_attention_sketch.py (65%) create mode 100644 ptodsl/ptodsl/_diagnostics.py create mode 100644 ptodsl/ptodsl/_host_tensors.py create mode 100644 ptodsl/ptodsl/_jit.py create mode 100644 ptodsl/ptodsl/_kernel_compilation.py create mode 100644 ptodsl/ptodsl/_kernel_signature.py delete mode 100644 ptodsl/ptodsl/_module.py create mode 100644 ptodsl/ptodsl/_runtime_index_ops.py create mode 100644 ptodsl/ptodsl/_runtime_scalar_ops.py create mode 100644 ptodsl/ptodsl/_scalar_coercion.py create mode 100644 ptodsl/ptodsl/_subkernels.py create mode 100644 ptodsl/ptodsl/_surface_types.py create mode 100644 ptodsl/ptodsl/_surface_values.py create mode 100644 ptodsl/ptodsl/_tensor_factories.py rename ptodsl/ptodsl/{vpto.py => _tile_template_tracing.py} (76%) create mode 100644 ptodsl/ptodsl/_tracing/__init__.py create mode 100644 ptodsl/ptodsl/_tracing/active.py create mode 100644 ptodsl/ptodsl/_tracing/artifacts.py create mode 100644 ptodsl/ptodsl/_tracing/control_flow.py create mode 100644 ptodsl/ptodsl/_tracing/module_builder.py create mode 100644 ptodsl/ptodsl/_tracing/runtime.py create mode 100644 ptodsl/ptodsl/_tracing/session.py create mode 100644 test/python/ptodsl_jit_compile.py create mode 100644 test/python/ptodsl_jit_diagnostics.py create mode 100644 test/python/ptodsl_subkernel_diagnostics.py diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index a13b39cad..1c8ae1c0d 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -18,6 +18,7 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" #include "pto-c/Dialect/PTO.h" #include "mlir-c/IR.h" #include "PTO/IR/PTO.h" @@ -27,6 +28,8 @@ #include "mlir/IR/BuiltinTypes.h" namespace py = pybind11; using namespace mlir::python::adaptors; +using llvm::cast; +using llvm::isa; static std::vector toInt64Vector(const py::sequence &seq) { std::vector out; @@ -63,6 +66,14 @@ static py::list shapeToPyList(const int64_t *data, intptr_t n) { return lst; } +static py::object wrapAttributeAs(const py::module_ &m, const char *className, + MlirAttribute attr) { + if (mlirAttributeIsNull(attr)) + return py::none(); + py::object cls = m.attr(className); + return cls.attr("__call__")(attr); +} + void populatePTODialectSubmodule(pybind11::module &m); void populatePTODialectSubmodule(pybind11::module &m) { (void)m; @@ -705,6 +716,61 @@ static void bindPTOModule(pybind11::module &m) { return mlirPTOPtrTypeGetMemorySpace(self); }); + mlir_type_subclass( + m, "VRegType", + [](MlirType type) -> bool { return isa(unwrap(type)); }) + .def_classmethod( + "get", + [](py::object cls, int64_t elementCount, MlirType elementType, + MlirContext context) -> py::object { + context = inferContextFromElementType(context, elementType); + MlirType t = wrap( + mlir::pto::VRegType::get( + unwrap(context), elementCount, unwrap(elementType))); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("element_count"), py::arg("element_type"), + py::arg("context") = py::none()) + .def_property_readonly( + "element_count", + [](MlirType self) -> int64_t { + return cast(unwrap(self)).getElementCount(); + }) + .def_property_readonly( + "element_type", + [](MlirType self) -> MlirType { + return wrap(cast(unwrap(self)).getElementType()); + }); + + mlir_type_subclass( + m, "MaskType", + [](MlirType type) -> bool { return isa(unwrap(type)); }) + .def_classmethod( + "get", + [](py::object cls, std::string granularity, MlirContext context) -> py::object { + MlirType t = wrap( + mlir::pto::MaskType::get(unwrap(context), granularity)); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("granularity"), + py::arg("context") = py::none()) + .def_property_readonly( + "granularity", + [](MlirType self) -> std::string { + return cast(unwrap(self)).getGranularity().str(); + }); + + mlir_type_subclass( + m, "AlignType", + [](MlirType type) -> bool { return isa(unwrap(type)); }) + .def_classmethod( + "get", + [](py::object cls, MlirContext context) -> py::object { + MlirType t = wrap(mlir::pto::AlignType::get(unwrap(context))); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("context") = py::none()); + mlir_type_subclass( m, "AsyncSessionType", [](MlirType type) -> bool { return mlirPTOTypeIsAAsyncSessionType(type); }) @@ -976,7 +1042,76 @@ static void bindPTOModule(pybind11::module &m) { if (mlirPTOTypeIsATileBufType(t)) return cls(t); return py::none(); }, - py::arg("cls"), py::arg("type")); + py::arg("cls"), py::arg("type")) + .def_property_readonly( + "rank", + [](MlirType self) -> intptr_t { + return static_cast( + cast(unwrap(self)).getRank()); + }) + .def_property_readonly( + "element_type", + [](MlirType self) -> MlirType { + return wrap(cast(unwrap(self)).getElementType()); + }) + .def_property_readonly( + "memory_space", + [m](MlirType self) -> py::object { + MlirAttribute attr = + wrap(cast(unwrap(self)).getMemorySpace()); + return wrapAttributeAs(m, "AddressSpaceAttr", attr); + }) + .def_property_readonly( + "shape", + [](MlirType self) -> py::list { + auto shape = cast(unwrap(self)).getShape(); + return shapeToPyList(shape.data(), static_cast(shape.size())); + }) + .def_property_readonly( + "valid_shape", + [](MlirType self) -> py::list { + auto validShape = cast(unwrap(self)).getValidShape(); + return shapeToPyList(validShape.data(), static_cast(validShape.size())); + }) + .def_property_readonly( + "blayout_attr", + [m](MlirType self) -> py::object { + MlirAttribute attr = + wrap(cast(unwrap(self)).getBLayoutAttr()); + return wrapAttributeAs(m, "BLayoutAttr", attr); + }) + .def_property_readonly( + "slayout_attr", + [m](MlirType self) -> py::object { + MlirAttribute attr = + wrap(cast(unwrap(self)).getSLayoutAttr()); + return wrapAttributeAs(m, "SLayoutAttr", attr); + }) + .def_property_readonly( + "blayout_value", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getBLayoutValueI32(); + }) + .def_property_readonly( + "slayout_value", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getSLayoutValueI32(); + }) + .def_property_readonly( + "pad_value", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getPadValueI32(); + }) + .def_property_readonly( + "compact_mode", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getCompactModeI32(); + }) + .def_property_readonly( + "s_fractal_size", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getSFractalSizeI32(); + }); populatePTODialectSubmodule(m); } diff --git a/ptodsl/README.md b/ptodsl/README.md index f1510bf1e..58809fcba 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -1,8 +1,8 @@ # ptodsl — PTO Python IR Builders A lightweight, pip-installable DSL package for building PTO MLIR IR modules -in Python. The API is inspired by Triton / CuteDSL: kernels are ordinary -Python functions decorated with `@pto.to_ir`, type annotations carry PTO +in Python. PTODSL kernels are ordinary Python functions decorated with +`@pto.jit`. Type annotations carry PTO types as lazy descriptors, and control-flow maps 1-to-1 to MLIR operations. --- @@ -19,12 +19,14 @@ ptodsl/ │ ├── _types.py # lazy dtype descriptors and type constructors │ ├── _ops.py # PTO operation wrappers │ ├── _control_flow.py # vecscope, for_, if_, yield_ context managers -│ └── _module.py # @pto.to_ir decorator + module builders +│ ├── _jit.py # @pto.jit decorator +│ ├── _tracing/ # shared tracing runtime building blocks +│ └── _tile_template_tracing.py # internal tile-template tracing implementation ├── examples/ │ ├── tadd_lowlevel.py # TADD – raw MLIR Python binding calls -│ ├── tadd_dsl.py # TADD – @pto.to_ir DSL style +│ ├── tadd_dsl.py # TADD – @pto.jit DSL style │ ├── softmax_lowlevel.py # Softmax – raw MLIR Python binding calls -│ └── softmax_dsl.py # Softmax – @pto.to_ir DSL style +│ └── softmax_dsl.py # Softmax – @pto.jit DSL style ├── pyproject.toml # pip install -e . ├── check_ir.py # IR correctness test runner └── README.md @@ -92,22 +94,25 @@ s = pto.scalar # arith shorthand alias ### Kernel decorator ```python -@pto.to_ir(name="MyKernel", kernel_kind="vector", arch="a5") +@pto.jit(name="MyKernel", kernel_kind="vector", target="a5") def MyKernel(): ... -@pto.to_ir(name="Softmax", kernel_kind="vector", arch="a5", func_attr="pto.aicore") +@pto.jit(name="Softmax", kernel_kind="vector", target="a5", func_attr="pto.aicore") def Softmax(arg0: pto.ptr(pto.float32, "gm"), n: pto.int32): ... print(MyKernel) # prints MLIR text -mod = MyKernel.build() # returns mlir.ir.Module +mod = MyKernel.mlir_module() # returns mlir.ir.Module ``` `func_attr="pto.aicore"` selects a flat single-module structure with the `pto.aicore` function attribute (softmax style). Without it, a nested double-module is emitted (TADD style). +Additional layered kernel decorators are also exported on the public surface: +`@pto.ukernel`, `@pto.cube`, `@pto.simd`, and `@pto.simt`. + ### Type descriptors (lazy – safe to use in annotations) | Expression | MLIR type | @@ -190,57 +195,16 @@ pto.vadd(a, b, mask) # infers result type from a.type pto.vmul / vmax / vdiv / vcmax / vcadd / vdup / vexpdif # similarly pto.make_tensor_view(ptr, shape=…, strides=…) # type inferred pto.partition_view(tv, offsets=…, sizes=…) # type inferred -pto.alloc_tile(tile_type, addr=…, valid_row=…, valid_col=…) +pto.alloc_tile(shape=…, dtype=…, memory_space=…) # authored surface pto.tload(part, tile) pto.tstore(tile, part) -pto.tile_ptr(tile, ptr_type) +tile.as_ptr() / view.as_ptr() pto.get_block_idx() # → i64 pto.set_flag("MTE2", "V", event_id=0) pto.wait_flag("MTE2", "V", event_id=0) -pto.barrier_all() -``` - -### Experimental `vpto` POC - -For early experiments around AST-free tracing of TileLang-style tile templates, -`ptodsl` also exposes an experimental namespace: - -```python -from ptodsl import vpto as pto - -@pto.vkernel(target="a5", op="pto.tadd") -def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - dtype = dst.element_type - valid_rows, valid_cols = dst.valid_shape - with pto.for_(0, valid_rows, step=1) as row: - remained0 = pto.scalar_const(64, pto.i32) - with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: - col = loop.iv - remained = loop.state.remained - mask, next_remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - out = pto.vadd(lhs, rhs, mask) - pto.vsts(out, dst[row, col:], mask) - loop.yield_state(remained=next_remained) -``` - -Current limitations: - -- pybinding-backed POC only; it still covers a narrow TileLang-shaped subset -- supports only static 2D `Tile` parameters -- supports only a narrow vector subset needed by `tadd_template.py` -- currently uses explicit structured `for_()` builders rather than Python `for range(...)` -- `vecscope()` remains available, but it is no longer required by the POC - -Reference script: - -```bash -python3 lib/TileOps/tadd_template_tracing_poc.py +pto.pipe_barrier(pto.Pipe.ALL) ``` ---- - ## How the IR check works ``` diff --git a/ptodsl/docs/demos/flash_attention_sketch.py b/ptodsl/demos/flash_attention_sketch.py similarity index 65% rename from ptodsl/docs/demos/flash_attention_sketch.py rename to ptodsl/demos/flash_attention_sketch.py index 39db1af0e..039ca3c9a 100644 --- a/ptodsl/docs/demos/flash_attention_sketch.py +++ b/ptodsl/demos/flash_attention_sketch.py @@ -6,13 +6,13 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. """ -Flash Attention redesign sketch. +Flash Attention compile-only demo. -This file is intentionally a design demo rather than runnable ``ptodsl`` code. -The goal is to make the *proposed* API layering explicit and keep the semantic -contracts clean: +This file is a compileable PTODSL demo whose current milestone is MLIR +emission, inspection, and API review. The goal is to make the intended API +layering explicit and keep the semantic contracts clean: - flash_attention(...) user-facing wrapper + emit_flash_attention_mlir(...) compile/inspect wrapper └─ @pto.jit flash_attention_kernel ├─ Tile Ops tload / tstore at the GM↔UB boundary └─ @pto.ukernel one KV-block worth of MTE/sync orchestration @@ -23,10 +23,10 @@ Design rules illustrated here: 1. ``@pto.jit`` marks a launchable kernel template. It owns JIT compilation, - cache lookup, and runtime launch binding, instead of forcing users to hop - through extra builder objects for common cases. -2. The Python wrapper owns ergonomic runtime concerns such as output allocation, - default stream handling, and extracting shape/stride metadata from tensors. + cache lookup, and artifact emission, instead of forcing users to hop through + extra builder objects for common cases. +2. The Python wrapper owns compile/inspection concerns such as selecting + specialization knobs and returning the emitted MLIR text for review. 3. ``@pto.jit`` also owns the top-level logical tiling, tile allocation, and loop scheduling for one already-selected per-head 2D slice. It should not manually spell low-level DMA details for every micro step. @@ -52,29 +52,53 @@ Hiding these dependencies with in-place aliases makes the algorithm harder to read and obscures what the DSL needs to express. -The API spellings below are approximate and intentionally favor the redesign -surface over today's exact binding details. - -Because this sketch targets a tracing-style frontend, any control flow that +Because this demo targets a tracing-style frontend, any control flow that must reach MLIR is expressed with structured DSL constructs such as ``pto.for_`` instead of native Python ``for`` loops. -Scalar literals and simple index/integer conversions are also shown in their -authored form. The intended frontend behavior is to lift Python ``int`` -literals and obvious scalar arithmetic into the corresponding MLIR scalar ops -implicitly, rather than forcing authors to spell ``pto.const(...)`` or -``index_cast(...)`` at every use site. +Scalar literals and simple index/integer conversions are also written in the +authored PTODSL surface. The current frontend lowers these through tracing +instead of forcing authors to spell ``pto.const(...)`` or ``index_cast(...)`` +at every use site. """ +from pathlib import Path +import sys + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from flash_attention_sketch.py" + ) + from ptodsl import pto +scalar = pto.scalar + + +def _min_index(lhs, rhs): + return pto.scalar.select( + pto.scalar.cmpi("slt", lhs, rhs), + lhs, + rhs, + ) + + +def _block_valid_extent(total, block_index, block_size): + return _min_index(total - block_index * block_size, pto.const(block_size)) + # ═══════════════════════════════════════════════════════════════════════════════ # Public API sketch # ═══════════════════════════════════════════════════════════════════════════════ # -# This section intentionally sketches the *desired* public surface, not today's -# exact implementation details. The split follows the common industry pattern: +# This section shows the current compile-only public surface. The split follows +# the common industry pattern: # # - a user-facing tensor wrapper # - a launchable JIT kernel entry @@ -82,66 +106,54 @@ # # The low-level kernel body should not double as the user-facing runtime API. # -# Two intended usage styles: +# Two intended usage styles for the current compile-only milestone: # -# 1. Direct call (most users): -# out = flash_attention(Q, K, V, causal=True) +# 1. One-shot MLIR emission: +# mlir_text = emit_flash_attention_mlir(head_dim=128, causal=True) # -# 2. Compile first, then launch repeatedly: +# 2. Compile first, then inspect: # compiled = flash_attention_kernel.compile(BLOCK_Q=128, BLOCK_KV=128, CAUSAL=True) -# compiled[batch * heads, stream]( -# Q, K, V, O, -# ) - -def flash_attention( - Q, - K, - V, +# mlir_text = compiled.mlir_text() + +def emit_flash_attention_mlir( *, - O=None, + head_dim=128, causal=False, block_q=128, block_kv=128, - stream=None, ): """ - User-facing convenience wrapper. + Compile the flash-attention sketch and return its MLIR text. - This is the API most end users should call. It mirrors mainstream tensor - libraries: infer runtime metadata from tensors, allocate the output when the - caller does not provide one, then compile and launch the JIT kernel. + The current milestone for this demo is compile / inspect / review, not + runtime launch. The wrapper therefore only specializes the JIT kernel and + returns the emitted MLIR text. """ - if O is None: - O = pto.empty_like(Q) - - batch, seq_q, heads, dim = Q.shape - _, seq_k, _, _ = K.shape - compiled = flash_attention_kernel.compile( BLOCK_Q=block_q, BLOCK_KV=block_kv, + HEAD_DIM=head_dim, CAUSAL=causal, ) - - compiled[batch * heads, stream](Q, K, V, O) - return O + return compiled.mlir_text() @pto.jit(target="a5") def flash_attention_kernel( - Q, # Python/framework tensor, logical [batch, seq_q, heads, dim] - K, # Python/framework tensor, logical [batch, seq_k, heads, dim] - V, # Python/framework tensor, logical [batch, seq_k, heads, dim] - O, # Python/framework tensor, logical [batch, seq_q, heads, dim] + Q: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_q, heads, dim] + K: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_k, heads, dim] + V: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_k, heads, dim] + O: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_q, heads, dim] *, BLOCK_Q: pto.constexpr = 128, BLOCK_KV: pto.constexpr = 128, + HEAD_DIM: pto.constexpr = 128, CAUSAL: pto.constexpr = False, NUM_STAGES: pto.constexpr = 2, ): """ Launchable device entry. - ``@pto.jit`` is the compile + launch boundary. Inputs/outputs at this + ``@pto.jit`` is the compile boundary. Inputs/outputs at this boundary are Python-native tensor objects; PTO-specific ``TensorView`` descriptors are materialized inside the JIT body rather than exposed in the public signature. Tile sizes and specialization knobs remain constexpr @@ -153,10 +165,10 @@ def flash_attention_kernel( batch, seq_q, heads, dim = Q.shape _, seq_k, _, _ = K.shape - q_view = pto.make_tensor_view(Q, shape=[batch, seq_q, heads, dim], strides=Q.strides) - k_view = pto.make_tensor_view(K, shape=[batch, seq_k, heads, dim], strides=K.strides) - v_view = pto.make_tensor_view(V, shape=[batch, seq_k, heads, dim], strides=V.strides) - o_view = pto.make_tensor_view(O, shape=[batch, seq_q, heads, dim], strides=O.strides) + q_view = pto.make_tensor_view(Q) + k_view = pto.make_tensor_view(K) + v_view = pto.make_tensor_view(V) + o_view = pto.make_tensor_view(O) # Make the SPMD launch contract explicit in the authored surface. # This sketch uses one block per (batch, head) slice and does not further @@ -179,72 +191,116 @@ def flash_attention_kernel( batch_idx = block_idx // heads head_idx = block_idx % heads - q_head = pto.select_head_view( + q_head = pto.partition_view( q_view, - batch=batch_idx, - head=head_idx, - shape=[seq_q, dim], + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], ) - k_head = pto.select_head_view( + k_head = pto.partition_view( k_view, - batch=batch_idx, - head=head_idx, - shape=[seq_k, dim], + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], ) - v_head = pto.select_head_view( + v_head = pto.partition_view( v_view, - batch=batch_idx, - head=head_idx, - shape=[seq_k, dim], + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], ) - o_head = pto.select_head_view( + o_head = pto.partition_view( o_view, - batch=batch_idx, - head=head_idx, - shape=[seq_q, dim], + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], ) Br = BLOCK_Q Bc = BLOCK_KV + D = HEAD_DIM + full_br = pto.const(Br) + full_bc = pto.const(Bc) + one = pto.const(1) q_blocks = (seq_q + Br - 1) // Br kv_blocks = (seq_k + Bc - 1) // Bc - # UB resident logical tiles for one selected (batch, head) slice. - q_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) - k_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) - v_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) - - o_prev_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) - o_next_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) - m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) - m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) - l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) - l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + # Physical tile shape remains static. Runtime tails live in valid_shape. + # Cube bridge sources are MAT-backed so they can feed LEFT/RIGHT staging. + q_mat = pto.alloc_tile( + shape=[Br, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, dim], + blayout="ColMajor", + slayout="RowMajor", + ) + k_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", + ) + v_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", + ) - s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) - p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) - pv_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) - alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) - beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) + o_prev_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) + o_next_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) + m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + + s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) + p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) + p_mat = pto.alloc_tile( + shape=[Br, Bc], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, full_bc], + blayout="ColMajor", + slayout="RowMajor", + ) + pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) + alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") # Cube-local scratch is explicit; it should not be conflated with UB tiles. - q_l0a = pto.alloc_tile(shape=[Br, dim], dtype=pto.f16, memory_space=pto.MemorySpace.LEFT) - p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f16, memory_space=pto.MemorySpace.LEFT) - rhs_l0b = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f16, memory_space=pto.MemorySpace.RIGHT) - qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) - pv_acc_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) + q_l0a = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, dim]) + p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, full_bc]) + rhs_l0b = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, memory_space=pto.MemorySpace.RIGHT, valid_shape=[full_bc, dim]) + qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, full_bc]) + pv_acc_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, dim]) # SIMT metadata buffer. A tiny raw-pointer island is acceptable at the # ukernel boundary because this is scalar control data, not user-facing math. - meta_tile = pto.alloc_tile(shape=[3, 1], dtype=pto.i32) - meta_ptr = pto.tile_buf_addr(meta_tile) + meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) + meta_ptr = meta_tile.as_ptr() with pto.for_(0, q_blocks, step=1) as qi: - q_part = pto.partition_view(q_head, offsets=[qi * Br, 0], sizes=[Br, dim]) - o_part = pto.partition_view(o_head, offsets=[qi * Br, 0], sizes=[Br, dim]) - - pto.tload(q_part, q_tile) + q_rows = _block_valid_extent(seq_q, qi, Br) + q_part = pto.partition_view(q_head, offsets=[0, qi * Br, 0, 0], sizes=[1, q_rows, 1, dim]) + o_part = pto.partition_view(o_head, offsets=[0, qi * Br, 0, 0], sizes=[1, q_rows, 1, dim]) + + q_mat.valid_shape = [q_rows, dim] + o_prev_tile.valid_shape = [q_rows, dim] + o_next_tile.valid_shape = [q_rows, dim] + m_prev_tile.valid_shape = [q_rows, one] + m_next_tile.valid_shape = [q_rows, one] + l_prev_tile.valid_shape = [q_rows, one] + l_next_tile.valid_shape = [q_rows, one] + alpha_tile.valid_shape = [q_rows, one] + beta_tile.valid_shape = [q_rows, one] + p_mat.valid_shape = [q_rows, full_bc] + pv_tile.valid_shape = [q_rows, dim] + q_l0a.valid_shape = [q_rows, dim] + + pto.tload(q_part, q_mat) # Initial online-softmax state for this Q block. # ``CAUSAL`` is threaded at the API boundary even though the masking @@ -263,15 +319,27 @@ def flash_attention_kernel( m_cur = kv_loop.m l_cur = kv_loop.l o_cur = kv_loop.o - k_part = pto.partition_view(k_head, offsets=[kj * Bc, 0], sizes=[Bc, dim]) - v_part = pto.partition_view(v_head, offsets=[kj * Bc, 0], sizes=[Bc, dim]) + kv_rows = _block_valid_extent(seq_k, kj, Bc) + k_part = pto.partition_view(k_head, offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) + v_part = pto.partition_view(v_head, offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) + + k_mat.valid_shape = [kv_rows, dim] + v_mat.valid_shape = [kv_rows, dim] + s_tile.valid_shape = [q_rows, kv_rows] + p_tile.valid_shape = [q_rows, kv_rows] + p_mat.valid_shape = [q_rows, kv_rows] + pv_tile.valid_shape = [q_rows, dim] + p_l0a.valid_shape = [q_rows, kv_rows] + rhs_l0b.valid_shape = [kv_rows, dim] + qk_acc_tile.valid_shape = [q_rows, kv_rows] + pv_acc_tile.valid_shape = [q_rows, dim] kv_block_process( - q_tile, + q_mat, k_part, v_part, - k_tile, - v_tile, + k_mat, + v_mat, o_cur, o_next_tile, m_cur, @@ -280,6 +348,7 @@ def flash_attention_kernel( l_next_tile, s_tile, p_tile, + p_mat, pv_tile, alpha_tile, beta_tile, @@ -316,8 +385,8 @@ def flash_attention_kernel( @pto.cube def qk_matmul( - q_tile: pto.Tile, # UB, [Br, dim] - k_tile: pto.Tile, # UB, [Bc, dim] + q_mat: pto.Tile, # MAT, [Br, dim] + k_mat: pto.Tile, # MAT, [Bc, dim] q_l0a: pto.Tile, # LEFT scratch k_l0b: pto.Tile, # RIGHT scratch s_acc: pto.Tile, # ACC scratch @@ -326,25 +395,25 @@ def qk_matmul( """ Compute ``S = Q @ K^T`` for one attention block. - The key point for the redesign is that the cube kernel consumes UB tiles and - explicit cube-local scratch, rather than pretending a UB tile can also stand + The key point for the redesign is that the cube kernel consumes MAT tiles and + explicit cube-local scratch, rather than pretending a logical scheduling tile can also stand in for LEFT/RIGHT/ACC state. """ - m = pto.tile_valid_rows(q_tile) - k = pto.tile_valid_cols(q_tile) - n = pto.tile_valid_rows(k_tile) + m = q_mat.valid_shape[0] + k = q_mat.valid_shape[1] + n = k_mat.valid_shape[0] # Caller owns scratch lifetime. The cube kernel only expresses dataflow. - pto.mte_l1_l0a(q_tile, q_l0a, m, k) - pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) - pto.mad(q_l0a, k_l0b, s_acc) - pto.mte_l0c_ub(s_acc, s_tile, m, n) + pto.mte_l1_l0a(q_mat.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_mat.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) @pto.cube def pv_matmul( - p_tile: pto.Tile, # UB, [Br, Bc] - v_tile: pto.Tile, # UB, [Bc, dim] + p_mat: pto.Tile, # MAT, [Br, Bc] + v_mat: pto.Tile, # MAT, [Bc, dim] p_l0a: pto.Tile, # LEFT scratch (reused) v_l0b: pto.Tile, # RIGHT scratch (reused) pv_acc: pto.Tile, # ACC scratch (reused) @@ -356,14 +425,14 @@ def pv_matmul( This keeps the second matrix product on the cube path as well, instead of accidentally collapsing it into an elementwise vector expression. """ - m = pto.tile_valid_rows(p_tile) - k = pto.tile_valid_cols(p_tile) - n = pto.tile_valid_cols(v_tile) + m = p_mat.valid_shape[0] + k = p_mat.valid_shape[1] + n = v_mat.valid_shape[1] - pto.mte_l1_l0a(p_tile, p_l0a, m, k) - pto.mte_l1_l0b(v_tile, v_l0b, k, n) - pto.mad(p_l0a, v_l0b, pv_acc) - pto.mte_l0c_ub(pv_acc, pv_tile, m, n) + pto.mte_l1_l0a(p_mat.as_ptr(), p_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(v_mat.as_ptr(), v_l0b.as_ptr(), k, n) + pto.mad(p_l0a.as_ptr(), v_l0b.as_ptr(), pv_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(pv_acc.as_ptr(), pv_tile.as_ptr(), m, n, n, n, 0) @pto.simd @@ -415,10 +484,10 @@ def online_softmax_rows( beta = 1.0 / l_next pto.vsts(p_row, p_tile[row, 0:], col_mask) - scalar.sts(m_next_tile[row, 0], m_next) - scalar.sts(l_next_tile[row, 0], l_next) - scalar.sts(alpha_tile[row, 0], alpha) - scalar.sts(beta_tile[row, 0], beta) + scalar.store(m_next, m_next_tile[row, 0]) + scalar.store(l_next, l_next_tile[row, 0]) + scalar.store(alpha, alpha_tile[row, 0]) + scalar.store(beta, beta_tile[row, 0]) @pto.simt @@ -451,7 +520,7 @@ def blend_output_rows( pv_val = scalar.load(pv_tile[row, col]) o_next = alpha * o_prev + beta * pv_val - scalar.sts(o_next_tile[row, col], o_next) + scalar.store(o_next, o_next_tile[row, col]) @pto.simt @@ -466,9 +535,9 @@ def materialize_tile_bounds( The SIMT kernel stays intentionally small here: it is responsible for scalar control metadata, not for rewriting the vector or cube logic. """ - scalar.sts(meta_ptr + 0, 0) - scalar.sts(meta_ptr + 4, valid_rows) - scalar.sts(meta_ptr + 8, valid_cols) + scalar.store(0, meta_ptr + 0) + scalar.store(valid_rows, meta_ptr + 1) + scalar.store(valid_cols, meta_ptr + 2) # ═══════════════════════════════════════════════════════════════════════════════ @@ -478,11 +547,11 @@ def materialize_tile_bounds( @pto.ukernel def kv_block_process( - q_tile: pto.Tile, # UB, reused across inner KV loop + q_mat: pto.Tile, # MAT, reused across inner KV loop k_part: pto.PartitionTensorView, # GM view for current K block v_part: pto.PartitionTensorView, # GM view for current V block - k_tile: pto.Tile, # UB scratch - v_tile: pto.Tile, # UB scratch + k_mat: pto.Tile, # MAT scratch + v_mat: pto.Tile, # MAT scratch o_prev_tile: pto.Tile, # UB state o_next_tile: pto.Tile, # UB state m_prev_tile: pto.Tile, # UB state @@ -491,6 +560,7 @@ def kv_block_process( l_next_tile: pto.Tile, # UB state s_tile: pto.Tile, # UB scratch for QK^T p_tile: pto.Tile, # UB scratch for probabilities + p_mat: pto.Tile, # MAT scratch for probabilities pv_tile: pto.Tile, # UB scratch for P@V alpha_tile: pto.Tile, # UB scratch beta_tile: pto.Tile, # UB scratch @@ -511,23 +581,23 @@ def kv_block_process( - wiring together the explicit state transition (prev -> next for m/l/o). """ - # Current-block GM->UB staging via MTE micro-instructions. - pto.mte_load(k_part, k_tile) - pto.mte_load(v_part, v_tile) - pto.mem_bar(pto.BarrierType.SYNC) + # Current-block GM->MAT staging via MTE micro-instructions. + pto.mte_load(k_part, k_mat) + pto.mte_load(v_part, v_mat) + pto.pipe_barrier(pto.Pipe.ALL) materialize_tile_bounds( meta_ptr, - pto.tile_valid_rows(q_tile), - pto.tile_valid_rows(k_tile), + q_mat.valid_shape[0], + k_mat.valid_shape[0], ) row_start = scalar.load(meta_ptr + 0) - row_stop = scalar.load(meta_ptr + 4) - valid_cols = scalar.load(meta_ptr + 8) + row_stop = scalar.load(meta_ptr + 1) + valid_cols = scalar.load(meta_ptr + 2) # 1. S = Q @ K^T - qk_matmul(q_tile, k_tile, q_l0a, rhs_l0b, qk_acc_tile, s_tile) - pto.mem_bar(pto.BarrierType.SYNC) + qk_matmul(q_mat, k_mat, q_l0a, rhs_l0b, qk_acc_tile, s_tile) + pto.pipe_barrier(pto.Pipe.ALL) # 2. Row-wise online softmax over S online_softmax_rows( @@ -543,11 +613,15 @@ def kv_block_process( row_stop, valid_cols, ) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) + + # Stage the probability tile onto the cube MAT path. + pto.tmov(p_tile, p_mat) + pto.pipe_barrier(pto.Pipe.ALL) # 3. PV = P @ V - pv_matmul(p_tile, v_tile, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) - pto.mem_bar(pto.BarrierType.SYNC) + pv_matmul(p_mat, v_mat, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) + pto.pipe_barrier(pto.Pipe.ALL) # 4. O_next = alpha * O_prev + beta * PV blend_output_rows( @@ -558,9 +632,9 @@ def kv_block_process( o_next_tile, row_start, row_stop, - pto.tile_valid_cols(v_tile), + v_mat.valid_shape[1], ) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) # ═══════════════════════════════════════════════════════════════════════════════ @@ -568,15 +642,15 @@ def kv_block_process( # ═══════════════════════════════════════════════════════════════════════════════ # # ┌──────────────────────────────────────────────────────────────────────────┐ -# │ L0 Python wrapper flash_attention(...) │ +# │ L0 Python wrapper emit_flash_attention_mlir(...) │ # │ │ -# │ output allocation, shape/stride extraction, compile, launch │ +# │ specialize kernel parameters, compile, emit MLIR text │ # │ │ -# │ Key idea: user-facing tensor API, not IR authoring. │ +# │ Key idea: current demo goal is compile/inspect, not runtime launch. │ # ├──────────────────────────────────────────────────────────────────────────┤ -# │ L1 @pto.jit compile + cache + launch + top-level orchestration │ +# │ L1 @pto.jit compile + cache + top-level orchestration │ # │ │ -# │ flash_attention_kernel[grid, stream](...) │ +# │ flash_attention_kernel.compile(...).mlir_text() │ # │ TensorView metadata / alloc_tile / partition_view / tload / tstore │ # │ outer Q loop + inner KV loop + ping-pong state ownership │ # │ │ @@ -628,5 +702,14 @@ def kv_block_process( # After each KV block: # (m_prev, l_prev, o_prev) := (m_next, l_next, o_next) # -# The important part for the redesign is not the exact helper spelling, but -# that every cross-stage dependency is visible in the surface language. +# The important part for the demo is that every cross-stage dependency is +# visible in the surface language and the whole kernel can already be traced to +# MLIR for review. + + +def main(): + print(emit_flash_attention_mlir()) + + +if __name__ == "__main__": + main() diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index cc6e8f134..4837a675f 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -84,15 +84,18 @@ def flash_attention(Q, K, V, *, O=None, causal=False): Decorating a function with `@pto.jit` marks it as a launchable PTO kernel. This decoration means: - **Compilation**: the function body is traced once to record all PTO instructions, then lowered through the PTOAS compiler pipeline into an optimized NPU executable. -- **Caching**: compiled kernels are cached by key (function identity + constexpr parameter values), so repeated calls with the same configuration skip recompilation. +- **Caching**: compiled kernels are cached by specialization key (function identity + tensor ABI signature + constexpr parameter values), so repeated calls with the same configuration skip recompilation. - **Launch binding**: the compiled kernel can be invoked with a grid and stream — `compiled[grid, stream](args...)` — which launches the executable on the NPU with the given SPMD grid. -The parameters of a `@pto.jit` function are Python-native tensors (not PTODSL-specific descriptors). The kernel body materializes `TensorView` descriptors from them via `make_tensor_view`, then partitions the problem with `partition_view`. Compile-time constants are declared as keyword-only arguments with `pto.constexpr`: +The parameters of a `@pto.jit` function are Python-native tensors (not PTODSL-specific descriptors). In PTODSL v1, their ABI contract is declared with `pto.tensor_spec(...)` in the function signature; this is a compile-time annotation, not a runtime object the Python wrapper must construct. The kernel body materializes `TensorView` descriptors from the runtime tensors via `make_tensor_view`, then partitions the problem with `partition_view`. Compile-time constants are declared as keyword-only arguments with `pto.constexpr`: ```python @pto.jit(target="a5") def flash_attention_kernel( - Q, K, V, O, + Q: pto.tensor_spec(rank=4, dtype=pto.f32), + K: pto.tensor_spec(rank=4, dtype=pto.f32), + V: pto.tensor_spec(rank=4, dtype=pto.f32), + O: pto.tensor_spec(rank=4, dtype=pto.f32), *, BLOCK_Q: pto.constexpr = 128, BLOCK_KV: pto.constexpr = 128, @@ -154,7 +157,7 @@ The flash attention kernel from Section 1.2 is not just an architectural diagram **L1 (`@pto.jit`)** allocates tiles for the Q block, KV block, online-softmax state (m/l/o ping-pong tiles), and cube-local scratch. It loops over Q blocks (outer `pto.for_`) and KV blocks (inner `pto.for_` with carry state), calling `kv_block_process` for each KV block and using `tload`/`tstore` at the GM boundary. -**L2 (`@pto.ukernel`)** stages the current K and V blocks with `mte_load`, issues `mem_bar` for synchronization, then sequences four sub-kernel calls: `qk_matmul` (cube), `online_softmax_rows` (simd), `pv_matmul` (cube), `blend_output_rows` (simt). +**L2 (`@pto.ukernel`)** stages the current K and V blocks with `mte_load`, issues `pipe_barrier(Pipe.ALL)` at phase boundaries, then sequences four sub-kernel calls: `qk_matmul` (cube), `online_softmax_rows` (simd), `pv_matmul` (cube), `blend_output_rows` (simt). **L3a (`@pto.cube`)** performs `mte_l1_l0a` / `mte_l1_l0b` / `mad` / `mte_l0c_ub` for both QK^T and P@V products. diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md index af0ec6298..6830fbe73 100644 --- a/ptodsl/docs/user_guide/02-quick-start.md +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -142,7 +142,7 @@ O = np.empty_like(A) compiled[1, None](A, B, O) ``` -- `.compile(**constexprs)` traces the kernel body, lowers it through the PTOAS pipeline, and returns a compiled handle. Repeated calls with the same configuration hit the cache. +- `.compile(**constexprs)` traces the kernel body, lowers it through the PTOAS pipeline, and returns a compiled handle. Repeated calls with the same tensor ABI contract and constexpr configuration hit the cache. - `compiled[grid, stream](args...)` launches the compiled kernel. `grid` is the number of SPMD blocks; `stream` is the NPU stream (or `None` for the default). ## 2.4 SPMD launch @@ -195,10 +195,10 @@ def add_block(a_part: pto.PartitionTensorView, rows: pto.i32, cols: pto.i32): pto.mte_load(a_part, a_tile) pto.mte_load(b_part, b_tile) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) add_rows(a_tile, b_tile, o_tile, rows, cols) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) pto.mte_store(o_tile, o_part) diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index c7cd157eb..41f6cb564 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -53,7 +53,7 @@ compiled = kernel_name.compile(CONST_A=128, CONST_B=64) compiled[grid, stream](tensor_1, tensor_2, ...) ``` -- `.compile(**constexprs)` — traces the kernel body with the given constexpr values, lowers the IR, and returns a compiled handle. Subsequent calls with the same (function identity, constexpr values) hit the cache. +- `.compile(**constexprs)` — traces the kernel body with the given constexpr values, lowers the IR, and returns a compiled handle. Subsequent calls with the same specialization key (function identity, tensor ABI signature, constexpr values) hit the cache. - `compiled[grid, stream](args...)` — launches the compiled kernel. `grid` is the number of SPMD blocks (an integer); `stream` is the NPU stream (`None` for default). ### SPMD built-ins @@ -101,13 +101,13 @@ When you call an L3 sub-kernel directly from `@pto.jit`, data movement is handle ```python @pto.cube def my_matmul(a_tile, b_tile, l0a, l0b, acc, o_tile): - m = pto.tile_valid_rows(a_tile) - k = pto.tile_valid_cols(a_tile) - n = pto.tile_valid_rows(b_tile) - pto.mte_l1_l0a(a_tile, l0a, m, k) - pto.mte_l1_l0b(b_tile, l0b, k, n, transpose=True) - pto.mad(l0a, l0b, acc) - pto.mte_l0c_ub(acc, o_tile, m, n) + m = a_tile.valid_shape[0] + k = a_tile.valid_shape[1] + n = b_tile.valid_shape[0] + pto.mte_l1_l0a(a_tile.as_ptr(), l0a.as_ptr(), m, k) + pto.mte_l1_l0b(b_tile.as_ptr(), l0b.as_ptr(), k, n, transpose=True) + pto.mad(l0a.as_ptr(), l0b.as_ptr(), acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(acc.as_ptr(), o_tile.as_ptr(), m, n, n, n, 0) @pto.jit(target="a5") def my_kernel(A, B, O, *, BLOCK: pto.constexpr): @@ -172,14 +172,14 @@ def process_block(k_part, v_part, k_tile, v_tile, # Stage current block from GM to UB pto.mte_load(k_part, k_tile) pto.mte_load(v_part, v_tile) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) # Dispatch sub-kernels qk_matmul(q_tile, k_tile, s_tile) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) online_softmax(s_tile, o_tile, rows, cols) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) # Write result back pto.mte_store(o_tile, o_part) @@ -220,14 +220,14 @@ def qk_matmul( s_acc: pto.Tile, s_tile: pto.Tile, ): - m = pto.tile_valid_rows(q_tile) - k = pto.tile_valid_cols(q_tile) - n = pto.tile_valid_rows(k_tile) - - pto.mte_l1_l0a(q_tile, q_l0a, m, k) - pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) - pto.mad(q_l0a, k_l0b, s_acc) - pto.mte_l0c_ub(s_acc, s_tile, m, n) + m = q_tile.valid_shape[0] + k = q_tile.valid_shape[1] + n = k_tile.valid_shape[0] + + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) ``` Cube-local state (LEFT, RIGHT, ACC, BIAS) never leaks into UB — it is the caller's responsibility to allocate scratch buffers and pass them in explicitly. @@ -353,10 +353,10 @@ with pto.simt(): ```python with pto.cube(): - pto.mte_l1_l0a(q_tile, q_l0a, m, k) - pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) - pto.mad(q_l0a, k_l0b, s_acc) - pto.mte_l0c_ub(s_acc, s_tile, m, n) + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) ``` ### Semantics @@ -394,7 +394,7 @@ Data crosses decorator boundaries only through UB-backed tiles or typed UB point ## 3.9 `pto.constexpr` -`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time constant. The compiler specializes the kernel for each combination of constexpr values, and the compiled artifact is cached by those values. +`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time constant. The compiler specializes the kernel for each combination of constexpr values, and the compiled artifact is cached by specialization key together with the kernel's tensor ABI contract. ```python @pto.jit(target="a5") diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md index f0944cbf4..22804daab 100644 --- a/ptodsl/docs/user_guide/04-type-system-and-buffer.md +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -175,12 +175,18 @@ A `Tile` is an on-chip buffer allocated in UB or cube-local memory. Allocate til # UB tile a_tile = pto.alloc_tile(shape=[BLOCK, dim], dtype=pto.f32) +# Logical column tile +m_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") + # Cube-local scratch with explicit memory space q_l0a = pto.alloc_tile(shape=[Br, dim], dtype=pto.f16, memory_space=pto.MemorySpace.LEFT) s_acc = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) ``` `alloc_tile` returns a `Tile` object. The `shape` must be a compile-time constant. The default memory space is UB. +For narrow logical column tiles such as `[Br, 1]`, author them with +`blayout="ColMajor"`. Row-major none-box tiles are validated against a 32-byte +physical row-alignment rule. ### Tile attributes diff --git a/ptodsl/docs/user_guide/05-control-flow.md b/ptodsl/docs/user_guide/05-control-flow.md index 2dc4ce142..0ef1f60c6 100644 --- a/ptodsl/docs/user_guide/05-control-flow.md +++ b/ptodsl/docs/user_guide/05-control-flow.md @@ -107,10 +107,10 @@ This pattern is central to algorithms like online softmax, where each KV block u ```python # Allocate ping-pong state tiles -m_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) -m_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) -l_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) -l_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +m_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") +m_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") +l_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") +l_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") # Initialize prev tiles m_prev.fill(float("-inf")) diff --git a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md index ba428a313..783e210db 100644 --- a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md +++ b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md @@ -4,9 +4,9 @@ Chapter 5 established the rule: Python constructs are resolved at trace time, PT ## 6.1 Python scalars vs PTO scalars -A **Python scalar** is any value computed by Python during tracing: a literal (`3.14159`), a shape dimension (`A.shape[0]`), a constexpr parameter (`BLOCK`), or an arithmetic expression built from these (`1.0 / sqrt(dim)`). These are evaluated at trace time and their results are baked into the device code as constants. +A **Python scalar** is any value computed by Python during tracing: a literal (`3.14159`), a constexpr parameter (`BLOCK`), or an arithmetic expression built only from compile-time-known values (`1.0 / sqrt(128)`). These are evaluated at trace time and their results are baked into the device code as constants. -A **PTO scalar** is a value that lives on the device at runtime. It comes from a `scalar.load` read, a device-side computation (`scalar.max`, `scalar.exp`), or a runtime query (`pto.get_block_idx()`). PTO scalars flow through the recorded program and are not resolved until the kernel executes. +A **PTO scalar** is a value that lives on the device at runtime. It comes from a `scalar.load` read, a device-side computation (`scalar.max`, `scalar.exp`), a runtime query (`pto.get_block_idx()`), or `@pto.jit` tensor metadata such as `A.shape[0]` / `A.strides[1]`. PTO scalars flow through the recorded program and are not resolved until the kernel executes. ### The mixed expression @@ -25,9 +25,10 @@ alpha * o_prev + beta * pv_val | If the value... | Use... | Example | |-----------------|--------|---------| -| Is known at compile time | Python scalar | `BLOCK`, `1.0 / sqrt(dim)`, `A.shape[0]` | +| Is known at compile time | Python scalar | `BLOCK`, `1.0 / sqrt(128)` | | Comes from device memory | PTO scalar | `scalar.load(tile[r, c])` | | Depends on a runtime value | PTO scalar | `scalar.max(m_prev, row_max)` | +| Comes from tensor metadata at the `@pto.jit` boundary | PTO scalar | `A.shape[0]`, `Q.strides[2]` | | Is a block/subblock index | PTO scalar | `pto.get_block_idx()` | When in doubt, ask: *can this value change between launches of the same compiled kernel?* If yes, it must be a PTO scalar. @@ -125,9 +126,9 @@ When writing to a raw pointer (e.g., a small metadata buffer obtained via `as_pt ```python meta_ptr = meta_tile.as_ptr() scalar.store(0, meta_ptr, 0) # store at element offset 0 -scalar.store(valid_rows, meta_ptr, 4) # store at element offset 4 +scalar.store(valid_rows, meta_ptr, 1) # store at element offset 1 row_start = scalar.load(meta_ptr, 0) -row_stop = scalar.load(meta_ptr, 4) +row_stop = scalar.load(meta_ptr, 1) ``` ## 6.3 Scalar arithmetic and comparisons diff --git a/ptodsl/docs/user_guide/07-data-movement-ops.md b/ptodsl/docs/user_guide/07-data-movement-ops.md index bf6cf7dec..40c1f02c7 100644 --- a/ptodsl/docs/user_guide/07-data-movement-ops.md +++ b/ptodsl/docs/user_guide/07-data-movement-ops.md @@ -258,11 +258,11 @@ def process_block(k_part, v_part, k_tile, v_tile, o_tile, o_part, cols * pto.bytewidth(pto.f16), nburst=(rows, cols * pto.bytewidth(pto.f16), cols * pto.bytewidth(pto.f16))) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) # ... compute on tiles ... - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) pto.mte_ub_gm(o_tile.as_ptr(), o_part.as_ptr(), cols * pto.bytewidth(pto.f32), nburst=(rows, cols * pto.bytewidth(pto.f32), @@ -958,7 +958,7 @@ Inside `@pto.cube`, data flows through a hierarchy of private buffers: GM → L1 --- -#### `pto.mte_l0c_ub(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, dual_dst_mode: int = 0, sub_blockid: int = 0, mode: FractalMode = FractalMode.NZ2ND, loop0_src_stride: int | None = None, channel_split_en: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` +#### `pto.mte_l0c_ub(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, sub_blockid: int = 0, *, dst_mode: str = "single") -> None` **Description**: Structured L0C (acc) directly to UB. This is the most common writeback path for cube kernels that feed results into subsequent processing. @@ -972,12 +972,8 @@ Inside `@pto.cube`, data flows through a hierarchy of private buffers: GM → L1 | `n` | `int` | N dimension size | | `src_stride` | `int` | Source stride | | `dst_stride` | `int` | Destination stride | -| `dual_dst_mode` | `int` | Dual destination mode (default 0) | | `sub_blockid` | `int` | Sub-block ID (default 0) | -| `mode` | `FractalMode` | `NZ2ND` (default), `NZ2DN`, or `NZ2NZ` | -| `loop0_src_stride` | `int` or `None` | Loop level 0 source stride | -| `channel_split_en` | `int` or `None` | Channel split enable (required for `NZ2NZ` mode) | -| `loop3` | `tuple[int, int, int]` or `None` | Loop level 3 parameters | +| `dst_mode` | `str` | Destination mode, currently `"single"` by default | **Returns**: None (side-effect operation). @@ -1010,10 +1006,10 @@ def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): k = q_tile.valid_shape[1] n = k_tile.valid_shape[0] - pto.mte_l1_l0a(q_tile, q_l0a, m, k) # UB tile → L0A - pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) # UB tile → L0B - pto.mad(q_l0a, k_l0b, s_acc) # L0A × L0B → L0C - pto.mte_l0c_ub(s_acc, s_tile, m, n) # L0C → UB tile + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) # UB tile → L0A + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) # UB tile → L0B + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) # L0A × L0B → L0C + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) # L0C → UB tile ``` -The `mte_l1_l0a`/`mte_l1_l0b` operations take UB `Tile` references directly (not raw pointers) — the tile-to-cube-local transfer is implicit. `mad` performs the matrix multiply. `mte_l0c_ub` writes the result back to a UB tile. +At the cube micro-op boundary, PTODSL currently uses explicit typed pointers. `tile.as_ptr()` materializes the pointer view for UB and cube-local scratch buffers, while the surrounding sub-kernel surface still uses `Tile` values for metadata such as `valid_shape`. diff --git a/ptodsl/docs/user_guide/08-compute-operations.md b/ptodsl/docs/user_guide/08-compute-operations.md index 6fddecdf4..75ec8e511 100644 --- a/ptodsl/docs/user_guide/08-compute-operations.md +++ b/ptodsl/docs/user_guide/08-compute-operations.md @@ -588,7 +588,7 @@ The Cube unit performs matrix multiplication. Its operands are typed pointers in ### 8.3.1 Matrix multiply: `pto.mad` -#### `pto.mad(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, k: int, n: int) -> None` +#### `pto.mad(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` **Description**: Zero-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N]`. `lhs` is an L0A pointer, `rhs` is an L0B pointer, `dst` is an L0C pointer. @@ -626,22 +626,22 @@ A full cube matmul follows a three-stage pattern: stage operands into L0A/L0B, c ```python @pto.cube def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): - m = pto.tile_valid_rows(q_tile) - k = pto.tile_valid_cols(q_tile) - n = pto.tile_valid_rows(k_tile) + m = q_tile.valid_shape[0] + k = q_tile.valid_shape[1] + n = k_tile.valid_shape[0] # Stage: UB → L0A / L0B - pto.mte_l1_l0a(q_tile, q_l0a, m, k) - pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) # Compute: L0A × L0B → L0C - pto.mad(q_l0a, k_l0b, s_acc, m, k, n) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) # Writeback: L0C → UB - pto.mte_l0c_ub(s_acc, s_tile, m, n) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) ``` -The `mte_l1_l0a`/`mte_l1_l0b` stage operands from UB into cube-local buffers. `mad` performs the matrix multiply into L0C. `mte_l0c_ub` writes the result back to a UB tile for downstream processing. +The `mte_l1_l0a`/`mte_l1_l0b` stage operands from UB into cube-local buffers. `mad` performs the matrix multiply into L0C. `mte_l0c_ub` writes the result back to a UB tile for downstream processing. At this micro-op layer, the operands are explicit pointer views obtained with `.as_ptr()`. --- @@ -649,7 +649,7 @@ The `mte_l1_l0a`/`mte_l1_l0b` stage operands from UB into cube-local buffers. `m | Operation | Semantics | |-----------|-----------| -| `pto.mad(lhs, rhs, dst, m, k, n)` | `dst = lhs * rhs` (zero-init) | +| `pto.mad(lhs, rhs, dst, m, n, k)` | `dst = lhs * rhs` (zero-init) | | `pto.mad_acc(lhs, rhs, dst, m, k, n)` | `dst += lhs * rhs` (accumulating) | | `pto.mad_bias(lhs, rhs, dst, bias, m, k, n)` | `dst = lhs * rhs + bias` | | `pto.mad_mx(lhs, rhs, dst, m, k, n)` | MX-format zero-init matmul | diff --git a/ptodsl/docs/user_guide/10-sync-ops.md b/ptodsl/docs/user_guide/10-sync-ops.md index 7124f33dd..a0a9b3a03 100644 --- a/ptodsl/docs/user_guide/10-sync-ops.md +++ b/ptodsl/docs/user_guide/10-sync-ops.md @@ -26,9 +26,6 @@ Memory barrier types used with `pto.mem_bar`. Each value specifies which categor | `ST_VLD` | Scalar stores before → vector loads after | | `LD_VST` | Scalar loads before → vector stores after | | `ST_VST` | Scalar stores before → vector stores after | -| `SYNC` | Full ordering — all prior memory operations (all pipes) complete before any subsequent operation | - -`SYNC` is a convenience value equivalent to a full pipeline barrier. It is the idiomatic choice for separating compute phases inside a ukernel when fine-grained barrier types are not needed. The naming convention: `V` = vector, `S` = scalar, `ST` = store, `LD` = load. `VST_VLD` reads "Vector STore before Vector LoaD." @@ -262,7 +259,8 @@ The most commonly used barrier types in practice: ### Usage in ukernel blocks -In flash attention, `mem_bar` separates logically independent computation phases within the same ukernel: +In flash attention, phase boundaries use `pipe_barrier(Pipe.ALL)`, while +`mem_bar` remains the tool for narrower intra-pipeline ordering: ```python @pto.ukernel @@ -270,23 +268,23 @@ def flash_attention_block(q_tile, k_tile, v_tile, ...): # Phase 1: load K/V pto.mte_load(k_part, k_tile) pto.mte_load(v_part, v_tile) - pto.mem_bar(BarrierType.SYNC) + pto.pipe_barrier(Pipe.ALL) # Phase 2: S = Q @ K^T qk_matmul(q_tile, k_tile, ...) - pto.mem_bar(BarrierType.SYNC) + pto.pipe_barrier(Pipe.ALL) # Phase 3: softmax(S) online_softmax(s_tile, ...) - pto.mem_bar(BarrierType.SYNC) + pto.pipe_barrier(Pipe.ALL) # Phase 4: PV = P @ V pv_matmul(p_tile, v_tile, ...) - pto.mem_bar(BarrierType.SYNC) + pto.pipe_barrier(Pipe.ALL) # Phase 5: blend output blend_output(o_prev_tile, pv_tile, ...) - pto.mem_bar(BarrierType.SYNC) + pto.pipe_barrier(Pipe.ALL) ``` --- diff --git a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md index 4d7b05c5b..154a88eac 100644 --- a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md +++ b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md @@ -102,20 +102,32 @@ head_idx = block_idx % heads The launch grid is `[batch * heads]`. Each block computes one `(batch, head)` slice. `get_block_idx()` returns the current block's linear index; dividing by `heads` recovers the batch and head indices. -### 11.3.3 Per-head view selection +### 11.3.3 Per-head view partitioning ```python -q_head = pto.select_head_view(q_view, batch=batch_idx, head=head_idx, - shape=[seq_q, dim]) -k_head = pto.select_head_view(k_view, batch=batch_idx, head=head_idx, - shape=[seq_k, dim]) -v_head = pto.select_head_view(v_view, batch=batch_idx, head=head_idx, - shape=[seq_k, dim]) -o_head = pto.select_head_view(o_view, batch=batch_idx, head=head_idx, - shape=[seq_q, dim]) +q_head = pto.partition_view( + q_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], +) +k_head = pto.partition_view( + k_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], +) +v_head = pto.partition_view( + v_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], +) +o_head = pto.partition_view( + o_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], +) ``` -`select_head_view` extracts a 2D slice `[seq, dim]` from the 4D tensor view for the current head. The resulting views are the working set for this block's entire computation. +There is no dedicated `select_head_view` public helper anymore. Each `(batch, head)` working set is sliced from the 4D TensorView with the standard `partition_view(...)` surface, and further logical slicing composes on top of the same primitive. ### 11.3.4 Tile allocation @@ -124,22 +136,22 @@ Two categories of tiles are allocated: **UB-resident tiles** — data tiles that live in the Unified Buffer: ```python -q_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) -k_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) -v_tile = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f32) - -o_prev_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) -o_next_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) -m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) -m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) -l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) -l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) - -s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) -p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32) -pv_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32) -alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) -beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32) +q_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +k_tile = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, valid_shape=[full_bc, dim]) +v_tile = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, valid_shape=[full_bc, dim]) + +o_prev_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +o_next_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + +s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") ``` The online-softmax algorithm requires **ping-pong state tiles**: `m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next`. After each KV block, `next` becomes `prev` for the following iteration. @@ -147,16 +159,16 @@ The online-softmax algorithm requires **ping-pong state tiles**: `m_prev`/`m_nex **Cube-local scratch tiles** — allocated in specific memory spaces: ```python -q_l0a = pto.alloc_tile(shape=[Br, dim], dtype=pto.f16, - memory_space=pto.MemorySpace.LEFT) +q_l0a = pto.alloc_tile(shape=[Br, D], dtype=pto.f16, + memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, dim]) p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f16, - memory_space=pto.MemorySpace.LEFT) -rhs_l0b = pto.alloc_tile(shape=[Bc, dim], dtype=pto.f16, - memory_space=pto.MemorySpace.RIGHT) + memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, full_bc]) +rhs_l0b = pto.alloc_tile(shape=[Bc, D], dtype=pto.f16, + memory_space=pto.MemorySpace.RIGHT, valid_shape=[full_bc, dim]) qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, - memory_space=pto.MemorySpace.ACC) -pv_acc_tile = pto.alloc_tile(shape=[Br, dim], dtype=pto.f32, - memory_space=pto.MemorySpace.ACC) + memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, full_bc]) +pv_acc_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, + memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, dim]) ``` Cube scratch tiles are NOT UB buffers. `LEFT`, `RIGHT`, and `ACC` are distinct hardware memory spaces inside the Cube unit. They serve as staging for matrix operands and accumulators. @@ -164,20 +176,25 @@ Cube scratch tiles are NOT UB buffers. `LEFT`, `RIGHT`, and `ACC` are distinct h ### 11.3.5 SIMT metadata buffer ```python -meta_tile = pto.alloc_tile(shape=[3, 1], dtype=pto.i32) -meta_ptr = pto.tile_buf_addr(meta_tile) +meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) +meta_ptr = meta_tile.as_ptr() ``` -A small UB tile stores three scalar loop bounds (`row_start`, `row_stop`, `valid_cols`). `tile_buf_addr` materializes a typed UB pointer into it, which is passed to the ukernel as scalar control metadata. +A small UB tile stores three scalar loop bounds (`row_start`, `row_stop`, `valid_cols`). `meta_tile.as_ptr()` materializes a typed UB pointer into it, which is passed to the ukernel as scalar control metadata. + +Notice that the row-wise softmax state tiles (`m_*`, `l_*`, `alpha_tile`, +`beta_tile`) are authored as `blayout="ColMajor"`. This is the intended public +surface for logical column vectors; it avoids forcing users to manufacture a +row-major padded physical width just to satisfy row-byte alignment. ### 11.3.6 Outer Q loop + inner KV loop ```python with pto.for_(0, q_blocks, step=1) as qi: - q_part = pto.partition_view(q_head, offsets=[qi * Br, 0], - sizes=[Br, dim]) - o_part = pto.partition_view(o_head, offsets=[qi * Br, 0], - sizes=[Br, dim]) + q_part = pto.partition_view(q_head, offsets=[0, qi * Br, 0, 0], + sizes=[1, Br, 1, dim]) + o_part = pto.partition_view(o_head, offsets=[0, qi * Br, 0, 0], + sizes=[1, Br, 1, dim]) pto.tload(q_part, q_tile) @@ -194,9 +211,9 @@ with pto.for_(0, q_blocks, step=1) as qi: l_cur = kv_loop.l o_cur = kv_loop.o k_part = pto.partition_view(k_head, - offsets=[kj * Bc, 0], sizes=[Bc, dim]) + offsets=[0, kj * Bc, 0, 0], sizes=[1, Bc, 1, dim]) v_part = pto.partition_view(v_head, - offsets=[kj * Bc, 0], sizes=[Bc, dim]) + offsets=[0, kj * Bc, 0, 0], sizes=[1, Bc, 1, dim]) kv_block_process( q_tile, k_part, v_part, k_tile, v_tile, @@ -245,20 +262,20 @@ The ukernel processes one KV block against an already-loaded Q tile. It owns the ```python pto.mte_load(k_part, k_tile) pto.mte_load(v_part, v_tile) -pto.mem_bar(pto.BarrierType.SYNC) +pto.pipe_barrier(pto.Pipe.ALL) ``` -`mte_load` copies the current K and V block from GM to UB. `mem_bar` ensures the DMA stores are visible before the cube unit reads `k_tile`/`v_tile`. +`mte_load` copies the current K and V block from GM to UB. `pipe_barrier(Pipe.ALL)` makes the phase boundary explicit before the cube unit reads `k_tile`/`v_tile`. ### Phase 0b — Materialize loop bounds ```python materialize_tile_bounds(meta_ptr, - pto.tile_valid_rows(q_tile), - pto.tile_valid_rows(k_tile)) + q_tile.valid_shape[0], + k_tile.valid_shape[0]) row_start = scalar.load(meta_ptr + 0) -row_stop = scalar.load(meta_ptr + 4) -valid_cols = scalar.load(meta_ptr + 8) +row_stop = scalar.load(meta_ptr + 1) +valid_cols = scalar.load(meta_ptr + 2) ``` The SIMT sub-kernel `materialize_tile_bounds` writes `{0, valid_rows, valid_cols}` into the metadata buffer. The ukernel then loads these scalars. They control the row iteration range in subsequent sub-kernels, handling partial tail blocks. @@ -267,10 +284,10 @@ The SIMT sub-kernel `materialize_tile_bounds` writes `{0, valid_rows, valid_cols ```python qk_matmul(q_tile, k_tile, q_l0a, rhs_l0b, qk_acc_tile, s_tile) -pto.mem_bar(pto.BarrierType.SYNC) +pto.pipe_barrier(pto.Pipe.ALL) ``` -Dispatches the cube sub-kernel. `mem_bar` separates the matrix multiply from the subsequent softmax. +Dispatches the cube sub-kernel. `pipe_barrier(Pipe.ALL)` separates the matrix multiply from the subsequent softmax. ### Phase 2 — Online softmax @@ -282,7 +299,7 @@ online_softmax_rows( alpha_tile, beta_tile, row_start, row_stop, valid_cols, ) -pto.mem_bar(pto.BarrierType.SYNC) +pto.pipe_barrier(pto.Pipe.ALL) ``` The simd sub-kernel computes per-row softmax on `S`, updates the running `m`/`l` state, and writes `P`, `alpha`, and `beta`. @@ -291,7 +308,7 @@ The simd sub-kernel computes per-row softmax on `S`, updates the running `m`/`l` ```python pv_matmul(p_tile, v_tile, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) -pto.mem_bar(pto.BarrierType.SYNC) +pto.pipe_barrier(pto.Pipe.ALL) ``` Second cube dispatch. `rhs_l0b` is reused for `V` (it previously held `K`). `pv_acc_tile` is reused from the QK^T accumulator. @@ -302,16 +319,16 @@ Second cube dispatch. `rhs_l0b` is reused for `V` (it previously held `K`). `pv_ blend_output_rows( o_prev_tile, pv_tile, alpha_tile, beta_tile, o_next_tile, row_start, row_stop, - pto.tile_valid_cols(v_tile), + v_tile.valid_shape[1], ) -pto.mem_bar(pto.BarrierType.SYNC) +pto.pipe_barrier(pto.Pipe.ALL) ``` The simt sub-kernel blends the old output accumulator with the new PV contribution, weighted by `alpha` and `beta`. ### Why the ukernel owns sync -Each `mem_bar` between phases is explicit in the ukernel body. This is intentional: at the L2 micro-instruction level, the user controls pipeline ordering. There is no auto-sync insertion — the ukernel is the single place where the hardware execution sequence is spelled out. +Each `pipe_barrier(Pipe.ALL)` between phases is explicit in the ukernel body. This is intentional: at the L2 micro-instruction level, the user controls pipeline ordering. There is no auto-sync insertion — the ukernel is the single place where the hardware execution sequence is spelled out. ## 11.5 L3a — `@pto.cube` sub-kernels @@ -320,14 +337,14 @@ Each `mem_bar` between phases is explicit in the ukernel body. This is intention ```python @pto.cube def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): - m = pto.tile_valid_rows(q_tile) - k = pto.tile_valid_cols(q_tile) - n = pto.tile_valid_rows(k_tile) - - pto.mte_l1_l0a(q_tile, q_l0a, m, k) - pto.mte_l1_l0b(k_tile, k_l0b, k, n, transpose=True) - pto.mad(q_l0a, k_l0b, s_acc) - pto.mte_l0c_ub(s_acc, s_tile, m, n) + m = q_tile.valid_shape[0] + k = q_tile.valid_shape[1] + n = k_tile.valid_shape[0] + + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) ``` Four cube ops: @@ -344,14 +361,14 @@ The cube kernel does not allocate scratch — the caller (L1) owns scratch lifet ```python @pto.cube def pv_matmul(p_tile, v_tile, p_l0a, v_l0b, pv_acc, pv_tile): - m = pto.tile_valid_rows(p_tile) - k = pto.tile_valid_cols(p_tile) - n = pto.tile_valid_cols(v_tile) - - pto.mte_l1_l0a(p_tile, p_l0a, m, k) - pto.mte_l1_l0b(v_tile, v_l0b, k, n) - pto.mad(p_l0a, v_l0b, pv_acc) - pto.mte_l0c_ub(pv_acc, pv_tile, m, n) + m = p_tile.valid_shape[0] + k = p_tile.valid_shape[1] + n = v_tile.valid_shape[1] + + pto.mte_l1_l0a(p_tile.as_ptr(), p_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(v_tile.as_ptr(), v_l0b.as_ptr(), k, n) + pto.mad(p_l0a.as_ptr(), v_l0b.as_ptr(), pv_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(pv_acc.as_ptr(), pv_tile.as_ptr(), m, n, n, n, 0) ``` Structurally identical to `qk_matmul`, but without transposition and with different input/output tiles. The scratch tiles `p_l0a`, `v_l0b`, and `pv_acc` are reused across KV blocks — the caller (L1) allocates them once. @@ -415,10 +432,10 @@ This implements the online-softmax update from the Flash Attention paper: ```python pto.vsts(p_row, p_tile[row, 0:], col_mask) - scalar.sts(m_next_tile[row, 0], m_next) - scalar.sts(l_next_tile[row, 0], l_next) - scalar.sts(alpha_tile[row, 0], alpha) - scalar.sts(beta_tile[row, 0], beta) + scalar.store(m_next, m_next_tile[row, 0]) + scalar.store(l_next, l_next_tile[row, 0]) + scalar.store(alpha, alpha_tile[row, 0]) + scalar.store(beta, beta_tile[row, 0]) ``` - `vsts` stores the vector `p_row` back to UB under the column mask. @@ -433,12 +450,12 @@ This implements the online-softmax update from the Flash Attention paper: ```python @pto.simt def materialize_tile_bounds(meta_ptr, valid_rows, valid_cols): - scalar.sts(meta_ptr + 0, 0) - scalar.sts(meta_ptr + 4, valid_rows) - scalar.sts(meta_ptr + 8, valid_cols) + scalar.store(0, meta_ptr + 0) + scalar.store(valid_rows, meta_ptr + 1) + scalar.store(valid_cols, meta_ptr + 2) ``` -Three scalar stores write the loop bounds into the metadata buffer. `meta_ptr` is a typed UB pointer; `+ 0`, `+ 4`, `+ 8` are byte offsets (three `i32` values). This is the simplest sub-kernel in the sketch — it handles scalar control metadata, not vector math. +Three scalar stores write the loop bounds into the metadata buffer. `meta_ptr` is a typed UB pointer; `+ 0`, `+ 1`, `+ 2` are element offsets into `i32` storage, not byte offsets. This is the simplest sub-kernel in the sketch — it handles scalar control metadata, not vector math. ### `blend_output_rows` — output accumulation @@ -454,7 +471,7 @@ def blend_output_rows(o_prev_tile, pv_tile, alpha_tile, beta_tile, o_prev = scalar.load(o_prev_tile[row, col]) pv_val = scalar.load(pv_tile[row, col]) o_next = alpha * o_prev + beta * pv_val - scalar.sts(o_next_tile[row, col], o_next) + scalar.store(o_next, o_next_tile[row, col]) ``` This is a scalar element-wise blend over the tile domain: @@ -476,15 +493,15 @@ For trivial sub-kernels like `materialize_tile_bounds`, a named function is over def kv_block_process(...): pto.mte_load(k_part, k_tile) pto.mte_load(v_part, v_tile) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) # Inline SIMT: materialize loop bounds (replaces the named @pto.simt function) with pto.simt(): - scalar.sts(meta_ptr + 0, 0) - scalar.sts(meta_ptr + 4, valid_rows) - scalar.sts(meta_ptr + 8, valid_cols) + scalar.store(0, meta_ptr + 0) + scalar.store(valid_rows, meta_ptr + 1) + scalar.store(valid_cols, meta_ptr + 2) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) qk_matmul(q_tile, k_tile, ...) ... diff --git a/ptodsl/docs/user_guide/12-additional-examples.md b/ptodsl/docs/user_guide/12-additional-examples.md index 234d2981c..e7b730dc3 100644 --- a/ptodsl/docs/user_guide/12-additional-examples.md +++ b/ptodsl/docs/user_guide/12-additional-examples.md @@ -153,14 +153,14 @@ This example demonstrates a complete GEMM kernel: `C = A @ B` where A is `[M, K] @pto.cube def gemm_tile(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, a_l0a: pto.Tile, b_l0b: pto.Tile, o_acc: pto.Tile): - m = pto.tile_valid_rows(a_tile) - k = pto.tile_valid_cols(a_tile) - n = pto.tile_valid_rows(b_tile) - - pto.mte_l1_l0a(a_tile, a_l0a, m, k) - pto.mte_l1_l0b(b_tile, b_l0b, k, n, transpose=True) - pto.mad(a_l0a, b_l0b, o_acc) - pto.mte_l0c_ub(o_acc, o_tile, m, n) + m = a_tile.valid_shape[0] + k = a_tile.valid_shape[1] + n = b_tile.valid_shape[0] + + pto.mte_l1_l0a(a_tile.as_ptr(), a_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(b_tile.as_ptr(), b_l0b.as_ptr(), k, n, transpose=True) + pto.mad(a_l0a.as_ptr(), b_l0b.as_ptr(), o_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(o_acc.as_ptr(), o_tile.as_ptr(), m, n, n, n, 0) ``` The cube sub-kernel consumes UB tiles and cube-local scratch buffers. The four-step sequence — stage left operand, stage right operand, multiply, writeback — is the canonical cube compute pattern. @@ -308,12 +308,12 @@ def norm_block(x_part: pto.PartitionTensorView, x_tile: pto.Tile, mu_next_tile: pto.Tile, n_next_tile: pto.Tile, m2_next_tile: pto.Tile): pto.mte_load(x_part, x_tile) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) block_mean_var(x_tile, block_size, mu_prev, n_prev, m2_prev, mu_next_tile, n_next_tile, m2_next_tile) - pto.mem_bar(pto.BarrierType.SYNC) + pto.pipe_barrier(pto.Pipe.ALL) ``` ### 12.4.3 L1: JIT entry with carry state diff --git a/ptodsl/examples/softmax_dsl.py b/ptodsl/examples/softmax_dsl.py index ec8311c64..decdaa4d6 100644 --- a/ptodsl/examples/softmax_dsl.py +++ b/ptodsl/examples/softmax_dsl.py @@ -11,7 +11,7 @@ Generates the same IR as test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto -using the ``@pto.to_ir`` decorator and the ``pto.*`` namespace. +using the ``@pto.jit`` decorator and the ``pto.*`` namespace. The Python maps almost line-for-line to the target MLIR: @@ -34,7 +34,7 @@ } # } # } # - pto.barrier # pto.barrier_all() + pto.barrier # pto.pipe_barrier(pto.Pipe.ALL) """ from ptodsl import pto @@ -42,10 +42,10 @@ s = pto.scalar # arith shorthand alias -@pto.to_ir( +@pto.jit( name="online_softmax_update_kernel_2d", kernel_kind="vector", - arch="a5", + target="a5", func_attr="pto.aicore", ) def online_softmax_update_kernel_2d( @@ -155,13 +155,13 @@ def online_softmax_update_kernel_2d( ptr_ub = pto.ptr(pto.float32, "ub") vf32 = pto.vreg_type(64, pto.float32) - ub_om = pto.tile_ptr(oldmax_tile, ptr_ub) - ub_os = pto.tile_ptr(oldsum_tile, ptr_ub) - ub_qk = pto.tile_ptr(qk_tile, ptr_ub) - ub_out = pto.tile_ptr(out_tile, ptr_ub) - ub_nm = pto.tile_ptr(newmax_tile, ptr_ub) - ub_ns = pto.tile_ptr(newsum_tile, ptr_ub) - ub_em = pto.tile_ptr(expmax_tile, ptr_ub) + ub_om = pto.as_ptr(oldmax_tile, ptr_ub) + ub_os = pto.as_ptr(oldsum_tile, ptr_ub) + ub_qk = pto.as_ptr(qk_tile, ptr_ub) + ub_out = pto.as_ptr(out_tile, ptr_ub) + ub_nm = pto.as_ptr(newmax_tile, ptr_ub) + ub_ns = pto.as_ptr(newsum_tile, ptr_ub) + ub_em = pto.as_ptr(expmax_tile, ptr_ub) active = pto.pset_b32("PAT_ALL") one_mask, _ = pto.plt_b32(c1_i32) @@ -234,11 +234,11 @@ def online_softmax_update_kernel_2d( pto.tstore(expmax_tile, expmax_part) pto.tstore(out_tile, out_part) - pto.barrier_all() + pto.pipe_barrier(pto.Pipe.ALL) def build(): - return online_softmax_update_kernel_2d._ir_module + return online_softmax_update_kernel_2d.mlir_module() if __name__ == "__main__": diff --git a/ptodsl/examples/tadd_dsl.py b/ptodsl/examples/tadd_dsl.py index 96983c2c0..b55a65693 100644 --- a/ptodsl/examples/tadd_dsl.py +++ b/ptodsl/examples/tadd_dsl.py @@ -7,14 +7,14 @@ # See LICENSE in the root of the software repository for the full text of the License. """ -TADD vPTO kernel – DSL-style builder. +TADD kernel – DSL-style builder. Generates the same IR as expand_tileop_to_vpto_result.pto using the -``@pto.to_ir`` decorator and the ``pto.*`` namespace. +``@pto.jit`` decorator and the ``pto.*`` namespace. The Python code maps 1-to-1 to the MLIR IR lines: - func.func @TADD() { # @pto.to_ir(name="TADD", …) + func.func @TADD() { # @pto.jit(name="TADD", …) %c0_i64 = arith.constant 0 : i64 # pto.const(0, dtype=pto.int64) %c16 = arith.constant 16 : index # pto.const(16, dtype=pto.index) … @@ -33,7 +33,7 @@ s = pto.scalar # arith shorthand alias -@pto.to_ir(name="TADD", kernel_kind="vector", arch="a5") +@pto.jit(name="TADD", kernel_kind="vector", target="a5") def TADD(): c0_i64 = pto.const(0, dtype=pto.int64) c16 = pto.const(16, dtype=pto.index) @@ -60,7 +60,7 @@ def TADD(): def build(): - return TADD._ir_module + return TADD.mlir_module() if __name__ == "__main__": diff --git a/ptodsl/ptodsl/__init__.py b/ptodsl/ptodsl/__init__.py index cfd6e6537..a0722d975 100644 --- a/ptodsl/ptodsl/__init__.py +++ b/ptodsl/ptodsl/__init__.py @@ -9,11 +9,11 @@ from importlib import import_module -__all__ = ["pto", "scalar", "vpto"] +__all__ = ["pto", "scalar"] def __getattr__(name): - if name in {"pto", "scalar", "vpto"}: + if name in {"pto", "scalar"}: module = import_module(f".{name}", __name__) globals()[name] = module return module diff --git a/ptodsl/ptodsl/_bootstrap.py b/ptodsl/ptodsl/_bootstrap.py index 50449e312..ec9a6707a 100644 --- a/ptodsl/ptodsl/_bootstrap.py +++ b/ptodsl/ptodsl/_bootstrap.py @@ -37,17 +37,22 @@ def _candidate_python_roots() -> list[Path]: def _bootstrap_python_paths() -> None: - added = set() + ordered_roots: list[str] = [] + seen = set() for root in _candidate_python_roots(): if not root or not root.is_dir(): continue if not (root / "mlir").exists(): continue root_text = str(root) - if root_text in added or root_text in sys.path: + if root_text in seen: continue + ordered_roots.append(root_text) + seen.add(root_text) + for root_text in reversed(ordered_roots): + if root_text in sys.path: + sys.path.remove(root_text) sys.path.insert(0, root_text) - added.add(root_text) _bootstrap_python_paths() diff --git a/ptodsl/ptodsl/_control_flow.py b/ptodsl/ptodsl/_control_flow.py index 5d20b91e6..23f3a12f4 100644 --- a/ptodsl/ptodsl/_control_flow.py +++ b/ptodsl/ptodsl/_control_flow.py @@ -14,12 +14,15 @@ ────────── ``vecscope()`` – ``pto.vecscope { … }`` ``for_(lo, hi, step, *, iter_args)`` - – ``scf.for`` with optional iter_args + – ``scf.for`` with optional iter_args or named carry state ``if_(cond, *, results)`` – ``scf.if`` with optional results + else ``yield_(*vals)`` – ``scf.yield`` """ from ._bootstrap import make_context # noqa: F401 +from ._runtime_index_ops import coerce_runtime_index +from ._tracing.active import current_session +from ._surface_values import unwrap_surface_value, wrap_like_surface_value, wrap_surface_value from ._types import _resolve from mlir.dialects import pto as _pto, scf @@ -60,20 +63,27 @@ class LoopHandle: loop.results – tuple of ForOp results (after loop exit) """ - def __init__(self, for_op): + def __init__(self, for_op, *, iter_arg_templates=()): self._op = for_op + self._iter_arg_templates = tuple(iter_arg_templates) @property def iv(self): - return self._op.induction_variable + return wrap_surface_value(self._op.induction_variable) @property def iter_args(self): - return tuple(self._op.inner_iter_args) + return tuple( + wrap_like_surface_value(template, value) + for template, value in zip(self._iter_arg_templates, self._op.inner_iter_args) + ) @property def results(self): - return tuple(self._op.results) + return tuple( + wrap_like_surface_value(template, value) + for template, value in zip(self._iter_arg_templates, self._op.results) + ) class _ForCM: @@ -81,20 +91,23 @@ def __init__(self, start, stop, step, iter_args): self._start = start self._stop = stop self._step = step - self._iter_args = list(iter_args) if iter_args is not None else [] + self._iter_arg_templates = tuple(iter_args) if iter_args is not None else () + self._iter_args = [unwrap_surface_value(value) for value in self._iter_arg_templates] self._for_op = None self._ip = None def __enter__(self): self._for_op = scf.ForOp( - self._start, self._stop, self._step, + _coerce_index(self._start), + _coerce_index(self._stop), + _coerce_index(self._step), self._iter_args if self._iter_args else None, ) self._ip = InsertionPoint(self._for_op.body) self._ip.__enter__() if not self._iter_args: - return self._for_op.induction_variable - return LoopHandle(self._for_op) + return wrap_surface_value(self._for_op.induction_variable) + return LoopHandle(self._for_op, iter_arg_templates=self._iter_arg_templates) def __exit__(self, *exc): if not self._iter_args: @@ -102,7 +115,7 @@ def __exit__(self, *exc): self._ip.__exit__(*exc) -def for_(start, stop, *, step, iter_args=None) -> _ForCM: +def for_(start, stop, *, step, iter_args=None): """ ``scf.for`` context manager. @@ -120,8 +133,153 @@ def for_(start, stop, *, step, iter_args=None) -> _ForCM: ... pto.yield_(nx, ny) fa, fb = loop.results + + Named carry state is expressed with ``.carry(...)``:: + + loop = pto.for_(c0, c128, step=c64).carry(acc=tile) + with loop: + cur = loop.acc + loop.update(acc=cur) + out = loop.final("acc") """ - return _ForCM(start, stop, step, iter_args) + return _ForBuilder(start, stop, step, iter_args) + + +class _CarryLoopStateView: + def __init__(self, names, values): + self._names = tuple(names) + self._values = dict(zip(self._names, values)) + + def __getattr__(self, name): + try: + return self._values[name] + except KeyError as exc: + raise AttributeError(name) from exc + + +class _CarryForCM(_ForCM): + def __init__(self, start, stop, step, state_items): + self._state_items = tuple(state_items) + self._state_names = tuple(name for name, _ in self._state_items) + self._state_templates = tuple(value for _, value in self._state_items) + self._session = None + self._session_frame = None + super().__init__(start, stop, step, self._state_templates) + self._yield_values = None + self._entered = False + + def __enter__(self): + self._session = current_session() + if self._session is not None: + self._session_frame = self._session.begin_carry_loop( + self._start, + self._stop, + self._step, + self._state_items, + ) + self._for_op = self._session_frame.for_op + handle = LoopHandle(self._for_op, iter_arg_templates=self._state_templates) + else: + handle = super().__enter__() + self._entered = True + self._yield_values = None + self._loop_handle = handle + self._state = _CarryLoopStateView(self._state_names, handle.iter_args) + return self + + def __exit__(self, exc_type, exc, tb): + try: + if self._session_frame is not None: + self._session.finish_carry_loop(self._session_frame, exc_type, exc, tb) + return None + if exc_type is None: + if self._yield_values is None: + raise RuntimeError( + "pto.for_(...).carry(...) requires loop.update(...) before leaving the loop body" + ) + scf.YieldOp(self._yield_values) + return super().__exit__(exc_type, exc, tb) + finally: + self._entered = False + self._session = None + self._session_frame = None + + @property + def iv(self): + if not self._entered: + raise RuntimeError("loop.iv is only available inside an active carry loop body") + return self._loop_handle.iv + + def __getattr__(self, name): + if name in self._state_names: + if not self._entered: + raise RuntimeError(f"loop.{name} is only available inside an active carry loop body") + return getattr(self._state, name) + raise AttributeError(name) + + def update(self, **kwargs): + if not self._entered: + raise RuntimeError("loop.update(...) may only be called inside the loop body") + if self._session_frame is not None: + self._session.update_carry_loop(self._session_frame, **kwargs) + return + missing = [name for name in self._state_names if name not in kwargs] + extra = [name for name in kwargs if name not in self._state_names] + if missing or extra: + pieces = [] + if missing: + pieces.append(f"missing: {', '.join(missing)}") + if extra: + pieces.append(f"unexpected: {', '.join(extra)}") + raise RuntimeError("loop.update(...) must match carry names exactly; " + "; ".join(pieces)) + if self._yield_values is not None: + raise RuntimeError("loop.update(...) may only be called once per loop body") + self._yield_values = [ + unwrap_surface_value(kwargs[name]) + for name in self._state_names + ] + + def final(self, name): + if self._for_op is None: + raise RuntimeError("loop.final(...) is only available after the loop has been built") + try: + index = self._state_names.index(name) + except ValueError as exc: + raise RuntimeError( + f"loop.final(...) requested unknown carry state '{name}'; " + f"expected one of: {', '.join(self._state_names)}" + ) from exc + return wrap_like_surface_value(self._state_templates[index], self._for_op.results[index]) + + +class _ForBuilder: + def __init__(self, start, stop, step, iter_args=None): + self._start = start + self._stop = stop + self._step = step + self._iter_args = iter_args + + def __enter__(self): + self._cm = _ForCM(self._start, self._stop, self._step, self._iter_args) + return self._cm.__enter__() + + def __exit__(self, *exc): + return self._cm.__exit__(*exc) + + def carry(self, **kwargs): + if self._iter_args is not None: + raise RuntimeError("for_(..., iter_args=...) cannot be combined with .carry(...)") + if not kwargs: + raise ValueError("carry(...) requires at least one named loop-carried value") + for name in kwargs: + if not isinstance(name, str) or not name: + raise TypeError("carry(...) names must be non-empty strings") + return _CarryForCM(self._start, self._stop, self._step, tuple(kwargs.items())) + + +def _coerce_index(value): + raw_value = unwrap_surface_value(value) + return coerce_runtime_index(raw_value, context="pto.for_(...) loop bound") # ── if_ ─────────────────────────────────────────────────────────────────────── @@ -163,7 +321,7 @@ def __init__(self, if_op): @property def results(self): - return tuple(self._op.results) + return tuple(wrap_surface_value(result) for result in self._op.results) class _IfCM: @@ -174,14 +332,15 @@ def __init__(self, cond, result_types): self._ip = None def __enter__(self): + cond = unwrap_surface_value(self._cond) if self._result_types: # if/else with results: create IfOp but don't enter any block; # the caller manages blocks via br.then_ / br.else_ - self._if_op = scf.IfOp(self._cond, self._result_types, hasElse=True) + self._if_op = scf.IfOp(cond, self._result_types, hasElse=True) return BranchHandle(self._if_op) else: # simple if without results: enter then_block automatically - self._if_op = scf.IfOp(self._cond) + self._if_op = scf.IfOp(cond) self._ip = InsertionPoint(self._if_op.then_block) self._ip.__enter__() return None @@ -221,7 +380,7 @@ def if_(cond, *, results=None) -> _IfCM: def yield_(*vals): """Emit ``scf.yield`` with the given values.""" - scf.YieldOp(list(vals)) + scf.YieldOp([unwrap_surface_value(value) for value in vals]) __all__ = [ diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py new file mode 100644 index 000000000..48af69166 --- /dev/null +++ b/ptodsl/ptodsl/_diagnostics.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Shared user-facing diagnostics for PTODSL tracing misuse.""" + +from __future__ import annotations + + +class PTODSLTracingMisuseError(TypeError): + """Raised when authored Python misuses PTODSL runtime values during tracing.""" + + +def native_python_control_flow_error(usage: str) -> PTODSLTracingMisuseError: + """Return one actionable diagnostic for native Python control-flow misuse.""" + return PTODSLTracingMisuseError( + f"native Python {usage} cannot consume a PTODSL runtime value during tracing. " + "This value is a device-side SSA/runtime-metadata value, not a Python bool/int. " + "Use pto.if_(...) or pto.for_(...) for device-side control flow, or keep the " + "bound/condition in pto.constexpr." + ) + + +def host_tensor_metadata_error(message: str, *, param_name: str | None = None) -> TypeError: + """Return one actionable diagnostic for unsupported host-tensor metadata.""" + prefix = "host tensor metadata is incomplete or unsupported" + if param_name is not None: + prefix = f"@pto.jit host tensor '{param_name}' metadata is incomplete or unsupported" + return TypeError(f"{prefix}: {message}") + + +def subkernel_host_tensor_boundary_error(role: str, name: str) -> TypeError: + """Return one diagnostic for host-tensor usage outside the JIT boundary.""" + return TypeError( + f"@pto.{role} parameter '{name}' uses a host tensor value, but host tensors only belong " + "at the @pto.jit boundary. Pass PTODSL device-side values such as Tile, " + "PartitionTensorView, typed pointers, or PTO scalars instead." + ) + + +def subkernel_signature_boundary_error(role: str, name: str) -> TypeError: + """Return one diagnostic for illegal host-tensor formal annotations on a subkernel.""" + return TypeError( + f"@pto.{role} parameter '{name}' cannot be annotated with pto.tensor_spec(...). " + "Host tensors are only valid as @pto.jit positional parameters." + ) + + +def illegal_subkernel_placement_error(role: str, outer_role: str | None) -> RuntimeError: + """Return one diagnostic for a subkernel call placed outside the supported layer graph.""" + if role == "ukernel": + return RuntimeError( + "@pto.ukernel may only be called from the top-level @pto.jit body; " + f"nested invocation inside @pto.{outer_role} is not part of the PTODSL layer contract." + ) + if role == "simt": + return RuntimeError( + "@pto.simt helper materialization is only supported from the top-level @pto.jit body " + f"or inside @pto.ukernel; it cannot be materialized inside @pto.{outer_role}." + ) + return RuntimeError( + f"@pto.{role} may only be called from the top-level @pto.jit body or inside @pto.ukernel; " + f"nested invocation inside @pto.{outer_role} is not part of the PTODSL layer contract." + ) + + +def simd_value_escape_error(type_text: str) -> RuntimeError: + """Return one diagnostic for transient SIMD values escaping a simd subkernel boundary.""" + return RuntimeError( + f"@pto.simd cannot return transient SIMD values across the subkernel boundary " + f"(got {type_text}). Write the value back to a Tile/UB buffer instead." + ) + + +def tile_row_alignment_error(*, shape, dtype, row_bytes: int, required_alignment: int) -> TypeError: + """Return one diagnostic for authored tile shapes violating row-byte alignment.""" + return TypeError( + "alloc_tile(shape=...) physical row layout is invalid for the current PTODSL tile contract: " + f"shape={list(shape)!r} with dtype={dtype!r} gives a row byte size of {row_bytes}, " + f"but row-major none-box tiles must be {required_alignment}-byte aligned. " + "For logical column tiles such as [Br, 1], prefer blayout='ColMajor' instead of authoring them " + "as row-major narrow tiles. If row-major is truly required, keep the physical tile shape explicitly " + "aligned and express the logical tail with valid_shape=[...]." + ) + + +__all__ = [ + "PTODSLTracingMisuseError", + "host_tensor_metadata_error", + "illegal_subkernel_placement_error", + "native_python_control_flow_error", + "simd_value_escape_error", + "subkernel_host_tensor_boundary_error", + "subkernel_signature_boundary_error", + "tile_row_alignment_error", +] diff --git a/ptodsl/ptodsl/_host_tensors.py b/ptodsl/ptodsl/_host_tensors.py new file mode 100644 index 000000000..9f270642d --- /dev/null +++ b/ptodsl/ptodsl/_host_tensors.py @@ -0,0 +1,238 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Host-tensor boundary helpers for ``@pto.jit``.""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass + +from ._diagnostics import host_tensor_metadata_error +from ._types import _resolve, index, ptr + + +def _normalize_tensor_shape(shape): + try: + return tuple(int(dim) for dim in shape) + except TypeError as exc: + raise host_tensor_metadata_error("missing iterable .shape") from exc + except ValueError as exc: + raise host_tensor_metadata_error(".shape must contain integer-like dimensions") from exc + + +def _normalize_tensor_strides(tensor): + stride_method = getattr(tensor, "stride", None) + if callable(stride_method): + try: + return tuple(int(dim) for dim in stride_method()) + except TypeError as exc: + raise host_tensor_metadata_error(".stride() must return an iterable of integer-like dimensions") from exc + except ValueError as exc: + raise host_tensor_metadata_error(".stride() must return integer-like dimensions") from exc + strides = getattr(tensor, "strides", None) + if strides is None: + raise host_tensor_metadata_error("missing .strides or .stride()") + try: + return tuple(int(dim) for dim in strides) + except TypeError as exc: + raise host_tensor_metadata_error(".strides must be iterable") from exc + except ValueError as exc: + raise host_tensor_metadata_error(".strides must contain integer-like dimensions") from exc + + +def _extract_tensor_data_handle(tensor): + for attr_name in ("data_ptr", "ptr"): + attr = getattr(tensor, attr_name, None) + if callable(attr): + value = attr() + else: + value = attr + if value is not None: + try: + return int(value) + except (TypeError, ValueError) as exc: + raise host_tensor_metadata_error( + f"{attr_name} must return an integer-like data handle" + ) from exc + array_interface = getattr(tensor, "__array_interface__", None) + if array_interface is not None: + data = array_interface.get("data") + if isinstance(data, tuple) and data: + try: + return int(data[0]) + except (TypeError, ValueError) as exc: + raise host_tensor_metadata_error( + "__array_interface__['data'][0] must be an integer-like data handle" + ) from exc + raise host_tensor_metadata_error( + "missing data handle; expected .data_ptr(), .ptr, or __array_interface__" + ) + + +@dataclass(frozen=True) +class HostTensorMetadata: + """Concrete runtime metadata extracted from a Python host tensor.""" + + shape: tuple[int, ...] + strides: tuple[int, ...] + dtype: object + data_handle: int + + +def inspect_host_tensor_metadata(tensor) -> HostTensorMetadata: + """Extract shape / strides / dtype / data-handle from a Python tensor-like object.""" + shape = _normalize_tensor_shape(getattr(tensor, "shape", None)) + strides = _normalize_tensor_strides(tensor) + dtype = getattr(tensor, "dtype", None) + if dtype is None: + raise host_tensor_metadata_error("missing .dtype") + return HostTensorMetadata( + shape=shape, + strides=strides, + dtype=dtype, + data_handle=_extract_tensor_data_handle(tensor), + ) + + +@dataclass(frozen=True) +class TensorSpec: + """Static ABI hint for one Python-native ``@pto.jit`` tensor parameter.""" + + rank: int + dtype: object + address_space: str = "gm" + + def __post_init__(self): + if self.rank <= 0: + raise ValueError("tensor_spec(rank=...) expects a positive rank") + + def entry_arg_types(self): + data_type = _resolve(ptr(self.dtype, self.address_space)) + index_type = _resolve(index) + return ( + data_type, + *([index_type] * self.rank), + *([index_type] * self.rank), + ) + + def abi_signature(self): + return ( + "tensor_spec", + self.rank, + self.dtype, + self.address_space, + ) + + def __repr__(self): + return ( + f"pto.tensor_spec(rank={self.rank}, dtype={self.dtype!r}, " + f"address_space={self.address_space!r})" + ) + + +def tensor_spec(*, rank: int, dtype, address_space: str = "gm") -> TensorSpec: + """Declare the ABI contract of one Python-native ``@pto.jit`` tensor parameter.""" + return TensorSpec(rank=rank, dtype=dtype, address_space=address_space) + + +class HostTensorValue: + """Tracing-time proxy for one Python-native tensor at the ``@pto.jit`` boundary.""" + + def __init__(self, name: str, spec: TensorSpec, data_handle, shape, strides): + from ._surface_values import wrap_surface_value + self.name = name + self.spec = spec + self.data_handle = wrap_surface_value(data_handle) + self.shape = tuple(wrap_surface_value(dim) for dim in shape) + self.strides = tuple(wrap_surface_value(dim) for dim in strides) + self.dtype = spec.dtype + + @property + def rank(self): + return self.spec.rank + + def __repr__(self): + return ( + f"" + ) + + +def bind_host_tensor_argument(name: str, spec: TensorSpec, entry_arguments): + """Bind one flattened entry-ABI slice into a ``HostTensorValue``.""" + expected = 1 + spec.rank + spec.rank + if len(entry_arguments) < expected: + raise RuntimeError( + f"entry ABI for host tensor '{name}' is incomplete: expected {expected} " + f"arguments, got {len(entry_arguments)}" + ) + data_handle = entry_arguments[0] + shape = entry_arguments[1:1 + spec.rank] + strides = entry_arguments[1 + spec.rank:1 + spec.rank + spec.rank] + return ( + HostTensorValue(name, spec, data_handle, shape, strides), + entry_arguments[expected:], + ) + + +def infer_jit_host_tensor_spec(param: inspect.Parameter): + """ + Resolve one ``@pto.jit`` positional parameter to a host-tensor contract. + + V1 cannot infer rank/dtype from an unannotated formal parameter while still + tracing at compile time, so host tensors currently require an explicit + ``pto.tensor_spec(...)`` ABI hint. + """ + if isinstance(param.annotation, TensorSpec): + return param.annotation + if param.annotation is inspect.Parameter.empty: + raise TypeError( + f"@pto.jit positional parameter '{param.name}' uses the host-tensor " + "boundary but does not declare an ABI hint. Add an annotation such " + "as `Q: pto.tensor_spec(rank=4, dtype=pto.f32)`." + ) + return None + + +def resolve_tensor_data_entry(value): + """Return the pointer-like data entry behind a host tensor proxy or raw value.""" + if isinstance(value, HostTensorValue): + return value.data_handle + return value + + +def looks_like_host_tensor(value) -> bool: + """Best-effort predicate for Python-native tensor-like objects at the JIT boundary.""" + if isinstance(value, HostTensorValue): + return True + return ( + getattr(value, "shape", None) is not None + and getattr(value, "dtype", None) is not None + and ( + callable(getattr(value, "stride", None)) + or getattr(value, "strides", None) is not None + ) + and ( + callable(getattr(value, "data_ptr", None)) + or getattr(value, "ptr", None) is not None + or getattr(value, "__array_interface__", None) is not None + ) + ) + + +__all__ = [ + "HostTensorMetadata", + "TensorSpec", + "HostTensorValue", + "bind_host_tensor_argument", + "tensor_spec", + "infer_jit_host_tensor_spec", + "inspect_host_tensor_metadata", + "looks_like_host_tensor", + "resolve_tensor_data_entry", +] diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py new file mode 100644 index 000000000..bc25e2602 --- /dev/null +++ b/ptodsl/ptodsl/_jit.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""``@pto.jit`` decorator and compiled-kernel handles.""" + +from ._kernel_compilation import CompiledKernelHandle, KernelCompiler +from ._kernel_signature import parse_jit_kernel_signature +from ._tracing import ( + KernelModuleSpec, + ModuleArtifact, + ModuleStyle, +) + + +def jit( + name=None, + *, + target: str = "a5", + kernel_kind: str = "vector", + func_attr: str = None, +): + """ + Decorator that wraps a Python function as a PTODSL JIT kernel template. + + Parameters + ---------- + name: IR function name (defaults to the Python function name). + target: Target architecture string, e.g. ``"a5"``. + kernel_kind: ``"vector"`` or ``"cube"`` – sets ``pto.kernel_kind``. + func_attr: Optional function attribute. Pass ``"pto.aicore"`` to + select the flat-module structure with the aicore attribute. + + The decorated function is replaced by a :class:`KernelHandle` that: + + - supports ``my_kernel.compile(**constexprs)`` specialization, + - prints as the default-specialization MLIR text, + - exposes ``my_kernel.mlir_module()`` / ``verify()`` / ``emit()`` on the + default specialization for convenience. + """ + + def decorator(fn): + fn_name = name or fn.__name__ + kernel_signature = parse_jit_kernel_signature(fn) + module_style = ( + ModuleStyle.FLAT_AICORE + if func_attr == "pto.aicore" + else ModuleStyle.NESTED + ) + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=fn_name, + target_arch=target, + kernel_kind=kernel_kind, + module_style=module_style, + ), + kernel_signature, + fn, + ) + return KernelHandle(fn.__name__, compiler) + + return decorator + + +class KernelHandle(ModuleArtifact): + """ + Represents a JIT kernel template plus its compiled specializations. + + ``handle.compile(**constexprs)`` returns one compiled specialization. + ``print(handle)`` emits the default-specialization MLIR module text. + """ + + def __init__(self, py_name: str, compiler: KernelCompiler): + self._compiler = compiler + super().__init__(py_name, module_factory=self._build_default_module) + + def compile(self, **constexpr_bindings) -> CompiledKernelHandle: + return self._compiler.compile(**constexpr_bindings) + + def cached_specializations(self): + return self._compiler.cached_specializations() + + def _build_default_module(self): + return self.compile().build() + + +__all__ = ["CompiledKernelHandle", "jit", "KernelHandle"] diff --git a/ptodsl/ptodsl/_kernel_compilation.py b/ptodsl/ptodsl/_kernel_compilation.py new file mode 100644 index 000000000..b78dd7e2e --- /dev/null +++ b/ptodsl/ptodsl/_kernel_compilation.py @@ -0,0 +1,83 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Kernel specialization and compilation helpers for ``@pto.jit``.""" + +from __future__ import annotations + +from ._tracing import ModuleArtifact, SignatureTracingRuntime + + +class CompiledKernelHandle(ModuleArtifact): + """One compiled ``@pto.jit`` specialization.""" + + def __init__(self, py_name: str, *, specialization_key, constexpr_bindings, module_factory): + super().__init__(py_name, module_factory=module_factory) + self._specialization_key = specialization_key + self._constexpr_bindings = dict(constexpr_bindings) + + @property + def specialization_key(self): + return self._specialization_key + + @property + def constexpr_bindings(self): + return dict(self._constexpr_bindings) + + def __getitem__(self, launch_spec): + raise NotImplementedError( + "PTODSL v1 compiled handles only support compile / inspect / verify / emit. " + "Runtime launch is not implemented yet." + ) + + +class KernelCompiler: + """Per-kernel specialization cache and module builder.""" + + def __init__(self, py_name: str, module_spec, kernel_signature, callback): + self._py_name = py_name + self._module_spec = module_spec + self._kernel_signature = kernel_signature + self._callback = callback + self._kernel_identity = id(callback) + self._compiled_cache = {} + + def compile(self, **constexpr_bindings): + normalized_bindings = self._kernel_signature.bind_constexpr_bindings(constexpr_bindings) + specialization_key = self._kernel_signature.specialization_key( + self._kernel_identity, + normalized_bindings, + ) + + cached = self._compiled_cache.get(specialization_key) + if cached is not None: + return cached + + runtime = SignatureTracingRuntime( + self._module_spec, + self._kernel_signature, + self._callback, + constexpr_bindings=normalized_bindings, + ) + compiled = CompiledKernelHandle( + self._py_name, + specialization_key=specialization_key, + constexpr_bindings=normalized_bindings, + module_factory=runtime.build_module, + ) + compiled.build() + self._compiled_cache[specialization_key] = compiled + return compiled + + def cached_specializations(self): + return tuple(self._compiled_cache.values()) + + +__all__ = [ + "CompiledKernelHandle", + "KernelCompiler", +] diff --git a/ptodsl/ptodsl/_kernel_signature.py b/ptodsl/ptodsl/_kernel_signature.py new file mode 100644 index 000000000..a31f865c3 --- /dev/null +++ b/ptodsl/ptodsl/_kernel_signature.py @@ -0,0 +1,191 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Declarative PTODSL kernel-signature parsing and entry-ABI binding.""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass + +from ._host_tensors import bind_host_tensor_argument, infer_jit_host_tensor_spec +from ._surface_values import wrap_surface_value +from ._surface_types import constexpr as _constexpr_marker +from ._types import _resolve + + +@dataclass(frozen=True) +class KernelSpecializationKey: + kernel_identity: int + abi_signature: tuple + constexpr_signature: tuple[tuple[str, object], ...] + + +@dataclass(frozen=True) +class DeviceParameterSpec: + name: str + annotation: object + + def entry_arg_types(self): + return (_resolve(self.annotation),) + + def bind_entry_arguments(self, entry_arguments): + if not entry_arguments: + raise RuntimeError(f"entry ABI for device parameter '{self.name}' is incomplete") + return wrap_surface_value(entry_arguments[0]), entry_arguments[1:] + + def abi_signature(self): + return ("device", self.name, _hashable_signature_atom(self.annotation)) + + +@dataclass(frozen=True) +class TensorSpecParameterSpec: + name: str + tensor_spec: object + + def entry_arg_types(self): + return tuple(self.tensor_spec.entry_arg_types()) + + def bind_entry_arguments(self, entry_arguments): + return bind_host_tensor_argument(self.name, self.tensor_spec, entry_arguments) + + def abi_signature(self): + return ("tensor", self.name, self.tensor_spec.abi_signature()) + + +@dataclass(frozen=True) +class ConstexprParameterSpec: + name: str + default: object + + def bind_specialization(self, provided_bindings): + value = provided_bindings.get(self.name, self.default) + try: + hash(value) + except TypeError as exc: + raise TypeError( + f"@pto.jit constexpr parameter '{self.name}' must be hashable so it can " + "participate in the specialization cache" + ) from exc + return value + + +def _hashable_signature_atom(value): + try: + hash(value) + except TypeError: + return repr(value) + return value + + +@dataclass(frozen=True) +class KernelSignature: + positional_parameters: tuple + constexpr_parameters: tuple[ConstexprParameterSpec, ...] + + def compute_entry_arg_types(self): + arg_types = [] + for param in self.positional_parameters: + arg_types.extend(param.entry_arg_types()) + return tuple(arg_types) + + def bind_entry_arguments(self, entry_arguments): + remaining = tuple(entry_arguments) + bound_args = [] + for param in self.positional_parameters: + bound_value, remaining = param.bind_entry_arguments(remaining) + bound_args.append(bound_value) + if remaining: + raise RuntimeError(f"unexpected trailing entry arguments in PTODSL kernel ABI: {len(remaining)}") + return tuple(bound_args) + + def default_constexpr_bindings(self): + return {param.name: param.default for param in self.constexpr_parameters} + + def bind_constexpr_bindings(self, provided_bindings): + provided = dict(provided_bindings) + expected_names = {param.name for param in self.constexpr_parameters} + unknown = sorted(name for name in provided if name not in expected_names) + if unknown: + raise TypeError( + f"unknown @pto.jit constexpr parameter(s): {', '.join(unknown)}" + ) + + bound = {} + for param in self.constexpr_parameters: + bound[param.name] = param.bind_specialization(provided) + return bound + + def abi_signature(self): + return tuple(param.abi_signature() for param in self.positional_parameters) + + def specialization_key(self, kernel_identity, constexpr_bindings): + return KernelSpecializationKey( + kernel_identity=kernel_identity, + abi_signature=self.abi_signature(), + constexpr_signature=tuple( + (param.name, constexpr_bindings[param.name]) + for param in self.constexpr_parameters + ), + ) + + +def parse_jit_kernel_signature(py_fn) -> KernelSignature: + """Parse one authored ``@pto.jit`` function signature.""" + sig = inspect.signature(py_fn) + positional_parameters = [] + constexpr_parameters = [] + + for param in sig.parameters.values(): + if param.kind in { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + }: + host_tensor_spec = infer_jit_host_tensor_spec(param) + if host_tensor_spec is not None: + positional_parameters.append( + TensorSpecParameterSpec(param.name, host_tensor_spec) + ) + else: + positional_parameters.append( + DeviceParameterSpec(param.name, param.annotation) + ) + continue + + if param.kind is inspect.Parameter.KEYWORD_ONLY: + if param.annotation is not _constexpr_marker: + raise TypeError( + f"@pto.jit keyword-only parameter '{param.name}' must be annotated " + "with pto.constexpr in PTODSL v1" + ) + if param.default is inspect.Parameter.empty: + raise TypeError( + f"@pto.jit constexpr parameter '{param.name}' must declare a default " + "value until explicit compile-time specialization is implemented" + ) + constexpr_parameters.append(ConstexprParameterSpec(param.name, param.default)) + continue + + raise TypeError( + f"@pto.jit parameter '{param.name}' uses unsupported parameter kind " + f"{param.kind!r}" + ) + + return KernelSignature( + positional_parameters=tuple(positional_parameters), + constexpr_parameters=tuple(constexpr_parameters), + ) + + +__all__ = [ + "ConstexprParameterSpec", + "DeviceParameterSpec", + "KernelSpecializationKey", + "KernelSignature", + "TensorSpecParameterSpec", + "parse_jit_kernel_signature", +] diff --git a/ptodsl/ptodsl/_module.py b/ptodsl/ptodsl/_module.py deleted file mode 100644 index 2745504f8..000000000 --- a/ptodsl/ptodsl/_module.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -""" -``@pto.to_ir`` decorator and module-level IR builders. - -The decorator: -1. Inspects the function signature – annotations are ``_DType`` lazy - descriptors or concrete ``mlir.ir.Type`` objects. -2. Creates the MLIR context and module. -3. Calls the Python function body with actual MLIR SSA values. -4. Verifies the module and caches it as ``fn._ir_module``. -5. Adds ``__str__`` so ``print(my_kernel)`` prints the MLIR text. - -Module structure is selected by ``func_attr``: -- ``func_attr="pto.aicore"`` → flat module + ``pto.aicore`` function attribute - (used by softmax-style kernels) -- otherwise → nested double-module (used by vPTO TADD-style) -""" - -import inspect - -from ._bootstrap import make_context -from ._types import _resolve - -from mlir.dialects import func, pto as _pto -from mlir.ir import ( - Attribute, - InsertionPoint, - Location, - Module, - Operation, - StringAttr, - UnitAttr, -) - - -def _call_body(ir_fn, fn, arg_types): - """Add entry block to *ir_fn* and call *fn* with the SSA arguments.""" - entry = ir_fn.add_entry_block() - with InsertionPoint(entry): - fn(*entry.arguments) - func.ReturnOp([]) - - -def _build_flat_module(fn_name, arg_types, fn, arch, kernel_kind): - """ - Flat ``module attributes {pto.target_arch, pto.kernel_kind}`` with a - single function that carries ``pto.aicore``. - """ - m = Module.create() - m.operation.attributes["pto.target_arch"] = StringAttr.get(arch) - m.operation.attributes["pto.kernel_kind"] = Attribute.parse( - f"#pto.kernel_kind<{kernel_kind}>" - ) - fn_ty = func.FunctionType.get(arg_types, []) - with InsertionPoint(m.body): - ir_fn = func.FuncOp(fn_name, fn_ty) - ir_fn.attributes["pto.aicore"] = UnitAttr.get() - _call_body(ir_fn, fn, arg_types) - return m - - -def _build_nested_module(fn_name, arg_types, fn, arch, kernel_kind): - """ - Nested ``module { module { func … } }`` structure used by vPTO kernels - without function arguments (e.g. TADD). - """ - outer = Module.create() - outer.operation.attributes["pto.target_arch"] = StringAttr.get(arch) - - with InsertionPoint(outer.body): - # Module.create() ignores the active InsertionPoint, so create - # the inner module via Operation.create("builtin.module") instead. - inner_op = Operation.create("builtin.module", regions=1) - inner_op.attributes["pto.target_arch"] = StringAttr.get(arch) - inner_op.attributes["pto.kernel_kind"] = Attribute.parse( - f"#pto.kernel_kind<{kernel_kind}>" - ) - inner_body = inner_op.regions[0].blocks.append() - - with InsertionPoint(inner_body): - fn_ty = func.FunctionType.get(arg_types, []) - ir_fn = func.FuncOp(fn_name, fn_ty) - - _call_body(ir_fn, fn, arg_types) - return outer - - -def to_ir(name=None, *, kernel_kind: str = "vector", arch: str = "a5", - func_attr: str = None): - """ - Decorator that eagerly lowers a Python function to an MLIR module. - - Parameters - ---------- - name: IR function name (defaults to the Python function name). - kernel_kind: ``"vector"`` or ``"cube"`` – sets ``pto.kernel_kind``. - arch: Target architecture string, e.g. ``"a5"``. - func_attr: Optional function attribute. Pass ``"pto.aicore"`` to - select the flat-module structure with the aicore attribute. - - The decorated function is replaced by a :class:`KernelHandle` that: - - - prints as the MLIR module text (``print(my_kernel)``), - - exposes ``my_kernel.build()`` returning the ``mlir.ir.Module``, - - exposes ``my_kernel._ir_module`` for direct access. - """ - - def decorator(fn): - fn_name = name or fn.__name__ - sig = inspect.signature(fn) - ctx = make_context() - with ctx, Location.unknown(): - arg_types = [ - _resolve(p.annotation) - for p in sig.parameters.values() - if p.annotation is not inspect.Parameter.empty - ] - if func_attr == "pto.aicore": - mod = _build_flat_module(fn_name, arg_types, fn, arch, kernel_kind) - else: - mod = _build_nested_module(fn_name, arg_types, fn, arch, kernel_kind) - mod.operation.verify() - - return KernelHandle(fn.__name__, mod) - - return decorator - - -class KernelHandle: - """ - Represents a compiled PTO kernel. - - ``print(handle)`` emits the MLIR module text. - ``handle.build()`` returns the ``mlir.ir.Module`` (for ``check_ir.py``). - ``handle._ir_module`` is the raw module for direct access. - """ - - def __init__(self, py_name: str, module): - self._py_name = py_name - self._ir_module = module - - def build(self): - """Return the compiled ``mlir.ir.Module``.""" - return self._ir_module - - def __str__(self): - return str(self._ir_module) - - def __repr__(self): - return str(self._ir_module) - - -__all__ = ["to_ir", "KernelHandle"] diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index 0db485888..8f1cbe7da 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -21,15 +21,37 @@ """ from ._bootstrap import make_context # noqa: F401 – ensure MLIR on sys.path -from ._types import _resolve, mask_type, part_tensor_view_type, tensor_view_type +from ._diagnostics import tile_row_alignment_error +from ._host_tensors import resolve_tensor_data_entry +from ._scalar_coercion import coerce_scalar_to_type, materialize_scalar_literal +from ._runtime_scalar_ops import classify_runtime_scalar_type, emit_runtime_binary_op +from ._surface_values import ( + MaskResultValue, + PartitionTensorViewValue, + TensorViewValue, + TileSliceValue, + TileValue, + _unwrap_sequence, + compose_partition_spec, + emit_as_ptr, + infer_tile_element_type, + parse_tile_type_metadata, + unwrap_surface_value, + wrap_surface_value, +) +from ._types import _resolve, mask_type, part_tensor_view_type, tensor_view_type, vreg_type from mlir.dialects import arith, pto as _pto from mlir.ir import ( Attribute, + BF16Type, + F16Type, + F32Type, + FloatAttr, IndexType, IntegerType, - ShapedType, - StringAttr, + MemRefType, + Type, ) # Pipe name shorthands → canonical PIPE_* names @@ -46,6 +68,8 @@ def _pipe_attr(name: str): + if not isinstance(name, str): + return _pto.PipeAttr.get(name) canonical = _PIPE_ALIASES.get(name, name) if not canonical.startswith("PIPE_"): canonical = "PIPE_" + canonical @@ -67,42 +91,91 @@ def const(value: int, *, dtype=None): """ from ._types import index as _idx_dtype mlir_type = _resolve(dtype) if dtype is not None else _resolve(_idx_dtype) - return arith.ConstantOp(mlir_type, value).result + return wrap_surface_value(arith.ConstantOp(mlir_type, value).result) # ── Pointer ops ─────────────────────────────────────────────────────────────── def castptr(int_addr, result_ptr_type): """``pto.castptr`` – cast an integer address to a typed PTO pointer.""" - return _pto.CastPtrOp(_resolve(result_ptr_type), int_addr).result + return wrap_surface_value( + _pto.CastPtrOp(_resolve(result_ptr_type), unwrap_surface_value(int_addr)).result + ) def addptr(base_ptr, index_offset): """``pto.addptr`` – advance a pointer by an index offset.""" - return _pto.AddPtrOp(base_ptr, index_offset).result + return wrap_surface_value( + _pto.AddPtrOp(unwrap_surface_value(base_ptr), unwrap_surface_value(index_offset)).result + ) # ── Vector load / store ─────────────────────────────────────────────────────── -def vlds(src_ptr, offset, result_vreg_type): - """``pto.vlds`` – vector load from *src_ptr* at *offset*.""" - return _pto.VldsOp(_resolve(result_vreg_type), src_ptr, offset).result +def vlds(src_ptr, offset=None, result_vreg_type=None): + """``pto.vlds`` – vector load from a tile slice or from *src_ptr* at *offset*.""" + if isinstance(src_ptr, TileSliceValue): + if offset is not None or result_vreg_type is not None: + raise TypeError("vlds(tile[row, col:]) infers its memref slice and vreg type; do not pass offset/result_vreg_type") + return wrap_surface_value(_pto.VldsOp( + _infer_vreg_type_from_tile_slice(src_ptr), + unwrap_surface_value(src_ptr), + _index_zero(), + ).result) + + if offset is None or result_vreg_type is None: + raise TypeError("vlds(ptr, offset, result_vreg_type) requires both offset and result_vreg_type") + return wrap_surface_value(_pto.VldsOp( + _resolve(result_vreg_type), + unwrap_surface_value(src_ptr), + unwrap_surface_value(offset), + ).result) def vbrc_load(src_ptr, offset, result_vreg_type): """``pto.vlds {dist="BRC_B32"}`` – broadcast a scalar into all lanes.""" - return _pto.VldsOp(_resolve(result_vreg_type), src_ptr, offset, - dist="BRC_B32").result - - -def vsts(val, dst_ptr, offset, mask): - """``pto.vsts`` – vector store.""" - _pto.VstsOp(val, dst_ptr, offset, mask) + return wrap_surface_value( + _pto.VldsOp( + _resolve(result_vreg_type), + unwrap_surface_value(src_ptr), + unwrap_surface_value(offset), + dist="BRC_B32", + ).result + ) + + +def vsts(val, dst_ptr, offset, mask=None): + """``pto.vsts`` – vector store to a tile slice or to *dst_ptr* at *offset*.""" + if isinstance(dst_ptr, TileSliceValue): + if mask is not None: + raise TypeError("vsts(vec, tile[row, col:], mask) does not accept a separate offset argument") + _pto.VstsOp( + unwrap_surface_value(val), + unwrap_surface_value(dst_ptr), + _index_zero(), + unwrap_surface_value(offset), + ) + return + + if mask is None: + raise TypeError("vsts(vec, ptr, offset, mask) requires an explicit mask") + _pto.VstsOp( + unwrap_surface_value(val), + unwrap_surface_value(dst_ptr), + unwrap_surface_value(offset), + unwrap_surface_value(mask), + ) def vsts_1pt(val, dst_ptr, offset, mask): """``pto.vsts {dist="1PT_B32"}`` – store only the lowest lane.""" - _pto.VstsOp(val, dst_ptr, offset, mask, dist="1PT_B32") + _pto.VstsOp( + unwrap_surface_value(val), + unwrap_surface_value(dst_ptr), + unwrap_surface_value(offset), + unwrap_surface_value(mask), + dist="1PT_B32", + ) # ── Mask / predicate ops ────────────────────────────────────────────────────── @@ -114,13 +187,17 @@ def plt_b32(scalar): Returns ``(mask_value, scalar_out)``. ``scalar_out`` is often unused and can be discarded with ``_``. """ - plt_op = _pto.PltB32Op(mask_type("b32"), IntegerType.get_signless(32), scalar) - return plt_op.mask, plt_op.scalar_out + plt_op = _pto.PltB32Op( + _resolve(mask_type("b32")), + IntegerType.get_signless(32), + unwrap_surface_value(scalar), + ) + return wrap_surface_value(plt_op.mask), wrap_surface_value(plt_op.scalar_out) def pset_b32(pattern: str): """``pto.pset_b32 "PATTERN"`` → ``!pto.mask``.""" - return _pto.PsetB32Op(mask_type("b32"), pattern).result + return wrap_surface_value(_pto.PsetB32Op(_resolve(mask_type("b32")), pattern).result) # ── Vector math (result type inferred from first operand) ───────────────────── @@ -128,32 +205,72 @@ def pset_b32(pattern: str): def vadd(lhs, rhs, mask, result_type=None): """``pto.vadd`` – element-wise add.""" rt = result_type if result_type is not None else lhs.type - return _pto.VaddOp(_resolve(rt), lhs, rhs, mask).result + return wrap_surface_value( + _pto.VaddOp( + _resolve(rt), + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) def vmul(lhs, rhs, mask): """``pto.vmul`` – element-wise multiply.""" - return _pto.VmulOp(lhs.type, lhs, rhs, mask).result + return wrap_surface_value( + _pto.VmulOp( + unwrap_surface_value(lhs).type, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) def vmax(lhs, rhs, mask): """``pto.vmax`` – element-wise maximum.""" - return _pto.VmaxOp(lhs.type, lhs, rhs, mask).result + return wrap_surface_value( + _pto.VmaxOp( + unwrap_surface_value(lhs).type, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) def vdiv(lhs, rhs, mask): """``pto.vdiv`` – element-wise divide.""" - return _pto.VdivOp(lhs.type, lhs, rhs, mask).result + return wrap_surface_value( + _pto.VdivOp( + unwrap_surface_value(lhs).type, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) def vcmax(v, mask): """``pto.vcmax`` – cross-lane maximum reduction.""" - return _pto.VcmaxOp(v.type, v, mask).result + return wrap_surface_value( + _pto.VcmaxOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + ) def vcadd(v, mask): """``pto.vcadd`` – cross-lane add (sum reduction).""" - return _pto.VcaddOp(v.type, v, mask).result + return wrap_surface_value( + _pto.VcaddOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + ) def vdup(v, mask, *, position=None): @@ -161,26 +278,170 @@ def vdup(v, mask, *, position=None): Pass ``position="LOWEST"`` to broadcast the lowest (lane-0) element. """ - return _pto.VdupOp(v.type, v, mask, position=position).result + return wrap_surface_value( + _pto.VdupOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + position=position, + ).result + ) def vexpdif(inp, ref, mask, part: str = "ODD"): """``pto.vexpdif`` – ``exp(inp - ref)`` selecting ODD or EVEN lanes.""" - return _pto.VexpdifOp(inp.type, inp, ref, mask, part).result + return wrap_surface_value( + _pto.VexpdifOp( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(ref), + unwrap_surface_value(mask), + part, + ).result + ) + + +def vexp(inp, mask): + """``pto.vexp`` – element-wise exponential.""" + return wrap_surface_value( + _pto.VexpOp( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(mask), + ).result + ) + + +def vcgmax(v, mask): + """``pto.vcgmax`` – group maximum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgmaxOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vcgadd(v, mask): + """``pto.vcgadd`` – group sum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgaddOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vsubs(inp, scalar, mask): + """``pto.vsubs`` – vector minus scalar under mask.""" + raw_scalar = _coerce_scalar_like_vector_element(inp, scalar, context="vsubs") + neg_scalar = _negate_runtime_scalar(raw_scalar) + return wrap_surface_value( + _pto.VaddsOp( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + neg_scalar, + unwrap_surface_value(mask), + ).result + ) # ── Tile-domain operations ──────────────────────────────────────────────────── -def make_tensor_view(ptr, *, shape, strides): +def make_tensor_view(ptr, *, shape=None, strides=None): """ ``pto.make_tensor_view`` – wrap a pointer as a tensor view. Type is inferred: rank from ``len(shape)``, element type from ``ptr``. """ + authored_ptr = ptr + if shape is None: + shape = getattr(authored_ptr, "shape", None) + if strides is None: + strides = getattr(authored_ptr, "strides", None) + if shape is None or strides is None: + raise TypeError("make_tensor_view() requires shape= and strides=, or a host tensor proxy carrying both") + ptr = resolve_tensor_data_entry(authored_ptr) rank = len(shape) - elem = _pto.PtrType(ptr.type).element_type + raw_ptr = unwrap_surface_value(ptr) + elem = _pto.PtrType(raw_ptr.type).element_type tv_type = tensor_view_type(rank, elem) - return _pto.MakeTensorViewOp(tv_type, ptr, list(shape), list(strides)).result + value = _pto.MakeTensorViewOp( + tv_type, + raw_ptr, + _unwrap_sequence(shape), + _unwrap_sequence(strides), + ).result + return TensorViewValue(value, shape=tuple(shape), strides=tuple(strides)) + + +def _normalize_static_tile_shape(shape): + static_shape = [] + for dim in shape: + if isinstance(dim, bool) or not isinstance(dim, int): + raise TypeError( + "alloc_tile(shape=...) currently requires a static physical tile shape. " + "Use constexpr/static integers for shape and place runtime metadata in valid_shape." + ) + static_shape.append(dim) + return tuple(static_shape) + + +def _split_valid_shape(shape, valid_shape): + rank = len(shape) + if valid_shape is None: + return tuple(shape), None, None, tuple(shape) + + if len(valid_shape) != rank: + raise TypeError( + f"alloc_tile(valid_shape=...) rank mismatch: expected {rank} dims, got {len(valid_shape)}" + ) + + type_valid_shape = [] + surface_valid_shape = [] + valid_row = None + valid_col = None + for index, dim in enumerate(valid_shape): + surface_valid_shape.append(dim) + if isinstance(dim, bool): + raise TypeError("alloc_tile(valid_shape=...) does not accept bool dimensions") + if isinstance(dim, int): + type_valid_shape.append(dim) + continue + type_valid_shape.append(-1) + if index == 0: + valid_row = dim + continue + if index == 1: + valid_col = dim + continue + raise TypeError( + "alloc_tile(valid_shape=...) currently only supports dynamic runtime metadata " + "for the first two dimensions" + ) + return tuple(type_valid_shape), valid_row, valid_col, tuple(surface_valid_shape) + + +def _uses_row_major_none_box_layout(blayout, slayout) -> bool: + return str(blayout).lower() == "rowmajor" and str(slayout).lower() == "nonebox" + + +def _validate_authored_tile_row_alignment(shape, dtype, *, blayout, slayout): + if not _uses_row_major_none_box_layout(blayout, slayout): + return + if not shape: + return + elem_bytewidth = _element_bytewidth(_resolve(dtype)) + row_bytes = shape[-1] * elem_bytewidth + required_alignment = 32 + if row_bytes % required_alignment == 0: + return + raise tile_row_alignment_error( + shape=shape, + dtype=str(_resolve(dtype)), + row_bytes=row_bytes, + required_alignment=required_alignment, + ) def partition_view(tv, *, offsets, sizes): @@ -189,44 +450,609 @@ def partition_view(tv, *, offsets, sizes): Type is inferred from the source tensor-view type. """ - src_type = _pto.TensorViewType(tv.type) + spec = compose_partition_spec(tv, offsets=offsets, sizes=sizes) + if spec is not None: + source = spec.root_tensor_view + offsets = spec.offsets + sizes = spec.sizes + else: + source = tv + + raw_source = unwrap_surface_value(source) + src_type = _pto.TensorViewType(raw_source.type) rank = src_type.rank elem = src_type.element_type ptv_type = part_tensor_view_type(rank, elem) - return _pto.PartitionViewOp(ptv_type, tv, list(offsets), list(sizes)).result + value = _pto.PartitionViewOp( + ptv_type, + raw_source, + _unwrap_sequence(offsets), + _unwrap_sequence(sizes), + ).result + return wrap_surface_value( + value, + root_tensor_view=source if spec is None else spec.root_tensor_view, + offsets=tuple(offsets), + sizes=tuple(sizes), + ) + + +def alloc_tile( + tile_type=None, + *, + shape=None, + dtype=None, + memory_space="ub", + valid_shape=None, + blayout: str = "RowMajor", + slayout: str = "NoneBox", + fractal_size: int = 512, + pad: str = "Null", + addr=None, + valid_row=None, + valid_col=None, +): + """ + ``pto.alloc_tile``. + + Accepts either the authored surface form: + ``alloc_tile(shape=[...], dtype=..., memory_space=...)`` -def alloc_tile(tile_type, *, addr, valid_row, valid_col=None): - """``pto.alloc_tile``.""" - return _pto.AllocTileOp(_resolve(tile_type), addr=addr, valid_row=valid_row, - valid_col=valid_col).result + or the low-level explicit-type form: + + ``alloc_tile(tile_type, addr=..., valid_row=..., valid_col=...)``. + """ + if tile_type is not None and shape is not None: + raise TypeError("alloc_tile() accepts either tile_type or shape=/dtype=, not both") + + if tile_type is None: + if shape is None or dtype is None: + raise TypeError("alloc_tile() requires either tile_type or both shape= and dtype=") + if addr is not None or valid_row is not None or valid_col is not None: + raise TypeError( + "alloc_tile(shape=..., dtype=...) uses the authored surface form; " + "addr=/valid_row=/valid_col= are only supported with an explicit tile_type" + ) + shape = _normalize_static_tile_shape(shape) + _validate_authored_tile_row_alignment(shape, dtype, blayout=blayout, slayout=slayout) + type_valid_shape, valid_row, valid_col, surface_valid_shape = _split_valid_shape(shape, valid_shape) + from ._types import tile_buf_type + tile_type = tile_buf_type( + shape, + dtype, + type_valid_shape, + blayout=blayout, + address_space=memory_space, + slayout=slayout, + fractal_size=fractal_size, + pad=pad, + ) + else: + surface_valid_shape = None + + value = _pto.AllocTileOp( + _resolve(tile_type), + addr=unwrap_surface_value(addr) if addr is not None else None, + valid_row=unwrap_surface_value(valid_row) if valid_row is not None else None, + valid_col=unwrap_surface_value(valid_col) if valid_col is not None else None, + ).result + if tile_type is not None and (valid_row is not None or valid_col is not None): + parsed_tile_type = parse_tile_type_metadata(_resolve(tile_type)) + rank = len(shape) if shape is not None else len(parsed_tile_type["shape_dims"]) + surface_valid_shape = [None] * rank + if rank >= 1: + surface_valid_shape[0] = valid_row + if rank >= 2: + surface_valid_shape[1] = valid_col + surface_valid_shape = tuple(surface_valid_shape) + return wrap_surface_value( + value, + tile_metadata={ + "shape": shape, + "dtype": dtype, + "memory_space": memory_space, + "valid_shape": surface_valid_shape, + }, + ) + + +def set_tile_valid_shape(tile, valid_shape): + """Update the runtime valid-shape metadata of a rank-2 dynamic tile.""" + if len(valid_shape) != 2: + raise TypeError( + "tile.valid_shape assignment currently expects exactly two dimensions" + ) + + parsed_tile_type = parse_tile_type_metadata(unwrap_surface_value(tile).type) + if parsed_tile_type is None: + raise TypeError("tile.valid_shape assignment expects a tile_buf-backed value") + if len(parsed_tile_type["shape_dims"]) != 2: + raise TypeError("tile.valid_shape assignment currently only supports rank-2 tiles") + if parsed_tile_type["valid_dims"] != (None, None): + raise TypeError( + "tile.valid_shape assignment requires a tile allocated with fully dynamic " + "valid_shape=[..., ...]" + ) + + valid_row, valid_col = _unwrap_sequence(valid_shape) + _pto.SetValidShapeOp( + unwrap_surface_value(tile), + valid_row, + valid_col, + ) def tload(part, tile): """``pto.tload ins(part) outs(tile)``.""" - _pto.TLoadOp(None, part, tile) + _pto.TLoadOp(None, unwrap_surface_value(part), unwrap_surface_value(tile)) def tstore(tile, part): """``pto.tstore ins(tile) outs(part)``.""" - _pto.TStoreOp(None, tile, part) + _pto.TStoreOp(None, unwrap_surface_value(tile), unwrap_surface_value(part)) -def tile_ptr(tile, result_ptr_type): - """``pto.tile_buf_addr`` – materialise a UB pointer from a tile handle.""" - return _pto.TileBufAddrOp(_resolve(result_ptr_type), tile).result +def tmov(src, dst): + """``pto.tmov ins(src) outs(dst)`` – move data between tile domains.""" + _pto.TMovOp(None, unwrap_surface_value(src), unwrap_surface_value(dst)) + + +def as_ptr(value, result_ptr_type=None): + """Materialize a typed pointer from a tile or tensor-view descriptor.""" + wrapped = wrap_surface_value(value) + return emit_as_ptr(wrapped, result_ptr_type) + + +def _constant_like(value, mlir_type): + value = unwrap_surface_value(value) + if hasattr(value, "type"): + return value + if isinstance(value, float): + return arith.ConstantOp(mlir_type, FloatAttr.get(mlir_type, value)).result + return arith.ConstantOp(mlir_type, value).result + + +def _index_zero(): + return arith.ConstantOp(IndexType.get(), 0).result + + +def _infer_vreg_type_from_tile_slice(tile_slice: TileSliceValue): + memref_type = MemRefType(tile_slice.type) + elem_type = memref_type.element_type + lanes = _elements_per_vreg(elem_type) + return _resolve(vreg_type(lanes, elem_type)) + + +def _elements_per_vreg(elem_type): + if F32Type.isinstance(elem_type): + bytewidth = 4 + elif any(cls.isinstance(elem_type) for cls in (F16Type, BF16Type)): + bytewidth = 2 + elif IntegerType.isinstance(elem_type): + width = IntegerType(elem_type).width + if width % 8 != 0: + raise TypeError(f"vlds/vsts tile-slice sugar does not support sub-byte integer element type {elem_type}") + bytewidth = width // 8 + else: + raise TypeError(f"vlds/vsts tile-slice sugar does not support element type {elem_type}") + return 256 // bytewidth + + +def _infer_vreg_metadata(vector_value): + raw_type = unwrap_surface_value(vector_value).type + try: + vreg_type = _pto.VRegType(raw_type) + return vreg_type.lanes, vreg_type.element_type + except Exception: + text = str(raw_type) + if not text.startswith("!pto.vreg<") or "x" not in text: + raise TypeError(f"expected PTO vector-register type, got {raw_type}") + body = text[len("!pto.vreg<"):-1] + lanes_text, elem_text = body.split("x", 1) + return int(lanes_text), Type.parse(elem_text) + + +def _extract_lowest_lane_scalar(vector_value, mask): + lanes, elem_type = _infer_vreg_metadata(vector_value) + tmp_tile = alloc_tile(shape=[1, lanes], dtype=elem_type, valid_shape=[1, 1]) + vsts_1pt(vector_value, tmp_tile.as_ptr(), _index_zero(), mask) + from . import scalar as _scalar + return _scalar.load(tmp_tile[0, 0]) + + +def _element_bytewidth(elem_type): + if F32Type.isinstance(elem_type): + return 4 + if any(cls.isinstance(elem_type) for cls in (F16Type, BF16Type)): + return 2 + if IntegerType.isinstance(elem_type): + width = IntegerType(elem_type).width + if width % 8 != 0: + raise TypeError(f"unsupported sub-byte integer element type {elem_type}") + return width // 8 + raise TypeError(f"unsupported element type {elem_type}") + + +def _mask_bits_for_dtype(dtype): + elem_type = _resolve(dtype) + bytewidth = _element_bytewidth(elem_type) + if bytewidth == 4: + return 32 + if bytewidth == 2: + return 16 + if bytewidth == 1: + return 8 + raise TypeError(f"make_mask(...) does not support dtype {elem_type}") + + +def _pset_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PsetB8Op, + 16: _pto.PsetB16Op, + 32: _pto.PsetB32Op, + }[mask_bits] + + +def _plt_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PltB8Op, + 16: _pto.PltB16Op, + 32: _pto.PltB32Op, + }[mask_bits] + + +def _coerce_i32(value, *, context: str): + raw_value = unwrap_surface_value(value) + i32_type = IntegerType.get_signless(32) + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return arith.ConstantOp(i32_type, raw_value).result + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") + if kind == "index": + return arith.IndexCastOp(i32_type, raw_value).result + if raw_value.type == i32_type: + return raw_value + width = IntegerType(raw_value.type).width + if width < 32: + return arith.ExtSIOp(i32_type, raw_value).result + if width > 32: + return arith.TruncIOp(i32_type, raw_value).result + return raw_value + + +def _coerce_i64(value, *, context: str): + raw_value = unwrap_surface_value(value) + i64_type = IntegerType.get_signless(64) + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return arith.ConstantOp(i64_type, raw_value).result + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") + if kind == "index": + return arith.IndexCastOp(i64_type, raw_value).result + if raw_value.type == i64_type: + return raw_value + width = IntegerType(raw_value.type).width + if width < 64: + return arith.ExtSIOp(i64_type, raw_value).result + if width > 64: + return arith.TruncIOp(i64_type, raw_value).result + return raw_value + + +def _i64_zero(): + return arith.ConstantOp(IntegerType.get_signless(64), 0).result + + +def _coerce_scalar_like_vector_element(vector_value, scalar_value, *, context: str): + _, elem_type = _infer_vreg_metadata(vector_value) + return coerce_scalar_to_type(scalar_value, elem_type, context=f"{context}(...)") + + +def _negate_runtime_scalar(value): + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + zero = materialize_scalar_literal(0.0 if kind == "float" else 0, raw_value.type, context="_negate_runtime_scalar(...)") + return emit_runtime_binary_op("sub", zero, raw_value) + + +def _mul_bytes(value, elem_type): + factor = _element_bytewidth(_resolve(elem_type)) + raw_value = unwrap_surface_value(value) + if isinstance(raw_value, int): + return raw_value * factor + return emit_runtime_binary_op("mul", raw_value, factor) + + +def _membar_attr(kind: str): + normalized = str(kind) + supported = { + "VV_ALL", + "VST_VLD", + "VLD_VST", + "VST_VST", + "VS_ALL", + "VST_LD", + "VLD_ST", + "VST_ST", + "SV_ALL", + "ST_VLD", + "LD_VST", + "ST_VST", + "SS_ALL", + "ST_LD", + "LD_ST", + "ST_ST", + } + if normalized not in supported: + raise ValueError(f"unsupported mem_bar kind {kind!r}") + return Attribute.parse(f"#pto.membar<{normalized}>") + + +def _acc_store_ub_dst_mode_attr(mode): + normalized = { + 0: "single", + 1: "split_m", + 2: "split_n", + "single": "single", + "split_m": "split_m", + "split_n": "split_n", + }.get(mode if isinstance(mode, int) else str(mode).lower()) + if normalized is None: + raise ValueError(f"unsupported mte_l0c_ub dst_mode {mode!r}") + return Attribute.parse(f"#pto") + + +def _infer_dma_partition_row_stride(partition: PartitionTensorViewValue): + if partition.shape is None or partition.strides is None: + raise TypeError("mte_load/mte_store require partition view shape/stride metadata") + outer_dims = list(partition.shape[:-1]) + non_unit = [i for i, dim in enumerate(outer_dims) if dim != 1] + if len(non_unit) > 1: + raise TypeError( + "mte_load/mte_store currently only support partitions with at most one non-unit " + "dimension before the contiguous innermost dimension" + ) + if not non_unit: + return 1, 0 + dim_index = non_unit[0] + return partition.shape[dim_index], partition.strides[dim_index] + + +def _infer_dma_tile_geometry(tile: TileValue): + if tile.shape is None: + raise TypeError("mte_load/mte_store require tile shape metadata") + if len(tile.shape) == 1: + valid_cols = tile.valid_shape[0] + return 1, valid_cols, tile.shape[0] + if len(tile.shape) == 2: + return tile.valid_shape[0], tile.valid_shape[1], tile.shape[1] + raise TypeError("mte_load/mte_store currently only support rank-1 or rank-2 tiles") + + +def _infer_dma_2d_copy_signature(partition, tile, *, direction: str): + row_count, src_row_stride = _infer_dma_partition_row_stride(partition) + tile_rows, valid_cols, physical_cols = _infer_dma_tile_geometry(tile) + if direction == "gm_to_ub": + return row_count, valid_cols, _mul_bytes(src_row_stride, infer_tile_element_type(tile)), physical_cols * _element_bytewidth(infer_tile_element_type(tile)) + return row_count, valid_cols, physical_cols * _element_bytewidth(infer_tile_element_type(tile)), _mul_bytes(src_row_stride, infer_tile_element_type(tile)) + + +def fill_tile(tile, value): + """Broadcast a scalar into an entire tile.""" + wrapped_tile = wrap_surface_value(tile) + scalar_value = _constant_like(value, infer_tile_element_type(wrapped_tile)) + _pto.TExpandsOp(scalar_value, unwrap_surface_value(wrapped_tile)) + + +def make_mask(dtype, value): + """Create a predicate mask matching *dtype* granularity.""" + mask_bits = _mask_bits_for_dtype(dtype) + result_type = _resolve(mask_type(f"b{mask_bits}")) + + if isinstance(value, str): + return wrap_surface_value(_pset_op_for_mask_bits(mask_bits)(result_type, value).result) + + raw_value = unwrap_surface_value(value) + raw_value = _coerce_i32(raw_value, context="make_mask(..., value)") + plt_op = _plt_op_for_mask_bits(mask_bits)(result_type, IntegerType.get_signless(32), raw_value) + return MaskResultValue(plt_op.mask, plt_op.scalar_out) # ── Hardware / sync ─────────────────────────────────────────────────────────── +def mte_load(source, destination): + """ + Convenience GM->on-chip load surface. + + Current scope is intentionally narrow: contiguous rank-1 or squeezed-rank-2 + partition views lowering into VEC or MAT tiles. + """ + source = wrap_surface_value(source) + destination = wrap_surface_value(destination) + if not isinstance(source, PartitionTensorViewValue) or not isinstance(destination, TileValue): + raise TypeError("mte_load(source, destination) expects (PartitionTensorView, Tile)") + + src_ptr = emit_as_ptr(source) + dst_ptr = emit_as_ptr(destination) + row_count, valid_cols, src_row_stride, dst_row_stride = _infer_dma_2d_copy_signature( + source, destination, direction="gm_to_ub" + ) + destination_type = parse_tile_type_metadata(unwrap_surface_value(destination).type) + if destination_type is None: + raise TypeError("mte_load(source, destination) expects a tile_buf-backed destination") + destination_space = destination_type["memory_space"] + len_burst = _coerce_i64(_mul_bytes(valid_cols, infer_tile_element_type(destination)), context="mte_load len_burst") + n_burst = _coerce_i64(row_count, context="mte_load n_burst") + src_stride = _coerce_i64(src_row_stride, context="mte_load src_stride") + dst_stride = _coerce_i64(dst_row_stride, context="mte_load dst_stride") + + if destination_space == "vec": + _pto.MteGmUbOp( + unwrap_surface_value(src_ptr), + unwrap_surface_value(dst_ptr), + _i64_zero(), + len_burst, + n_burst, + src_stride, + dst_stride, + [], + [], + [], + ) + return + + if destination_space == "mat": + _pto.MteGmL1Op( + unwrap_surface_value(src_ptr), + unwrap_surface_value(dst_ptr), + len_burst, + n_burst, + src_stride, + dst_stride, + [], + [], + [], + ) + return + + raise TypeError( + "mte_load(source, destination) currently supports VEC or MAT tile destinations, " + f"got memory_space={destination_space!r}" + ) + + +def mte_store(source, destination): + """Convenience UB->GM store surface matching ``mte_load`` scope.""" + source = wrap_surface_value(source) + destination = wrap_surface_value(destination) + if not isinstance(source, TileValue) or not isinstance(destination, PartitionTensorViewValue): + raise TypeError("mte_store(source, destination) expects (Tile, PartitionTensorView)") + + src_ptr = emit_as_ptr(source) + dst_ptr = emit_as_ptr(destination) + row_count, valid_cols, src_row_stride, dst_row_stride = _infer_dma_2d_copy_signature( + destination, source, direction="ub_to_gm" + ) + _pto.MteUbGmOp( + unwrap_surface_value(src_ptr), + unwrap_surface_value(dst_ptr), + _coerce_i64(_mul_bytes(valid_cols, infer_tile_element_type(source)), context="mte_store len_burst"), + _coerce_i64(row_count, context="mte_store n_burst"), + _coerce_i64(src_row_stride, context="mte_store src_stride"), + _coerce_i64(dst_row_stride, context="mte_store dst_stride"), + [], + [], + [], + ) + + +def mem_bar(barrier_type): + """``pto.mem_bar`` with a small authored enum surface.""" + barrier_name = getattr(barrier_type, "value", barrier_type) + _pto.MemBarOp(kind=_membar_attr(barrier_name)) + + +def mte_l1_l0a(source, destination, m, k, *, transpose=False): + """``pto.mte_l1_l0a`` – cube-side LEFT staging.""" + _pto.MteL1L0aOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(m, context="mte_l1_l0a m"), + _coerce_i64(k, context="mte_l1_l0a k"), + transpose=transpose, + ) + + +def mte_l1_l0b(source, destination, k, n, *, transpose=False): + """``pto.mte_l1_l0b`` – cube-side RIGHT staging.""" + _pto.MteL1L0bOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(k, context="mte_l1_l0b k"), + _coerce_i64(n, context="mte_l1_l0b n"), + transpose=transpose, + ) + + +def mte_l0c_ub(source, destination, m, n, src_stride, dst_stride, sub_blockid=0, *, dst_mode="single"): + """``pto.mte_l0c_ub`` – ACC to UB store.""" + _pto.MteL0cUbOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(m, context="mte_l0c_ub m"), + _coerce_i64(n, context="mte_l0c_ub n"), + _coerce_i64(src_stride, context="mte_l0c_ub src_stride"), + _coerce_i64(dst_stride, context="mte_l0c_ub dst_stride"), + _acc_store_ub_dst_mode_attr(dst_mode), + sub_blockid=_coerce_i64(sub_blockid, context="mte_l0c_ub sub_blockid"), + ) + + +def mad(lhs, rhs, dst, m, n, k): + """``pto.mad`` – cube matmul accumulate.""" + _pto.MadOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad m"), + _coerce_i64(n, context="mad n"), + _coerce_i64(k, context="mad k"), + ) + def get_block_idx(): """``pto.get_block_idx`` → i64 block index.""" - return _pto.GetBlockIdxOp().result + return wrap_surface_value(_pto.GetBlockIdxOp().result) + + +def get_block_num(): + """``pto.get_block_num`` → i64 block count.""" + return wrap_surface_value(_pto.GetBlockNumOp().result) + + +def get_subblock_idx(): + """``pto.get_subblock_idx`` → i64 subblock index.""" + return wrap_surface_value(_pto.GetSubBlockIdxOp().result) + + +def get_subblock_num(): + """``pto.get_subblock_num`` → i64 subblock count.""" + return wrap_surface_value(_pto.GetSubBlockNumOp().result) + + +def store_vfsimt_info(dim_z, dim_y, dim_x): + """``pto.store_vfsimt_info`` – configure the SIMT VF launch descriptor.""" + _pto.StoreVfSimtInfoOp( + unwrap_surface_value(dim_z), + unwrap_surface_value(dim_y), + unwrap_surface_value(dim_x), + ) + + +def get_tid_x(): + """``pto.get_tid_x`` → i32 SIMT lane X coordinate.""" + return wrap_surface_value(_pto.GetTidXOp().result) + + +def get_tid_y(): + """``pto.get_tid_y`` → i32 SIMT lane Y coordinate.""" + return wrap_surface_value(_pto.GetTidYOp().result) + + +def get_tid_z(): + """``pto.get_tid_z`` → i32 SIMT lane Z coordinate.""" + return wrap_surface_value(_pto.GetTidZOp().result) -def barrier_all(): - """``pto.barrier #pto.pipe``.""" - _pto.BarrierOp(_pipe_attr("ALL")) +def pipe_barrier(pipe): + """``pto.pipe_barrier(pipe)`` – drain the specified hardware pipeline.""" + _pto.BarrierOp(_pipe_attr(pipe)) def set_flag(src: str, dst: str, *, event_id: int = 0): @@ -247,10 +1073,15 @@ def wait_flag(src: str, dst: str, *, event_id: int = 0): "const", "castptr", "addptr", "vlds", "vbrc_load", "vsts", "vsts_1pt", - "plt_b32", "pset_b32", + "plt_b32", "pset_b32", "make_mask", "vadd", "vmul", "vmax", "vdiv", "vcmax", "vcadd", "vdup", "vexpdif", + "vexp", "vcgmax", "vcgadd", "vsubs", "make_tensor_view", "partition_view", - "alloc_tile", "tload", "tstore", "tile_ptr", - "get_block_idx", "barrier_all", "set_flag", "wait_flag", + "alloc_tile", "tload", "tstore", "tmov", "as_ptr", + "mte_load", "mte_store", "mem_bar", + "mte_l1_l0a", "mte_l1_l0b", "mte_l0c_ub", "mad", + "get_block_idx", "get_block_num", "get_subblock_idx", "get_subblock_num", + "store_vfsimt_info", "get_tid_x", "get_tid_y", "get_tid_z", + "pipe_barrier", "set_flag", "wait_flag", ] diff --git a/ptodsl/ptodsl/_runtime_index_ops.py b/ptodsl/ptodsl/_runtime_index_ops.py new file mode 100644 index 000000000..03256c8aa --- /dev/null +++ b/ptodsl/ptodsl/_runtime_index_ops.py @@ -0,0 +1,43 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time helpers for coercing authored runtime values to MLIR index.""" + +from __future__ import annotations + +from mlir.dialects import arith +from mlir.ir import IndexType, IntegerType + + +def coerce_runtime_index(value, *, context: str): + """Normalize one authored loop/slice bound to an MLIR index SSA value.""" + if isinstance(value, bool): + raise TypeError(f"{context} does not accept bool values") + + if isinstance(value, int): + return arith.ConstantOp(IndexType.get(), value).result + + if not hasattr(value, "type"): + raise TypeError( + f"{context} expects a Python int, an index value, or an integer runtime scalar; " + f"got {value!r}" + ) + + value_type = value.type + if IndexType.isinstance(value_type): + return value + if IntegerType.isinstance(value_type): + return arith.IndexCastOp(IndexType.get(), value).result + + raise TypeError( + f"{context} expects an index or integer runtime scalar, got {value_type}" + ) + + +__all__ = [ + "coerce_runtime_index", +] diff --git a/ptodsl/ptodsl/_runtime_scalar_ops.py b/ptodsl/ptodsl/_runtime_scalar_ops.py new file mode 100644 index 000000000..f2e5cdfd7 --- /dev/null +++ b/ptodsl/ptodsl/_runtime_scalar_ops.py @@ -0,0 +1,134 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time authored scalar operator lowering for runtime values.""" + +from __future__ import annotations + +from mlir.dialects import arith +from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType + + +_INTEGER_BINARY_OPS = { + "add": arith.AddIOp, + "sub": arith.SubIOp, + "mul": arith.MulIOp, + "floordiv": arith.FloorDivSIOp, + "mod": arith.RemSIOp, +} + +_FLOAT_BINARY_OPS = { + "add": arith.AddFOp, + "sub": arith.SubFOp, + "mul": arith.MulFOp, + "truediv": arith.DivFOp, +} + + +def emit_runtime_binary_op(op_name: str, lhs, rhs): + """Lower one authored runtime scalar binary operator.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind in {"index", "integer"}: + op_cls = _INTEGER_BINARY_OPS.get(op_name) + if op_cls is None: + raise TypeError(f"runtime scalar operator '{op_name}' is not supported for integer/index values") + return op_cls(lhs, rhs).result + if kind == "float": + op_cls = _FLOAT_BINARY_OPS.get(op_name) + if op_cls is None: + raise TypeError(f"runtime scalar operator '{op_name}' is not supported for floating-point values") + return op_cls(lhs, rhs).result + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def emit_runtime_max(lhs, rhs): + """Lower one authored runtime scalar max operation.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind == "float": + return arith.MaximumFOp(lhs, rhs).result + if kind == "integer": + return arith.MaxSIOp(lhs, rhs).result + if kind == "index": + cond = arith.CmpIOp(arith.CmpIPredicate.sge, lhs, rhs).result + return arith.SelectOp(cond, lhs, rhs).result + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def normalize_runtime_binary_operands(lhs, rhs): + lhs_is_value = _is_mlir_value(lhs) + rhs_is_value = _is_mlir_value(rhs) + + if not lhs_is_value and not rhs_is_value: + raise TypeError("runtime scalar operators require at least one traced runtime operand") + + if lhs_is_value and rhs_is_value: + return _reconcile_typed_operands(lhs, rhs) + + anchor_type = lhs.type if lhs_is_value else rhs.type + lhs = lhs if lhs_is_value else _materialize_literal(lhs, anchor_type) + rhs = rhs if rhs_is_value else _materialize_literal(rhs, anchor_type) + return _reconcile_typed_operands(lhs, rhs) + + +def _reconcile_typed_operands(lhs, rhs): + lhs_type = lhs.type + rhs_type = rhs.type + + if lhs_type == rhs_type: + return lhs, rhs, classify_runtime_scalar_type(lhs_type) + + if IndexType.isinstance(lhs_type) and IntegerType.isinstance(rhs_type): + rhs = arith.IndexCastOp(IndexType.get(), rhs).result + return lhs, rhs, "index" + + if IntegerType.isinstance(lhs_type) and IndexType.isinstance(rhs_type): + lhs = arith.IndexCastOp(IndexType.get(), lhs).result + return lhs, rhs, "index" + + raise TypeError( + "runtime scalar operators require matching scalar types or an index/integer pair; " + f"got {lhs_type} and {rhs_type}" + ) + + +def _materialize_literal(value, anchor_type): + if isinstance(value, bool): + raise TypeError("runtime scalar operators do not accept bool literals") + + kind = classify_runtime_scalar_type(anchor_type) + if kind == "float": + return arith.ConstantOp(anchor_type, FloatAttr.get(anchor_type, float(value))).result + + if isinstance(value, float): + raise TypeError( + "runtime scalar operators cannot materialize a floating-point literal " + f"against non-floating operand type {anchor_type}" + ) + + return arith.ConstantOp(anchor_type, int(value)).result + + +def classify_runtime_scalar_type(type_obj): + if IndexType.isinstance(type_obj): + return "index" + if IntegerType.isinstance(type_obj): + return "integer" + if any(cls.isinstance(type_obj) for cls in (BF16Type, F16Type, F32Type)): + return "float" + raise TypeError(f"runtime scalar operators only support index/int/float values, got {type_obj}") + + +def _is_mlir_value(value) -> bool: + return not isinstance(value, (bool, int, float)) and hasattr(value, "type") + + +__all__ = [ + "classify_runtime_scalar_type", + "emit_runtime_binary_op", + "emit_runtime_max", + "normalize_runtime_binary_operands", +] diff --git a/ptodsl/ptodsl/_scalar_coercion.py b/ptodsl/ptodsl/_scalar_coercion.py new file mode 100644 index 000000000..fc150a5f3 --- /dev/null +++ b/ptodsl/ptodsl/_scalar_coercion.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Shared authored scalar type-adaptation helpers for PTODSL surface lowering.""" + +from __future__ import annotations + +from ._runtime_scalar_ops import classify_runtime_scalar_type +from ._surface_values import unwrap_surface_value + +from mlir.dialects import arith +from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType + + +def coerce_scalar_to_type(value, target_type, *, context: str): + """Normalize one authored scalar value/literal to *target_type*.""" + raw_value = unwrap_surface_value(value) + if not hasattr(raw_value, "type"): + return materialize_scalar_literal(raw_value, target_type, context=context) + + if raw_value.type == target_type: + return raw_value + + source_kind = classify_runtime_scalar_type(raw_value.type) + target_kind = classify_runtime_scalar_type(target_type) + + if source_kind == "index" and target_kind == "integer": + return _coerce_integer_like(raw_value, target_type) + if source_kind == "integer" and target_kind == "index": + return arith.IndexCastOp(target_type, raw_value).result + if source_kind == "integer" and target_kind == "integer": + return _coerce_integer_like(raw_value, target_type) + if source_kind == "float" and target_kind == "float": + return _coerce_float_like(raw_value, target_type) + + raise TypeError( + f"{context} cannot coerce the authored value to the expected scalar type: " + f"got {raw_value.type}, expected {target_type}" + ) + + +def materialize_scalar_literal(value, target_type, *, context: str): + """Materialize one Python literal as an MLIR scalar constant of *target_type*.""" + if isinstance(value, bool): + raise TypeError(f"{context} does not accept bool literals") + + target_kind = classify_runtime_scalar_type(target_type) + if target_kind == "float": + return arith.ConstantOp(target_type, FloatAttr.get(target_type, float(value))).result + + if isinstance(value, float): + raise TypeError( + f"{context} cannot materialize a floating-point literal against non-floating " + f"target type {target_type}" + ) + + return arith.ConstantOp(target_type, int(value)).result + + +def _coerce_integer_like(raw_value, target_type): + if IndexType.isinstance(raw_value.type): + return arith.IndexCastOp(target_type, raw_value).result + source_width = IntegerType(raw_value.type).width + target_width = IntegerType(target_type).width + if source_width < target_width: + return arith.ExtSIOp(target_type, raw_value).result + if source_width > target_width: + return arith.TruncIOp(target_type, raw_value).result + return raw_value + + +def _coerce_float_like(raw_value, target_type): + source_width = _float_bytewidth(raw_value.type) + target_width = _float_bytewidth(target_type) + if source_width < target_width: + return arith.ExtFOp(target_type, raw_value).result + if source_width > target_width: + return arith.TruncFOp(target_type, raw_value).result + return raw_value + + +def _float_bytewidth(type_obj): + if BF16Type.isinstance(type_obj) or F16Type.isinstance(type_obj): + return 2 + if F32Type.isinstance(type_obj): + return 4 + raise TypeError(f"unsupported floating-point type {type_obj}") + + +__all__ = [ + "coerce_scalar_to_type", + "materialize_scalar_literal", +] diff --git a/ptodsl/ptodsl/_subkernels.py b/ptodsl/ptodsl/_subkernels.py new file mode 100644 index 000000000..64e4c481f --- /dev/null +++ b/ptodsl/ptodsl/_subkernels.py @@ -0,0 +1,168 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Layered PTODSL subkernel decorators.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from functools import update_wrapper +import inspect + +from ._diagnostics import ( + illegal_subkernel_placement_error, + simd_value_escape_error, + subkernel_host_tensor_boundary_error, + subkernel_signature_boundary_error, +) +from ._host_tensors import TensorSpec, looks_like_host_tensor +from ._surface_values import unwrap_surface_value +from ._tracing import current_runtime, current_session + + +class KernelRole(str, Enum): + UKERNEL = "ukernel" + CUBE = "cube" + SIMD = "simd" + SIMT = "simt" + + +@dataclass(frozen=True) +class SubkernelSpec: + """Declarative metadata for a PTODSL subkernel surface.""" + + role: KernelRole + symbol_name: str + target: str = "a5" + + +class SubkernelTemplate: + """Callable decorated PTODSL subkernel surface.""" + + def __init__(self, spec: SubkernelSpec, py_fn): + self.spec = spec + self.py_fn = py_fn + self.signature = inspect.signature(py_fn) + self._validate_definition() + update_wrapper(self, py_fn) + + def emit_body(self, *args, **kwargs): + """Emit this subkernel body into the currently active trace.""" + result = self.py_fn(*args, **kwargs) + self._validate_result(result) + return result + + def trace_body(self, *args, **kwargs): + """Backward-compatible alias for body emission.""" + return self.emit_body(*args, **kwargs) + + def __call__(self, *args, **kwargs): + runtime = current_runtime() + if runtime is None: + raise RuntimeError( + f"@pto.{self.spec.role.value} kernels may only be called while tracing " + "a compatible PTODSL kernel" + ) + self._validate_invocation(*args, **kwargs) + return runtime.dispatch_subkernel_call(self, *args, **kwargs) + + def _validate_definition(self) -> None: + for param in self.signature.parameters.values(): + if isinstance(param.annotation, TensorSpec): + raise subkernel_signature_boundary_error(self.spec.role.value, param.name) + + def _validate_invocation(self, *args, **kwargs) -> None: + session = current_session() + outer = session.current_subkernel if session is not None else None + if outer is not None: + if self.spec.role == KernelRole.UKERNEL or outer.role != KernelRole.UKERNEL.value: + raise illegal_subkernel_placement_error(self.spec.role.value, outer.role) + + bound = self.signature.bind_partial(*args, **kwargs) + for name, value in bound.arguments.items(): + if looks_like_host_tensor(value): + raise subkernel_host_tensor_boundary_error(self.spec.role.value, name) + + def _validate_result(self, result) -> None: + if self.spec.role != KernelRole.SIMD: + return + escaped_type = _find_transient_simd_escape(result) + if escaped_type is not None: + raise simd_value_escape_error(escaped_type) + + +def _find_transient_simd_escape(value): + if value is None: + return None + if isinstance(value, (tuple, list)): + for item in value: + escaped = _find_transient_simd_escape(item) + if escaped is not None: + return escaped + return None + if isinstance(value, dict): + for item in value.values(): + escaped = _find_transient_simd_escape(item) + if escaped is not None: + return escaped + return None + raw_value = unwrap_surface_value(value) + type_obj = getattr(raw_value, "type", None) + if type_obj is None: + return None + type_text = str(type_obj) + if type_text.startswith("!pto.vreg<") or type_text.startswith("!pto.mask<"): + return type_text + return None + + +def _subkernel_decorator(role: KernelRole, *, name: str | None = None, target: str = "a5"): + def decorator(fn): + return SubkernelTemplate( + SubkernelSpec( + role=role, + symbol_name=name or fn.__name__, + target=target, + ), + fn, + ) + + return decorator + + +def _decorate_subkernel(role: KernelRole, fn=None, *, name: str | None = None, target: str = "a5"): + if fn is not None: + return _subkernel_decorator(role, name=name, target=target)(fn) + return _subkernel_decorator(role, name=name, target=target) + + +def ukernel(fn=None, *, name: str | None = None, target: str = "a5"): + return _decorate_subkernel(KernelRole.UKERNEL, fn, name=name, target=target) + + +def cube(fn=None, *, name: str | None = None, target: str = "a5"): + return _decorate_subkernel(KernelRole.CUBE, fn, name=name, target=target) + + +def simd(fn=None, *, name: str | None = None, target: str = "a5"): + return _decorate_subkernel(KernelRole.SIMD, fn, name=name, target=target) + + +def simt(fn=None, *, name: str | None = None, target: str = "a5"): + return _decorate_subkernel(KernelRole.SIMT, fn, name=name, target=target) + + +__all__ = [ + "KernelRole", + "SubkernelSpec", + "SubkernelTemplate", + "ukernel", + "cube", + "simd", + "simt", +] diff --git a/ptodsl/ptodsl/_surface_types.py b/ptodsl/ptodsl/_surface_types.py new file mode 100644 index 000000000..e48cedb68 --- /dev/null +++ b/ptodsl/ptodsl/_surface_types.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Public PTODSL surface markers and enums.""" + +from ._bootstrap import make_context # noqa: F401 +from ._host_tensors import TensorSpec, tensor_spec + +from mlir.dialects import pto as _pto + + +class _ConstexprMarker: + """Marker annotation for PTODSL compile-time specialization parameters.""" + + def __repr__(self): + return "pto.constexpr" + + +constexpr = _ConstexprMarker() + + +class MemorySpace: + """Public PTODSL memory-space enum aliases.""" + + GM = _pto.AddressSpace.GM + UB = _pto.AddressSpace.VEC + VEC = _pto.AddressSpace.VEC + MAT = _pto.AddressSpace.MAT + LEFT = _pto.AddressSpace.LEFT + RIGHT = _pto.AddressSpace.RIGHT + ACC = _pto.AddressSpace.ACC + BIAS = _pto.AddressSpace.BIAS + SCALING = _pto.AddressSpace.SCALING + + +class BarrierType: + """Public PTODSL memory-barrier kind aliases.""" + + VV_ALL = "VV_ALL" + VST_VLD = "VST_VLD" + VLD_VST = "VLD_VST" + VST_VST = "VST_VST" + VS_ALL = "VS_ALL" + VST_LD = "VST_LD" + VLD_ST = "VLD_ST" + VST_ST = "VST_ST" + SV_ALL = "SV_ALL" + ST_VLD = "ST_VLD" + LD_VST = "LD_VST" + ST_VST = "ST_VST" + SS_ALL = "SS_ALL" + ST_LD = "ST_LD" + LD_ST = "LD_ST" + ST_ST = "ST_ST" + + +class Pipe: + """Public PTODSL pipeline aliases for pipeline-level sync ops.""" + + S = _pto.PIPE.PIPE_S + V = _pto.PIPE.PIPE_V + M = _pto.PIPE.PIPE_M + MTE1 = _pto.PIPE.PIPE_MTE1 + MTE2 = _pto.PIPE.PIPE_MTE2 + MTE3 = _pto.PIPE.PIPE_MTE3 + MTE4 = _pto.PIPE.PIPE_MTE4 + MTE5 = _pto.PIPE.PIPE_MTE5 + V2 = _pto.PIPE.PIPE_V2 + FIX = _pto.PIPE.PIPE_FIX + ALL = _pto.PIPE.PIPE_ALL + + +class TensorView: + """Authoring-time marker for a tensor-view descriptor value.""" + + +class PartitionTensorView: + """Authoring-time marker for a partitioned tensor-view descriptor value.""" + + +class Tile: + """Authoring-time marker for an on-chip tile value.""" + + +__all__ = [ + "constexpr", + "TensorSpec", + "MemorySpace", + "BarrierType", + "Pipe", + "TensorView", + "PartitionTensorView", + "Tile", + "tensor_spec", +] diff --git a/ptodsl/ptodsl/_surface_values.py b/ptodsl/ptodsl/_surface_values.py new file mode 100644 index 000000000..ca45f98a3 --- /dev/null +++ b/ptodsl/ptodsl/_surface_values.py @@ -0,0 +1,851 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time wrappers for authored PTODSL surface values.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +from ._diagnostics import native_python_control_flow_error +from ._runtime_scalar_ops import emit_runtime_binary_op +from ._surface_types import PartitionTensorView, TensorView, Tile +from ._types import _normalize_address_space, _resolve, ptr + +from mlir.dialects import arith +from mlir.dialects import memref +from mlir.dialects import pto as _pto +from mlir.ir import IndexType, MemRefType, ShapedType, StridedLayoutAttr, Type + + +def unwrap_surface_value(value): + """Return the underlying MLIR SSA value for a surface wrapper.""" + return value.value if isinstance(value, _SurfaceValue) else value + + +def _unwrap_sequence(values): + normalized = [] + for value in values: + if isinstance(value, int): + normalized.append(_index_const(value)) + else: + normalized.append(unwrap_surface_value(value)) + return normalized + + +def _normalize_index(value): + return unwrap_surface_value(value) + + +def _index_const(value: int): + return arith.ConstantOp(IndexType.get(), value).result + + +def _add_index(lhs, rhs): + lhs = _normalize_index(lhs) + rhs = _normalize_index(rhs) + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs + rhs + if isinstance(lhs, int): + lhs = _index_const(lhs) + if isinstance(rhs, int): + rhs = _index_const(rhs) + return arith.AddIOp(lhs, rhs).result + + +def _maybe_cast_tensor_view_type(type_obj): + try: + return _pto.TensorViewType(type_obj) + except Exception: + return None + + +def _maybe_cast_partition_tensor_view_type(type_obj): + try: + return _pto.PartitionTensorViewType(type_obj) + except Exception: + return None + + +def _maybe_cast_tile_buf_type(type_obj): + try: + return _pto.TileBufType(type_obj) + except Exception: + return None + + +def wrap_surface_value( + value, + *, + root_tensor_view=None, + offsets=None, + sizes=None, + tile_metadata=None, +): + """Wrap a raw MLIR value into the authored PTODSL surface type when needed.""" + if isinstance(value, _SurfaceValue): + return value + + type_obj = value.type + if _maybe_cast_tensor_view_type(type_obj) is not None: + return TensorViewValue(value) + if _maybe_cast_partition_tensor_view_type(type_obj) is not None: + return PartitionTensorViewValue( + value, + root_tensor_view=root_tensor_view, + offsets=offsets, + sizes=sizes, + ) + if _maybe_cast_tile_buf_type(type_obj) is not None: + return TileValue(value, **(tile_metadata or {})) + try: + MemRefType(type_obj) + return AddressValue(value) + except Exception: + pass + return RuntimeValue(value) + + +class _SurfaceValue: + """Base class for authored PTODSL values backed by one MLIR SSA value.""" + + def __init__(self, value): + self._value = value + + @property + def value(self): + return self._value + + @property + def type(self): + return self._value.type + + @property + def surface_metadata(self): + return None + + def __bool__(self): + raise native_python_control_flow_error("if/while condition") + + def __iter__(self): + raise native_python_control_flow_error("for-loop iteration") + + def __repr__(self): + return repr(self._value) + + +class RuntimeValue(_SurfaceValue): + """Generic authored runtime value wrapper with fail-fast Python misuse diagnostics.""" + + def __index__(self): + raise native_python_control_flow_error("range()/loop bound") + + def __int__(self): + raise native_python_control_flow_error("int() coercion") + + def __add__(self, other): + return wrap_surface_value(emit_runtime_binary_op("add", self.value, unwrap_surface_value(other))) + + def __radd__(self, other): + return wrap_surface_value(emit_runtime_binary_op("add", unwrap_surface_value(other), self.value)) + + def __sub__(self, other): + return wrap_surface_value(emit_runtime_binary_op("sub", self.value, unwrap_surface_value(other))) + + def __rsub__(self, other): + return wrap_surface_value(emit_runtime_binary_op("sub", unwrap_surface_value(other), self.value)) + + def __mul__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mul", self.value, unwrap_surface_value(other))) + + def __rmul__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mul", unwrap_surface_value(other), self.value)) + + def __truediv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("truediv", self.value, unwrap_surface_value(other))) + + def __rtruediv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("truediv", unwrap_surface_value(other), self.value)) + + def __floordiv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("floordiv", self.value, unwrap_surface_value(other))) + + def __rfloordiv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("floordiv", unwrap_surface_value(other), self.value)) + + def __mod__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mod", self.value, unwrap_surface_value(other))) + + def __rmod__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mod", unwrap_surface_value(other), self.value)) + + +class MaskResultValue(_SurfaceValue): + """Mask value that also supports `(mask, remained)` unpacking.""" + + def __init__(self, mask_value, scalar_out): + super().__init__(mask_value) + self.scalar_out = wrap_surface_value(scalar_out) + + def __iter__(self): + yield self + yield self.scalar_out + + +class AddressValue(_SurfaceValue): + """Author-facing address view backed by either a PTO ptr or a memref.""" + + def __add__(self, offset): + return AddressOffsetValue(self, offset) + + def __radd__(self, offset): + return AddressOffsetValue(self, offset) + + +@dataclass(frozen=True) +class AddressOffsetValue: + """Address view plus an element offset, used by scalar.load/store sugar.""" + + base: AddressValue + offset: object + + def __add__(self, other): + return AddressOffsetValue(self.base, _add_index(self.offset, other)) + + def __radd__(self, other): + return AddressOffsetValue(self.base, _add_index(other, self.offset)) + + def __bool__(self): + raise native_python_control_flow_error("if/while condition") + + def __iter__(self): + raise native_python_control_flow_error("for-loop iteration") + + +@dataclass(frozen=True) +class TileElementRef: + """One logical tile element selected by tile[row, col] surface syntax.""" + + tile: "TileValue" + linear_offset: object + + def __bool__(self): + raise native_python_control_flow_error("if/while condition") + + def __iter__(self): + raise native_python_control_flow_error("for-loop iteration") + + +class TileSliceValue(_SurfaceValue): + """Author-facing memref view produced by `tile[row, col:]` style indexing.""" + + def __init__(self, value, *, tile: "TileValue", offsets, shape): + super().__init__(value) + self.tile = tile + self.offsets = tuple(offsets) + self.shape = tuple(shape) + + @property + def surface_metadata(self): + return { + "tile": self.tile, + "offsets": self.offsets, + "shape": self.shape, + } + + +class TensorViewValue(_SurfaceValue, TensorView): + """Author-facing tensor-view descriptor value.""" + + def __init__(self, value, *, shape=None, strides=None): + super().__init__(value) + self.shape = tuple(shape) if shape is not None else None + self.strides = tuple(strides) if strides is not None else None + + @property + def surface_metadata(self): + return { + "shape": self.shape, + "strides": self.strides, + } + + def as_ptr(self, result_ptr_type=None): + from ._ops import as_ptr + return as_ptr(self, result_ptr_type) + + +class PartitionTensorViewValue(_SurfaceValue, PartitionTensorView): + """Author-facing partitioned tensor-view descriptor value.""" + + def __init__(self, value, *, root_tensor_view=None, offsets=None, sizes=None): + super().__init__(value) + self.root_tensor_view = root_tensor_view + self.offsets = tuple(offsets) if offsets is not None else None + self.sizes = tuple(sizes) if sizes is not None else None + self.shape = self.sizes + self.strides = getattr(root_tensor_view, "strides", None) + + def as_ptr(self, result_ptr_type=None): + from ._ops import as_ptr + return as_ptr(self, result_ptr_type) + + +class _TileValidShapeView: + """Tuple-like proxy that lowers `tile.valid_shape[i]` on demand.""" + + def __init__(self, tile: "TileValue"): + self._tile = tile + self._cache: dict[int, object] = {} + + def __getitem__(self, index: int): + if index not in {0, 1}: + raise IndexError("PTODSL tile.valid_shape currently supports indices 0 and 1") + cached = self._cache.get(index) + if cached is not None: + return cached + if self._tile.static_valid_shape is not None: + dim = self._tile.static_valid_shape[index] + if dim is not None: + value = _index_const(dim) if isinstance(dim, int) else unwrap_surface_value(dim) + value = wrap_surface_value(value) + self._cache[index] = value + return value + try: + if index == 0: + value = wrap_surface_value(_pto.TileValidRowsOp(self._tile.value).result) + else: + value = wrap_surface_value(_pto.TileValidColsOp(self._tile.value).result) + except Exception: + static_dim = _fallback_static_valid_dim(self._tile.type, index) + if static_dim is None: + raise RuntimeError( + "tile.valid_shape could not be lowered because the current " + "Python bindings do not materialize pto.tile_valid_* and " + "the tile type does not carry a recoverable static bound" + ) from None + value = wrap_surface_value(_index_const(static_dim)) + self._cache[index] = value + return value + + +class TileValue(_SurfaceValue, Tile): + """Author-facing tile handle with surface-style accessors.""" + + def __init__( + self, + value, + *, + shape=None, + dtype=None, + memory_space=None, + valid_shape=None, + ): + super().__init__(value) + parsed = parse_tile_type_metadata(value.type) + self.shape = tuple(shape) if shape is not None else ( + parsed["shape_dims"] if parsed is not None else None + ) + self.dtype = dtype if dtype is not None else ( + parsed["element_type"] if parsed is not None else None + ) + self.memory_space = memory_space if memory_space is not None else ( + parsed["memory_space"] if parsed is not None else None + ) + self.static_valid_shape = tuple(valid_shape) if valid_shape is not None else ( + parsed["valid_dims"] if parsed is not None else None + ) + self._valid_shape = _TileValidShapeView(self) + + @property + def valid_shape(self): + return self._valid_shape + + @valid_shape.setter + def valid_shape(self, dims): + from ._ops import set_tile_valid_shape + + set_tile_valid_shape(self, dims) + self.static_valid_shape = tuple(dims) + self._valid_shape._cache.clear() + + @property + def surface_metadata(self): + return { + "shape": self.shape, + "dtype": self.dtype, + "memory_space": self.memory_space, + "valid_shape": self.static_valid_shape, + } + + def as_ptr(self, result_ptr_type=None): + from ._ops import as_ptr + return as_ptr(self, result_ptr_type) + + def fill(self, value): + from ._ops import fill_tile + fill_tile(self, value) + + def __getitem__(self, key): + if not isinstance(key, tuple): + key = (key,) + if self.shape is None: + raise RuntimeError("tile indexing requires tile shape metadata") + + if _is_tile_slice_key(key, self.shape): + return _materialize_tile_slice(self, key) + + if len(key) != len(self.shape): + raise TypeError( + f"tile indexing expects {len(self.shape)} indices, got {len(key)}" + ) + linear_offset = 0 + stride = 1 + for index, dim in zip(reversed(key), reversed(self.shape)): + linear_offset = _add_index(linear_offset, _mul_index(index, stride)) + if dim is None: + raise RuntimeError("tile indexing requires static tile shape metadata") + stride *= dim + return TileElementRef(self, linear_offset) + + +@dataclass(frozen=True) +class PartitionSpec: + """Logical authored partition metadata used to compose nested slices.""" + + root_tensor_view: object + offsets: tuple + sizes: tuple + + +def wrap_like_surface_value(template, value): + """Wrap *value* using the same authored surface contract as *template*.""" + if isinstance(template, PartitionTensorViewValue): + return PartitionTensorViewValue( + value, + root_tensor_view=template.root_tensor_view, + offsets=template.offsets, + sizes=template.sizes, + ) + if isinstance(template, TensorViewValue): + return TensorViewValue(value, shape=template.shape, strides=template.strides) + if isinstance(template, TileValue): + return TileValue(value, **template.surface_metadata) + if isinstance(template, AddressValue): + return AddressValue(value) + return wrap_surface_value(value) + + +def extract_partition_spec(source) -> PartitionSpec | None: + """Return the root tensor-view + composed slice metadata when available.""" + if isinstance(source, PartitionTensorViewValue) and source.root_tensor_view is not None: + return PartitionSpec( + root_tensor_view=source.root_tensor_view, + offsets=source.offsets or (), + sizes=source.sizes or (), + ) + if isinstance(source, TensorViewValue): + return PartitionSpec(root_tensor_view=source, offsets=(), sizes=()) + return None + + +def compose_partition_spec(source, *, offsets, sizes) -> PartitionSpec | None: + """Compose a nested `partition_view(...)` against an existing partition.""" + parent = extract_partition_spec(source) + if parent is None: + return None + if parent.offsets and len(parent.offsets) != len(offsets): + raise ValueError("nested partition_view rank mismatch") + composed_offsets = tuple( + _add_index(parent_offset, child_offset) + for parent_offset, child_offset in zip(parent.offsets or [0] * len(offsets), offsets) + ) + return PartitionSpec( + root_tensor_view=parent.root_tensor_view, + offsets=composed_offsets, + sizes=tuple(sizes), + ) + + +def infer_ptr_type_from_surface_value(surface_value, result_ptr_type=None): + """Infer a PTO pointer type for `as_ptr()` when the caller omits one.""" + if result_ptr_type is not None: + return _resolve(result_ptr_type) + + value_type = surface_value.type + + tv_type = _maybe_cast_tensor_view_type(value_type) + if tv_type is not None: + return _resolve(ptr(tv_type.element_type, "gm")) + + part_type = _maybe_cast_partition_tensor_view_type(value_type) + if part_type is not None: + return _resolve(ptr(part_type.element_type, "gm")) + + tile_type = _maybe_cast_tile_buf_type(value_type) + if tile_type is None: + raise TypeError("as_ptr() expects a Tile, TensorView, or PartitionTensorView surface value") + + memory_space = getattr(tile_type, "memory_space", None) + parsed = None + if memory_space is None: + parsed = parse_tile_type_metadata(value_type) + if parsed is None: + raise RuntimeError("unable to infer tile pointer type: tile type is missing memory-space metadata") + memory_space = parsed["memory_space"] + + space_enum = getattr(memory_space, "value", None) + if space_enum is not None: + space_enum = _normalize_address_space(_ADDRESS_SPACE_VALUE_TO_KEYWORD.get(space_enum)) + else: + space_enum = _normalize_address_space(str(memory_space)) + if space_enum is None: + raise RuntimeError("unable to infer tile pointer type: unsupported tile memory space") + + return _resolve(ptr(tile_type.element_type, space_enum)) + + +def emit_as_ptr(surface_value, result_ptr_type=None): + """Lower `as_ptr()` on a surface value to the appropriate PTO op.""" + value = unwrap_surface_value(surface_value) + result_type = infer_address_type_from_surface_value(surface_value, result_ptr_type) + + if isinstance(surface_value, (TensorViewValue, PartitionTensorViewValue)): + return AddressValue(_pto.TensorViewAddrOp(result_type, value).result) + if isinstance(surface_value, TileValue): + return AddressValue(_pto.TileBufAddrOp(result_type, value).result) + raise TypeError("as_ptr() expects a Tile, TensorView, or PartitionTensorView surface value") + + +_TILE_TYPE_RE = re.compile( + r"!pto\.tile_buf<(?P[^,]+),\s*(?P.+?)x(?P[^,x>]+),\s*valid=(?P[^,>]+)(?:,.*)?>" +) + + +_ADDRESS_SPACE_VALUE_TO_KEYWORD = { + 1: "gm", + 2: "mat", + 3: "left", + 4: "right", + 5: "acc", + 6: "vec", + 7: "bias", + 8: "scaling", +} + + +def _read_tile_type_metadata_from_binding(type_obj): + required = ("shape", "element_type", "memory_space", "valid_shape") + if not all(hasattr(type_obj, name) for name in required): + return None + + memory_space_attr = type_obj.memory_space + memory_space_value = getattr(memory_space_attr, "value", None) + memory_space = _ADDRESS_SPACE_VALUE_TO_KEYWORD.get(memory_space_value) + if memory_space is None: + return None + + def _normalize_dims(seq): + dims = [] + for dim in seq: + dims.append(None if dim == ShapedType.get_dynamic_size() else int(dim)) + return tuple(dims) + + return { + "memory_space": memory_space, + "shape_dims": _normalize_dims(type_obj.shape), + "element_type": type_obj.element_type, + "valid_dims": _normalize_dims(type_obj.valid_shape), + } + + +def _fallback_static_valid_dim(type_obj, index: int): + parsed = parse_tile_type_metadata(type_obj) + if parsed is None: + return None + shape_dims = parsed["shape_dims"] + valid_dims = parsed["valid_dims"] + if index >= len(shape_dims) or index >= len(valid_dims): + return None + valid_dim = valid_dims[index] + if valid_dim is not None: + return valid_dim + return shape_dims[index] + + +def parse_tile_type_metadata(type_obj): + bound = _read_tile_type_metadata_from_binding(type_obj) + if bound is not None: + return bound + + match = _TILE_TYPE_RE.match(str(type_obj)) + if match is None: + return None + shape_dims = [ + None if dim == "?" else int(dim) + for dim in match.group("shape").split("x") + ] + valid_dims = [ + None if dim == "?" else int(dim) + for dim in match.group("valid").split("x") + ] + return { + "memory_space": match.group("space"), + "shape_dims": tuple(shape_dims), + "element_type": Type.parse(match.group("elem")), + "valid_dims": tuple(valid_dims), + } + + +def infer_tile_element_type(tile): + """Recover the tile element type from authored metadata or type text.""" + if isinstance(tile, TileValue) and tile.dtype is not None: + return _resolve(tile.dtype) + parsed = parse_tile_type_metadata(tile.type if isinstance(tile, TileValue) else tile) + if parsed is None: + raise RuntimeError("unable to recover tile element type from tile surface value") + return parsed["element_type"] + + +def infer_address_type_from_surface_value(surface_value, result_ptr_type=None): + """Infer the concrete result type emitted by `as_ptr()`.""" + return infer_ptr_type_from_surface_value(surface_value, result_ptr_type) + + +def infer_memref_type_from_surface_value(surface_value): + """Build a memref address-view type that preserves element/rank/address-space.""" + if isinstance(surface_value, TileSliceValue): + return surface_value.type + + if isinstance(surface_value, TileValue): + if surface_value.shape is not None and surface_value.dtype is not None and surface_value.memory_space is not None: + space_enum = _normalize_address_space(surface_value.memory_space) + if space_enum is None: + raise RuntimeError("unsupported tile memory space for memref address view") + return MemRefType.get( + list(surface_value.shape), + _resolve(surface_value.dtype), + memory_space=_pto.AddressSpaceAttr.get(space_enum), + ) + + value_type = surface_value.type + + tv_type = _maybe_cast_tensor_view_type(value_type) + if tv_type is not None: + return MemRefType.get( + [ShapedType.get_dynamic_size()] * tv_type.rank, + tv_type.element_type, + memory_space=_pto.AddressSpaceAttr.get(_pto.AddressSpace.GM), + ) + + part_type = _maybe_cast_partition_tensor_view_type(value_type) + if part_type is not None: + return MemRefType.get( + [ShapedType.get_dynamic_size()] * part_type.rank, + part_type.element_type, + memory_space=_pto.AddressSpaceAttr.get(_pto.AddressSpace.GM), + ) + + tile_type = _maybe_cast_tile_buf_type(value_type) + if tile_type is None: + raise TypeError("memref address inference expects a Tile, TensorView, or PartitionTensorView") + + parsed = parse_tile_type_metadata(value_type) + if parsed is None: + raise RuntimeError("unable to recover tile memref shape/address-space") + space_enum = _normalize_address_space(parsed["memory_space"]) + if space_enum is None: + raise RuntimeError("unsupported tile memory space for memref address view") + return MemRefType.get( + list(parsed["shape_dims"]), + parsed["element_type"], + memory_space=_pto.AddressSpaceAttr.get(space_enum), + ) + + +def resolve_address_access(target, offset=None): + """Normalize address/tile element sugar into `(buffer, index_offset)`.""" + if isinstance(target, TileElementRef): + base = emit_as_ptr(target.tile) + resolved_offset = target.linear_offset + elif isinstance(target, AddressOffsetValue): + base = target.base + resolved_offset = target.offset + elif isinstance(target, AddressValue): + base = target + resolved_offset = 0 + else: + base = target + resolved_offset = 0 + + if offset is not None: + resolved_offset = _add_index(resolved_offset, offset) + + return unwrap_surface_value(base), _coerce_index_value(resolved_offset) + + +def _is_tile_slice_key(key, shape): + if len(shape) == 1: + return len(key) == 1 and isinstance(key[0], slice) + if len(shape) == 2: + return len(key) == 2 and isinstance(key[1], slice) + return False + + +def _materialize_tile_slice(tile: TileValue, key): + rank = len(tile.shape) + if rank == 1: + start_slice = key[0] + if start_slice.stop is not None or start_slice.step is not None: + raise TypeError("tile[start:] only supports an open-ended slice") + start = 0 if start_slice.start is None else start_slice.start + return _build_tile_slice_view(tile, raw_offsets=[start], shape=[_dynamic_extent(tile.shape[0], start)]) + + row, col_slice = key + if col_slice.stop is not None or col_slice.step is not None: + raise TypeError("tile[row, col:] only supports an open-ended column slice") + col = 0 if col_slice.start is None else col_slice.start + return _build_tile_slice_view( + tile, + raw_offsets=[row, col], + shape=[_dynamic_extent(tile.shape[1], col)], + ) + + +def _build_tile_slice_view(tile: TileValue, *, raw_offsets, shape): + base_memref = _emit_tile_memref(tile) + base_type = MemRefType(base_memref.type) + rank = len(base_type.shape) + offset_operands, static_offsets = _split_dynamic_index_operands(raw_offsets) + shape_operands, static_shape = _split_dynamic_index_operands(shape) + if rank == 1: + slice_type = _make_strided_memref_type( + [_static_extent_if_known(shape[0])], + base_type.element_type, + [1], + base_type.memory_space, + ) + slice_value = memref.SubViewOp( + slice_type, + base_memref, + offset_operands, + shape_operands, + [], + static_offsets, + static_shape, + [1], + ).result + return TileSliceValue(slice_value, tile=tile, offsets=tuple(raw_offsets), shape=shape) + + row_type = _make_strided_memref_type( + [1, _static_extent_if_known(shape[0])], + base_type.element_type, + [base_type.shape[1], 1], + base_type.memory_space, + ) + row_view = memref.SubViewOp( + row_type, + base_memref, + offset_operands, + shape_operands, + [], + static_offsets, + [1, static_shape[0]], + [1, 1], + ).result + flat_type = _make_strided_memref_type( + [_static_extent_if_known(shape[0])], + base_type.element_type, + [1], + base_type.memory_space, + ) + slice_value = memref.CollapseShapeOp(flat_type, row_view, [[0, 1]]).result + return TileSliceValue(slice_value, tile=tile, offsets=tuple(raw_offsets), shape=shape) + + +def _emit_tile_memref(tile: TileValue): + memref_type = infer_memref_type_from_surface_value(tile) + return _pto.TileBufAddrOp(memref_type, tile.value).result + + +def _dynamic_extent(static_dim, start): + if isinstance(start, int): + return static_dim - start + return arith.SubIOp(_index_const(static_dim), start).result + + +def _static_extent_if_known(extent): + return extent if isinstance(extent, int) else ShapedType.get_dynamic_size() + + +def _static_index_attr(value): + return value if isinstance(value, int) else ShapedType.get_dynamic_size() + + +def _split_dynamic_index_operands(values): + operands = [] + static_attrs = [] + for value in values: + if isinstance(value, int): + static_attrs.append(value) + else: + operands.append(_coerce_index_value(value)) + static_attrs.append(ShapedType.get_dynamic_size()) + return operands, static_attrs + + +def _make_strided_memref_type(shape, element_type, strides, memory_space): + return MemRefType.get( + list(shape), + element_type, + StridedLayoutAttr.get(ShapedType.get_dynamic_size(), list(strides)), + memory_space, + ) + + +def _mul_index(lhs, rhs): + lhs = _normalize_index(lhs) + rhs = _normalize_index(rhs) + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs * rhs + if isinstance(lhs, int): + lhs = _index_const(lhs) + if isinstance(rhs, int): + rhs = _index_const(rhs) + return arith.MulIOp(lhs, rhs).result + + +def _coerce_index_value(value): + value = _normalize_index(value) + return _index_const(value) if isinstance(value, int) else value + + +__all__ = [ + "AddressOffsetValue", + "AddressValue", + "MaskResultValue", + "PartitionSpec", + "PartitionTensorViewValue", + "RuntimeValue", + "TileElementRef", + "TileSliceValue", + "TensorViewValue", + "TileValue", + "compose_partition_spec", + "emit_as_ptr", + "extract_partition_spec", + "infer_tile_element_type", + "infer_address_type_from_surface_value", + "infer_memref_type_from_surface_value", + "infer_ptr_type_from_surface_value", + "parse_tile_type_metadata", + "resolve_address_access", + "unwrap_surface_value", + "wrap_like_surface_value", + "wrap_surface_value", + "_unwrap_sequence", +] diff --git a/ptodsl/ptodsl/_tensor_factories.py b/ptodsl/ptodsl/_tensor_factories.py new file mode 100644 index 000000000..4336c3df9 --- /dev/null +++ b/ptodsl/ptodsl/_tensor_factories.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Small host-side tensor factory helpers used by PTODSL wrappers.""" + +from __future__ import annotations + + +def empty_like(tensor): + """Allocate one host-side tensor with the same logical metadata as *tensor*.""" + new_empty = getattr(tensor, "new_empty", None) + if callable(new_empty): + return new_empty(tensor.shape) + + try: + import torch # type: ignore + except Exception: + torch = None + if torch is not None and isinstance(tensor, torch.Tensor): + return torch.empty_like(tensor) + + try: + import numpy as np # type: ignore + except Exception: + np = None + if np is not None and isinstance(tensor, np.ndarray): + return np.empty_like(tensor) + + raise TypeError( + "pto.empty_like(...) could not infer how to allocate an output tensor for " + f"{type(tensor)!r}; provide O= explicitly or use a tensor type exposing " + ".new_empty(...), torch.empty_like, or numpy.empty_like support" + ) + + +__all__ = [ + "empty_like", +] diff --git a/ptodsl/ptodsl/vpto.py b/ptodsl/ptodsl/_tile_template_tracing.py similarity index 76% rename from ptodsl/ptodsl/vpto.py rename to ptodsl/ptodsl/_tile_template_tracing.py index cd8c3db23..59bcad95c 100644 --- a/ptodsl/ptodsl/vpto.py +++ b/ptodsl/ptodsl/_tile_template_tracing.py @@ -6,7 +6,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. """ -Experimental `ptodsl.vpto` POC for TileLang-style tile templates. +Tile-template tracing implementation for PTODSL tile templates. This module keeps the authored Python body close to TileLang-style templates, but traces execution directly into MLIR Python bindings instead of going through @@ -24,9 +24,9 @@ - ``vadd(lhs, rhs, mask)`` - ``vsts(vec, tile[row, col:], mask)`` -The goal of this POC is to validate a tracing-oriented VPTO frontend shape that -already builds real MLIR Python objects, while staying intentionally narrow and -readable for `tadd_template.py`. +The current goal is to keep a narrow tile-template tracing path that already +builds real MLIR Python objects, while keeping its scope explicit and aligned +with the main PTODSL tracing runtime. """ from __future__ import annotations @@ -34,9 +34,15 @@ import inspect from dataclasses import dataclass from pathlib import Path - from . import scalar as _scalar -from ._bootstrap import make_context +from ._surface_types import Tile +from ._tracing import ( + KernelModuleSpec, + ModuleArtifact, + ModuleStyle, + TracingRuntime, + require_active_runtime, +) from ._types import ( _resolve, float16 as _float16, @@ -52,11 +58,8 @@ vreg_type as _vreg_type, ) -from mlir.dialects import arith, func, pto as _pto, scf -from mlir.ir import Attribute, InsertionPoint, IntegerType, Location, Module, Operation, StringAttr, Type - - -_ACTIVE_TRACE = None +from mlir.dialects import arith, pto as _pto, scf +from mlir.ir import InsertionPoint, IntegerType, Type @dataclass(frozen=True) @@ -78,10 +81,6 @@ def __repr__(self) -> str: i8 = ScalarType("i8", lanes=256, mask_bits=8, bytewidth=1) -class Tile: - """Tile annotation marker for the tracing POC.""" - - @dataclass(frozen=True) class TileSpec: shape: tuple[int, int] @@ -182,13 +181,13 @@ def __getitem__(self, key): or not _is_index_like(key[0]) or not isinstance(key[1], slice) ): - raise TypeError("vpto POC only supports tile[row, col:] indexing") + raise TypeError("tile-template tracing only supports tile[row, col:] indexing") row, col_slice = key if col_slice.stop is not None or col_slice.step is not None: - raise TypeError("vpto POC only supports tile[row, col:] slices") + raise TypeError("tile-template tracing only supports tile[row, col:] slices") col = 0 if col_slice.start is None else col_slice.start if not _is_index_like(col): - raise TypeError("vpto POC only supports integer/index column offsets") + raise TypeError("tile-template tracing only supports integer/index column offsets") _validate_static_bound(row, self._spec.shape[0], "row") _validate_static_bound(col, self._spec.shape[1], "column") return _TileSlice(self, row=row, col=col) @@ -282,8 +281,16 @@ def __exit__(self, exc_type, exc, tb): self._trace._exit_for(self._handle, exc_type, exc, tb) -class _TraceBuilder: - def __init__(self, descriptor: "TracingKernelDescriptor", tile_specs: dict[str, TileSpec]): +class _TraceBuilder(TracingRuntime): + def __init__(self, descriptor: "TileTemplate", tile_specs: dict[str, TileSpec]): + super().__init__( + KernelModuleSpec( + function_name=descriptor.name, + target_arch=descriptor.target, + kernel_kind="vector", + module_style=ModuleStyle.NESTED, + ) + ) self.descriptor = descriptor self.tile_specs = tile_specs self._const_cache: dict[tuple[int, str], _Value] = {} @@ -291,63 +298,41 @@ def __init__(self, descriptor: "TracingKernelDescriptor", tile_specs: dict[str, self._row_offset_cache: dict[tuple[str, str], _Value] = {} self._loop_stack: list[dict] = [] self._inside_vecscope = False - - def build_module(self): - global _ACTIVE_TRACE - if _ACTIVE_TRACE is not None: - raise RuntimeError("nested vpto builds are not supported") - + self._ordered_specs: list[tuple[str, TileSpec]] = [] signature = inspect.signature(self.descriptor.py_fn) - ctx = make_context() - with ctx, Location.unknown(): - arg_types = [] - ordered_specs = [] - for param_name, param in signature.parameters.items(): - if not _is_tile_annotation(param.annotation): - raise TypeError( - "vpto POC currently only supports Tile parameters; " - f"parameter {param_name!r} uses {param.annotation!r}" - ) - spec = self.tile_specs.get(param_name) - if spec is None: - raise ValueError(f"missing specialization for Tile parameter {param_name!r}") - ordered_specs.append((param_name, spec)) - arg_types.append(spec.mlir_type()) - - module = Module.create() - module.operation.attributes["pto.target_arch"] = StringAttr.get(self.descriptor.target) - - with InsertionPoint(module.body): - inner_op = Operation.create("builtin.module", regions=1) - inner_op.attributes["pto.target_arch"] = StringAttr.get(self.descriptor.target) - inner_op.attributes["pto.kernel_kind"] = Attribute.parse("#pto.kernel_kind") - inner_body = inner_op.regions[0].blocks.append() - - with InsertionPoint(inner_body): - fn_ty = func.FunctionType.get(arg_types, []) - ir_fn = func.FuncOp(self.descriptor.name, fn_ty) - - entry = ir_fn.add_entry_block() - with InsertionPoint(entry): - args = [] - for arg_value, (_, spec) in zip(entry.arguments, ordered_specs): - args.append(_TileProxy(self, arg_value, spec)) - - _ACTIVE_TRACE = self - try: - self.descriptor.py_fn(*args) - finally: - _ACTIVE_TRACE = None - - if self._inside_vecscope: - raise RuntimeError("vpto kernel exited with an open vecscope block") - if self._loop_stack: - raise RuntimeError("vpto kernel exited with an open scf.for block") - - func.ReturnOp([]) - - module.operation.verify() - return module + self._signature_parameters = tuple(signature.parameters.items()) + + def compute_argument_types(self): + arg_types = [] + ordered_specs = [] + for param_name, param in self._signature_parameters: + if not _is_tile_annotation(param.annotation): + raise TypeError( + "tile-template tracing currently only supports Tile parameters; " + f"parameter {param_name!r} uses {param.annotation!r}" + ) + spec = self.tile_specs.get(param_name) + if spec is None: + raise ValueError(f"missing specialization for Tile parameter {param_name!r}") + ordered_specs.append((param_name, spec)) + arg_types.append(spec.mlir_type()) + self._ordered_specs = ordered_specs + return arg_types + + def bind_entry_arguments(self, entry_arguments): + args = [] + for arg_value, (_, spec) in zip(entry_arguments, self._ordered_specs): + args.append(_TileProxy(self, arg_value, spec)) + return tuple(args) + + def trace_entry(self, *args): + self.descriptor.py_fn(*args) + + def validate_trace_state(self): + if self._inside_vecscope: + raise RuntimeError("tile-template trace exited with an open vecscope block") + if self._loop_stack: + raise RuntimeError("tile-template trace exited with an open scf.for block") def vecscope(self) -> _VecScopeCM: return _VecScopeCM(self) @@ -368,17 +353,19 @@ def yield_(self, *vals): def _yield_loop_values(self, vals, *, surface: str, from_named_state: bool): if not self._loop_stack: - raise RuntimeError(f"{surface}(...) may only be used inside a vpto for_ block") + raise RuntimeError(f"{surface}(...) may only be used inside a tile-template for_ block") frame = self._loop_stack[-1] if frame["kind"] != "for": - raise RuntimeError(f"{surface}(...) may only be used inside a vpto for_ block") + raise RuntimeError(f"{surface}(...) may only be used inside a tile-template for_ block") if frame["state_names"] and not from_named_state: raise RuntimeError( - f"{surface}(...) is ambiguous for vpto for_ with named state; " + f"{surface}(...) is ambiguous for tile-template for_ with named state; " "use loop.yield_state(...) instead" ) if frame["yielded"]: - raise RuntimeError(f"{surface}(...) may only be emitted once per vpto for_ block") + raise RuntimeError( + f"{surface}(...) may only be emitted once per tile-template for_ block" + ) if len(vals) != len(frame["iter_args"]): raise RuntimeError( f"{surface}(...) expected {len(frame['iter_args'])} value(s), got {len(vals)}" @@ -428,7 +415,9 @@ def materialize_linear_offset(self, tile_slice: _TileSlice) -> _Value: def _enter_vecscope(self): if self._inside_vecscope: - raise RuntimeError("nested vpto vecscope blocks are not supported in this POC") + raise RuntimeError( + "nested tile-template vecscope blocks are not supported in the current implementation" + ) vecscope_op = _pto.VecScopeOp() vecscope_block = vecscope_op.body.blocks.append() vecscope_ip = InsertionPoint(vecscope_block) @@ -446,7 +435,7 @@ def _exit_vecscope(self, exc_type, exc, tb): raise RuntimeError("vecscope exit without matching enter") frame = self._loop_stack.pop() if frame["kind"] != "vecscope": - raise RuntimeError("vpto vecscope stack corruption detected") + raise RuntimeError("tile-template vecscope stack corruption detected") frame["ip"].__exit__(exc_type, exc, tb) self._inside_vecscope = False @@ -488,14 +477,14 @@ def _exit_for(self, handle: _LoopHandle | None, exc_type, exc, tb): raise RuntimeError("for_ exit without a loop handle") frame = self._loop_stack.pop() if frame["kind"] != "for" or frame["handle"] is not handle: - raise RuntimeError("vpto for_ stack corruption detected") + raise RuntimeError("tile-template for_ stack corruption detected") if exc_type is None: if frame["iter_args"] and not frame["yielded"]: if frame["state_names"]: raise RuntimeError( - "vpto for_ with named state requires explicit loop.yield_state(...)" + "tile-template for_ with named state requires explicit loop.yield_state(...)" ) - raise RuntimeError("vpto for_ with iter_args requires explicit yield_(...)") + raise RuntimeError("tile-template for_ with iter_args requires explicit yield_(...)") if not frame["iter_args"]: scf.YieldOp([]) frame["ip"].__exit__(exc_type, exc, tb) @@ -527,7 +516,7 @@ def _coerce_value(self, value) -> _Value: return self.index_const(value) if hasattr(value, "type"): return _Value(value) - raise TypeError(f"unsupported vpto scalar value {value!r}") + raise TypeError(f"unsupported tile-template scalar value {value!r}") def _coerce_like(self, value, ty: str) -> _Value: coerced = self._coerce_value(value) @@ -537,46 +526,35 @@ def _coerce_like(self, value, ty: str) -> _Value: @dataclass(frozen=True) -class TracingKernelDescriptor: +class TileTemplate: py_fn: object target: str op: str name: str source_label: str - def specialize(self, **tile_specs: TileSpec) -> "MaterializedTracingKernel": - return MaterializedTracingKernel(self, tile_specs) + def specialize(self, **tile_specs: TileSpec) -> "SpecializedTileTemplate": + return SpecializedTileTemplate(self, tile_specs) -class MaterializedTracingKernel: - def __init__(self, descriptor: TracingKernelDescriptor, tile_specs: dict[str, TileSpec]): +class SpecializedTileTemplate(ModuleArtifact): + def __init__(self, descriptor: TileTemplate, tile_specs: dict[str, TileSpec]): + super().__init__( + descriptor.name, + module_factory=lambda: _TraceBuilder(descriptor, tile_specs).build_module(), + ) self.descriptor = descriptor self.tile_specs = tile_specs - self._cached_module = None - - def build(self): - if self._cached_module is None: - self._cached_module = _TraceBuilder(self.descriptor, self.tile_specs).build_module() - return self._cached_module - def mlir_text(self) -> str: - return str(self.build()) - def emit(self, path: str | Path) -> None: - Path(path).write_text(self.mlir_text(), encoding="utf-8") - - def __str__(self) -> str: - return self.mlir_text() - - -def vkernel(*, target: str = "a5", op: str, name: str | None = None): +def tile_template(*, target: str = "a5", op: str, name: str | None = None): if target != "a5": - raise ValueError("vpto POC currently only supports target='a5'") + raise ValueError("tile-template tracing currently only supports target='a5'") def decorator(fn): source_path = Path(inspect.getsourcefile(fn) or "") descriptor_name = name or fn.__name__ - return TracingKernelDescriptor( + return TileTemplate( py_fn=fn, target=target, op=op, @@ -588,36 +566,38 @@ def decorator(fn): def vecscope() -> _VecScopeCM: - return _require_active_trace("vecscope").vecscope() + return require_active_runtime("vecscope", expected_type=_TraceBuilder).vecscope() def for_(start, stop, *, step, iter_args=None, state=None) -> _ForCM: - return _require_active_trace("for_").for_(start, stop, step=step, iter_args=iter_args, state=state) + return require_active_runtime("for_", expected_type=_TraceBuilder).for_( + start, stop, step=step, iter_args=iter_args, state=state + ) def yield_(*vals): - _require_active_trace("yield_").yield_(*vals) + require_active_runtime("yield_", expected_type=_TraceBuilder).yield_(*vals) def get_lanes(dtype: ScalarType) -> _Value: - return _require_active_trace("get_lanes").index_const(dtype.lanes) + return require_active_runtime("get_lanes", expected_type=_TraceBuilder).index_const(dtype.lanes) def scalar_const(value: int, dtype: ScalarType) -> _Value: - return _require_active_trace("scalar_const").scalar_const(value, dtype) + return require_active_runtime("scalar_const", expected_type=_TraceBuilder).scalar_const(value, dtype) def make_mask(dtype: ScalarType, remained) -> tuple[_MaskValue, _Value]: - trace = _require_active_trace("make_mask") + trace = require_active_runtime("make_mask", expected_type=_TraceBuilder) remained_val = trace._coerce_value(remained) expected_scalar_ty = str(_resolve(_scalar_descriptor(_scalar_type_for_mask(dtype)))) if remained_val.type_text != expected_scalar_ty: raise TypeError( - f"vpto POC expects make_mask remained to use {expected_scalar_ty}, got {remained_val.type_text}" + f"tile-template tracing expects make_mask remained to use {expected_scalar_ty}, got {remained_val.type_text}" ) if dtype.mask_bits not in {8, 16, 32}: raise ValueError(f"unsupported mask bit-width {dtype.mask_bits}") - mask_ty = _mask_type(f"b{dtype.mask_bits}") + mask_ty = _resolve(_mask_type(f"b{dtype.mask_bits}")) scalar_ty = IntegerType.get_signless(dtype.mask_bits) op_cls = getattr(_pto, f"PltB{dtype.mask_bits}Op", None) if op_cls is None: @@ -631,9 +611,9 @@ def make_mask(dtype: ScalarType, remained) -> tuple[_MaskValue, _Value]: def vlds(tile_slice: _TileSlice) -> _VectorValue: - trace = _require_active_trace("vlds") + trace = require_active_runtime("vlds", expected_type=_TraceBuilder) if not isinstance(tile_slice, _TileSlice): - raise TypeError("vpto POC only supports vlds(tile[row, col:])") + raise TypeError("tile-template tracing only supports vlds(tile[row, col:])") ptr_value = trace.ensure_tile_ptr(tile_slice.tile) offset = trace.materialize_linear_offset(tile_slice) vector_ty = _resolve(_vreg_type(tile_slice.tile.element_type.lanes, _scalar_descriptor(tile_slice.tile.element_type))) @@ -643,30 +623,24 @@ def vlds(tile_slice: _TileSlice) -> _VectorValue: def vadd(lhs: _VectorValue, rhs: _VectorValue, mask: _MaskValue) -> _VectorValue: if lhs.dtype != rhs.dtype: - raise TypeError("vpto POC expects vadd operands to use the same dtype") + raise TypeError("tile-template tracing expects vadd operands to use the same dtype") if lhs.dtype != mask.dtype: - raise TypeError("vpto POC expects vadd mask dtype to match vector dtype") + raise TypeError("tile-template tracing expects vadd mask dtype to match vector dtype") result = _pto.VaddOp(lhs.value.type, lhs.value, rhs.value, mask.value).result return _VectorValue(result, lhs.dtype) def vsts(vec: _VectorValue, tile_slice: _TileSlice, mask: _MaskValue) -> None: - trace = _require_active_trace("vsts") + trace = require_active_runtime("vsts", expected_type=_TraceBuilder) if vec.dtype != mask.dtype: - raise TypeError("vpto POC expects vsts mask dtype to match vector dtype") + raise TypeError("tile-template tracing expects vsts mask dtype to match vector dtype") if vec.dtype != tile_slice.tile.element_type: - raise TypeError("vpto POC expects vsts destination dtype to match vector dtype") + raise TypeError("tile-template tracing expects vsts destination dtype to match vector dtype") ptr_value = trace.ensure_tile_ptr(tile_slice.tile) offset = trace.materialize_linear_offset(tile_slice) _pto.VstsOp(vec.value, ptr_value.value, offset.value, mask.value) -def _require_active_trace(surface: str) -> _TraceBuilder: - if _ACTIVE_TRACE is None: - raise RuntimeError(f"{surface}() may only be used while tracing a vpto kernel") - return _ACTIVE_TRACE - - def _is_tile_annotation(annotation) -> bool: if annotation is Tile: return True @@ -719,8 +693,8 @@ def _scalar_type_for_mask(dtype: ScalarType) -> ScalarType: __all__ = [ "Tile", "TileSpec", - "TracingKernelDescriptor", - "MaterializedTracingKernel", + "TileTemplate", + "SpecializedTileTemplate", "ScalarType", "f32", "f16", @@ -728,7 +702,7 @@ def _scalar_type_for_mask(dtype: ScalarType) -> ScalarType: "i32", "i16", "i8", - "vkernel", + "tile_template", "vecscope", "for_", "yield_", diff --git a/ptodsl/ptodsl/_tracing/__init__.py b/ptodsl/ptodsl/_tracing/__init__.py new file mode 100644 index 000000000..70901127d --- /dev/null +++ b/ptodsl/ptodsl/_tracing/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Shared tracing runtime building blocks for PTODSL frontends.""" + +from .active import ( + activate_runtime, + activate_session, + current_runtime, + current_session, + require_active_runtime, + require_active_session, +) +from .artifacts import ModuleArtifact +from .module_builder import KernelModuleSpec, ModuleStyle, create_kernel_module +from .runtime import CallbackTracingRuntime, SignatureTracingRuntime, TracingRuntime +from .session import HelperFunctionSpec, SubkernelTraceFrame, TraceSession + +__all__ = [ + "activate_runtime", + "activate_session", + "current_runtime", + "current_session", + "require_active_runtime", + "require_active_session", + "ModuleArtifact", + "KernelModuleSpec", + "ModuleStyle", + "create_kernel_module", + "CallbackTracingRuntime", + "SignatureTracingRuntime", + "TracingRuntime", + "HelperFunctionSpec", + "SubkernelTraceFrame", + "TraceSession", +] diff --git a/ptodsl/ptodsl/_tracing/active.py b/ptodsl/ptodsl/_tracing/active.py new file mode 100644 index 000000000..0a32a2b52 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/active.py @@ -0,0 +1,86 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Active tracing-runtime stack shared by PTODSL frontends.""" + +from __future__ import annotations + +from contextlib import contextmanager + +_ACTIVE_RUNTIME_STACK = [] +_ACTIVE_SESSION_STACK = [] + + +@contextmanager +def activate_runtime(runtime): + """Push *runtime* as the current active tracing runtime.""" + _ACTIVE_RUNTIME_STACK.append(runtime) + try: + yield runtime + finally: + popped = _ACTIVE_RUNTIME_STACK.pop() + if popped is not runtime: + raise RuntimeError("PTODSL active tracing runtime stack corruption detected") + + +@contextmanager +def activate_session(session): + """Push *session* as the current active trace session.""" + _ACTIVE_SESSION_STACK.append(session) + try: + yield session + finally: + popped = _ACTIVE_SESSION_STACK.pop() + if popped is not session: + raise RuntimeError("PTODSL active trace-session stack corruption detected") + + +def current_runtime(expected_type=None): + """Return the current active tracing runtime, or ``None`` if inactive.""" + if not _ACTIVE_RUNTIME_STACK: + return None + runtime = _ACTIVE_RUNTIME_STACK[-1] + if expected_type is not None and not isinstance(runtime, expected_type): + return None + return runtime + + +def current_session(): + """Return the current active trace session, or ``None`` if inactive.""" + if not _ACTIVE_SESSION_STACK: + return None + return _ACTIVE_SESSION_STACK[-1] + + +def require_active_runtime(surface: str, expected_type=None): + """Return the active runtime or raise a surface-specific error.""" + runtime = current_runtime(expected_type=expected_type) + if runtime is None: + raise RuntimeError( + f"{surface}() may only be used while tracing a compatible PTODSL kernel" + ) + return runtime + + +def require_active_session(surface: str): + """Return the active trace session or raise a surface-specific error.""" + session = current_session() + if session is None: + raise RuntimeError( + f"{surface}() may only be used while tracing a compatible PTODSL kernel" + ) + return session + + +__all__ = [ + "activate_runtime", + "activate_session", + "current_runtime", + "current_session", + "require_active_runtime", + "require_active_session", +] diff --git a/ptodsl/ptodsl/_tracing/artifacts.py b/ptodsl/ptodsl/_tracing/artifacts.py new file mode 100644 index 000000000..650d0517f --- /dev/null +++ b/ptodsl/ptodsl/_tracing/artifacts.py @@ -0,0 +1,58 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Reusable module-backed artifacts for PTODSL tracing frontends.""" + +from __future__ import annotations + +from pathlib import Path + + +class ModuleArtifact: + """ + Cached module-backed artifact. + + Subclasses may either pass an eager ``module`` or a lazy ``module_factory``. + """ + + def __init__(self, py_name: str, *, module=None, module_factory=None): + self._py_name = py_name + self._cached_module = module + self._module_factory = module_factory + + def build(self): + """Return the cached ``mlir.ir.Module``.""" + if self._cached_module is None: + if self._module_factory is None: + raise RuntimeError(f"{self._py_name} has no module factory") + self._cached_module = self._module_factory() + return self._cached_module + + def mlir_module(self): + """Return the cached ``mlir.ir.Module``.""" + return self.build() + + def mlir_text(self) -> str: + """Return the textual MLIR form.""" + return str(self.build()) + + def verify(self) -> None: + """Verify the cached module operation.""" + self.build().operation.verify() + + def emit(self, path: str | Path) -> None: + """Write the textual MLIR form to *path*.""" + Path(path).write_text(self.mlir_text(), encoding="utf-8") + + def __str__(self): + return self.mlir_text() + + def __repr__(self): + return self.mlir_text() + + +__all__ = ["ModuleArtifact"] diff --git a/ptodsl/ptodsl/_tracing/control_flow.py b/ptodsl/ptodsl/_tracing/control_flow.py new file mode 100644 index 000000000..c2c370fe7 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/control_flow.py @@ -0,0 +1,92 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time helpers for structured PTODSL control-flow lowering.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from .._runtime_index_ops import coerce_runtime_index +from .._surface_values import unwrap_surface_value + +from mlir.dialects import scf +from mlir.ir import InsertionPoint + + +@dataclass +class CarryLoopFrame: + """Active loop-carry lowering frame for one authored ``pto.for_().carry()``.""" + + for_op: object + insertion_point: InsertionPoint + state_names: tuple[str, ...] + state_templates: tuple[object, ...] + yielded: bool = False + + +def build_carry_loop_frame(start, stop, step, state_items) -> CarryLoopFrame: + """Materialize one ``scf.for`` carry loop and enter its body insertion point.""" + state_items = tuple(state_items) + state_names = tuple(name for name, _ in state_items) + state_templates = tuple(value for _, value in state_items) + iter_args = [unwrap_surface_value(value) for value in state_templates] + for_op = scf.ForOp( + _coerce_index(start), + _coerce_index(stop), + _coerce_index(step), + iter_args, + ) + insertion_point = InsertionPoint(for_op.body) + insertion_point.__enter__() + return CarryLoopFrame( + for_op=for_op, + insertion_point=insertion_point, + state_names=state_names, + state_templates=state_templates, + ) + + +def yield_carry_loop_state(frame: CarryLoopFrame, **kwargs) -> None: + """Validate one ``loop.update(...)`` call and emit the matching ``scf.yield``.""" + missing = [name for name in frame.state_names if name not in kwargs] + extra = [name for name in kwargs if name not in frame.state_names] + if missing or extra: + pieces = [] + if missing: + pieces.append(f"missing: {', '.join(missing)}") + if extra: + pieces.append(f"unexpected: {', '.join(extra)}") + raise RuntimeError("loop.update(...) must match carry names exactly; " + "; ".join(pieces)) + if frame.yielded: + raise RuntimeError("loop.update(...) may only be called once per loop body") + scf.YieldOp([unwrap_surface_value(kwargs[name]) for name in frame.state_names]) + frame.yielded = True + + +def finish_carry_loop_frame(frame: CarryLoopFrame, exc_type, exc, tb) -> None: + """Leave one active carry-loop frame and close its insertion point.""" + try: + if exc_type is None and not frame.yielded: + raise RuntimeError( + "pto.for_(...).carry(...) requires loop.update(...) before leaving the loop body" + ) + finally: + frame.insertion_point.__exit__(exc_type, exc, tb) + + +def _coerce_index(value): + raw_value = unwrap_surface_value(value) + return coerce_runtime_index(raw_value, context="pto.for_(...).carry(...) loop bound") + + +__all__ = [ + "CarryLoopFrame", + "build_carry_loop_frame", + "yield_carry_loop_state", + "finish_carry_loop_frame", +] diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py new file mode 100644 index 000000000..87a1c2d0f --- /dev/null +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Common MLIR module/container builders for PTODSL tracing frontends.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from mlir.dialects import func +from mlir.ir import Attribute, InsertionPoint, Module, Operation, StringAttr, UnitAttr + + +class ModuleStyle(str, Enum): + """Supported top-level PTODSL module layouts.""" + + FLAT_AICORE = "flat_aicore" + NESTED = "nested" + + +@dataclass(frozen=True) +class KernelModuleSpec: + """Declarative description of a traced PTODSL kernel container.""" + + function_name: str + target_arch: str + kernel_kind: str + module_style: ModuleStyle = ModuleStyle.NESTED + + +def _kernel_kind_attr(kernel_kind: str): + return Attribute.parse(f"#pto.kernel_kind<{kernel_kind}>") + + +def _build_flat_aicore_module(spec: KernelModuleSpec, arg_types): + module = Module.create() + module.operation.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) + module.operation.attributes["pto.kernel_kind"] = _kernel_kind_attr(spec.kernel_kind) + fn_ty = func.FunctionType.get(arg_types, []) + with InsertionPoint(module.body): + ir_fn = func.FuncOp(spec.function_name, fn_ty) + ir_fn.attributes["pto.aicore"] = UnitAttr.get() + return module, ir_fn + + +def _build_nested_module(spec: KernelModuleSpec, arg_types): + outer = Module.create() + outer.operation.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) + + with InsertionPoint(outer.body): + inner_op = Operation.create("builtin.module", regions=1) + inner_op.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) + inner_op.attributes["pto.kernel_kind"] = _kernel_kind_attr(spec.kernel_kind) + inner_body = inner_op.regions[0].blocks.append() + + with InsertionPoint(inner_body): + fn_ty = func.FunctionType.get(arg_types, []) + ir_fn = func.FuncOp(spec.function_name, fn_ty) + + return outer, ir_fn + + +def create_kernel_module(spec: KernelModuleSpec, arg_types): + """Create the top-level module and entry function for *spec*.""" + if spec.module_style == ModuleStyle.FLAT_AICORE: + return _build_flat_aicore_module(spec, arg_types) + if spec.module_style == ModuleStyle.NESTED: + return _build_nested_module(spec, arg_types) + raise ValueError(f"unsupported PTODSL module style {spec.module_style!r}") + + +__all__ = [ + "KernelModuleSpec", + "ModuleStyle", + "create_kernel_module", +] diff --git a/ptodsl/ptodsl/_tracing/runtime.py b/ptodsl/ptodsl/_tracing/runtime.py new file mode 100644 index 000000000..358ce6b10 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/runtime.py @@ -0,0 +1,131 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Base tracing runtimes shared by PTODSL frontends.""" + +from __future__ import annotations + +from .active import activate_runtime, activate_session, require_active_session +from .module_builder import create_kernel_module +from .session import TraceSession +from .._bootstrap import make_context +from .._types import _resolve + +from mlir.dialects import func +from mlir.ir import InsertionPoint, Location + + +class TracingRuntime: + """Shared module-building runtime for tracing-based PTODSL frontends.""" + + def __init__(self, module_spec): + self.module_spec = module_spec + + def compute_argument_types(self): + """Return the MLIR entry argument types for this runtime.""" + raise NotImplementedError + + def bind_entry_arguments(self, entry_arguments): + """Wrap raw entry-block arguments into surface values.""" + return tuple(entry_arguments) + + def trace_entry(self, *args): + """Emit the traced function body using wrapped entry arguments.""" + raise NotImplementedError + + def validate_trace_state(self): + """Validate runtime-local tracing state before the function returns.""" + + def emit_return(self): + """Emit the function return terminator.""" + func.ReturnOp([]) + + def verify_module(self, module): + """Verify the completed module.""" + module.operation.verify() + + def create_session(self, module, entry_function): + """Create the shared trace session for this build.""" + return TraceSession(self.module_spec, module, entry_function) + + def initialize_session(self, session, entry_block): + """Populate runtime-specific session state before tracing.""" + session.bind_entry_block(entry_block) + + def finalize_session(self, session): + """Finalize runtime-specific session state after tracing.""" + + def dispatch_subkernel_call(self, subkernel, *args, **kwargs): + """Dispatch a decorated PTODSL subkernel call in the active trace.""" + session = require_active_session(f"@pto.{subkernel.spec.role.value}") + if subkernel.spec.role.value in {"ukernel", "cube", "simd"}: + return session.lower_inline_subkernel(subkernel, *args, **kwargs) + if subkernel.spec.role.value == "simt": + return session.lower_simt_helper_subkernel(subkernel, *args, **kwargs) + return subkernel.emit_body(*args, **kwargs) + + def build_module(self): + """Materialize the full MLIR module for this runtime.""" + ctx = make_context() + with ctx, Location.unknown(): + arg_types = list(self.compute_argument_types()) + module, ir_fn = create_kernel_module(self.module_spec, arg_types) + session = self.create_session(module, ir_fn) + entry = ir_fn.add_entry_block() + with InsertionPoint(entry), activate_runtime(self), activate_session(session): + self.initialize_session(session, entry) + args = self.bind_entry_arguments(entry.arguments) + self.trace_entry(*args) + self.validate_trace_state() + self.emit_return() + self.finalize_session(session) + session.validate_final_state() + self.verify_module(module) + return module + + +class CallbackTracingRuntime(TracingRuntime): + """Small tracing runtime for eager callback-style module materialization.""" + + def __init__(self, module_spec, arg_types, callback): + super().__init__(module_spec) + self._arg_types = tuple(arg_types) + self._callback = callback + + def compute_argument_types(self): + return tuple(_resolve(arg_type) for arg_type in self._arg_types) + + def trace_entry(self, *args): + self._callback(*args) + + +class SignatureTracingRuntime(TracingRuntime): + """Tracing runtime that binds a parsed PTODSL kernel signature.""" + + def __init__(self, module_spec, kernel_signature, callback, *, constexpr_bindings=None): + super().__init__(module_spec) + self._kernel_signature = kernel_signature + self._callback = callback + self._constexpr_bindings = dict(constexpr_bindings or {}) + + def compute_argument_types(self): + return self._kernel_signature.compute_entry_arg_types() + + def bind_entry_arguments(self, entry_arguments): + return self._kernel_signature.bind_entry_arguments(entry_arguments) + + def trace_entry(self, *args): + kwargs = self._kernel_signature.default_constexpr_bindings() + kwargs.update(self._constexpr_bindings) + self._callback(*args, **kwargs) + + +__all__ = [ + "CallbackTracingRuntime", + "SignatureTracingRuntime", + "TracingRuntime", +] diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py new file mode 100644 index 000000000..e12fd2a36 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/session.py @@ -0,0 +1,205 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Trace-session objects shared by PTODSL tracing runtimes.""" + +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass + +from .control_flow import ( + build_carry_loop_frame, + finish_carry_loop_frame, + yield_carry_loop_state, +) +from .._surface_values import unwrap_surface_value, wrap_like_surface_value + +from mlir.dialects import arith, func +from mlir.dialects import pto as _pto +from mlir.ir import InsertionPoint, IntegerType, UnitAttr + + +@dataclass(frozen=True) +class HelperFunctionSpec: + """Declarative description of a helper function emitted during tracing.""" + + symbol_name: str + arg_types: tuple + result_types: tuple = () + attributes: tuple[tuple[str, object], ...] = () + + +@dataclass(frozen=True) +class SubkernelTraceFrame: + """Active inline-lowering frame for one PTODSL subkernel call.""" + + role: str + symbol_name: str + target: str + + +class TraceSession: + """Shared per-build state for a traced PTODSL module.""" + + def __init__(self, module_spec, module, entry_function): + self.module_spec = module_spec + self.module = module + self.entry_function = entry_function + self.entry_block = None + self._function_stack = [entry_function] + self._function_symbol_table = entry_function.operation.parent.regions[0].blocks[0] + self._helpers: dict[str, object] = {} + self._subkernel_stack: list[SubkernelTraceFrame] = [] + self._carry_loop_stack = [] + + @property + def current_function(self): + return self._function_stack[-1] + + @property + def current_subkernel(self): + if not self._subkernel_stack: + return None + return self._subkernel_stack[-1] + + @property + def subkernel_stack_depth(self): + return len(self._subkernel_stack) + + @property + def current_carry_loop(self): + if not self._carry_loop_stack: + return None + return self._carry_loop_stack[-1] + + def bind_entry_block(self, entry_block) -> None: + """Record the root entry block for the active trace.""" + self.entry_block = entry_block + + @contextmanager + def enter_function(self, ir_fn): + """Push *ir_fn* as the current active function in this session.""" + self._function_stack.append(ir_fn) + try: + yield ir_fn + finally: + popped = self._function_stack.pop() + if popped is not ir_fn: + raise RuntimeError("PTODSL trace-session function stack corruption detected") + + @contextmanager + def enter_subkernel(self, subkernel): + """Push *subkernel* as the current active inline-lowering frame.""" + frame = SubkernelTraceFrame( + role=subkernel.spec.role.value, + symbol_name=subkernel.spec.symbol_name, + target=subkernel.spec.target, + ) + self._subkernel_stack.append(frame) + try: + yield frame + finally: + popped = self._subkernel_stack.pop() + if popped is not frame: + raise RuntimeError("PTODSL trace-session subkernel stack corruption detected") + + def lower_inline_subkernel(self, subkernel, *args, **kwargs): + """Lower one inline PTODSL subkernel call through the shared session.""" + with self.enter_subkernel(subkernel): + return subkernel.emit_body(*args, **kwargs) + + def begin_carry_loop(self, start, stop, step, state_items): + """Materialize one authored ``pto.for_(...).carry(...)`` loop body.""" + frame = build_carry_loop_frame(start, stop, step, state_items) + self._carry_loop_stack.append(frame) + return frame + + def update_carry_loop(self, frame, **kwargs): + """Emit the one legal ``loop.update(...)`` for the active carry loop.""" + active = self.current_carry_loop + if active is None or active is not frame: + raise RuntimeError("loop.update(...) may only be called inside the active carry loop body") + yield_carry_loop_state(frame, **kwargs) + + def finish_carry_loop(self, frame, exc_type, exc, tb): + """Finalize one active authored carry loop and close its body insertion point.""" + if not self._carry_loop_stack: + raise RuntimeError("carry-loop exit without a matching active PTODSL trace-session frame") + popped = self._carry_loop_stack.pop() + if popped is not frame: + raise RuntimeError("PTODSL trace-session carry-loop stack corruption detected") + finish_carry_loop_frame(frame, exc_type, exc, tb) + + def lower_simt_helper_subkernel(self, subkernel, *args, **kwargs): + """Lower one ``@pto.simt`` call through a dedicated helper function.""" + outer_frame = self.current_subkernel + if outer_frame is not None and outer_frame.role == "simt": + raise RuntimeError("@pto.simt helper lowering does not support nested SIMT helper calls") + + arg_templates = tuple(args) + arg_types = tuple(unwrap_surface_value(arg).type for arg in arg_templates) + helper_spec = HelperFunctionSpec( + symbol_name=subkernel.spec.symbol_name, + arg_types=arg_types, + attributes=(("pto.simt_entry", UnitAttr.get()),), + ) + helper_fn, created = self.get_or_create_helper_function(helper_spec) + + if created: + entry_block = helper_fn.add_entry_block() + wrapped_args = tuple( + wrap_like_surface_value(template, value) + for template, value in zip(arg_templates, entry_block.arguments) + ) + with self.enter_function(helper_fn), self.enter_subkernel(subkernel), InsertionPoint(entry_block): + subkernel.emit_body(*wrapped_args, **kwargs) + func.ReturnOp([]) + + i32 = IntegerType.get_signless(32) + dim_z = arith.ConstantOp(i32, 1).result + dim_y = arith.ConstantOp(i32, 1).result + dim_x = arith.ConstantOp(i32, 1).result + _pto.StoreVfSimtInfoOp(dim_z, dim_y, dim_x) + func.CallOp(helper_fn, [unwrap_surface_value(arg) for arg in arg_templates]) + + def lookup_helper(self, symbol_name: str): + """Return a previously declared helper function, or ``None``.""" + return self._helpers.get(symbol_name) + + def get_or_create_helper_function(self, spec: HelperFunctionSpec): + """ + Look up or create a helper ``func.func`` in the current symbol table. + + Returns ``(helper_fn, created)`` where *created* reports whether a new + symbol was emitted in this trace session. + """ + helper = self._helpers.get(spec.symbol_name) + if helper is not None: + return helper, False + + fn_ty = func.FunctionType.get(list(spec.arg_types), list(spec.result_types)) + with InsertionPoint(self._function_symbol_table): + helper = func.FuncOp(spec.symbol_name, fn_ty) + for attr_name, attr_value in spec.attributes: + helper.attributes[attr_name] = attr_value + self._helpers[spec.symbol_name] = helper + return helper, True + + def validate_final_state(self) -> None: + """Check that tracing-time session stacks were fully unwound.""" + if self._subkernel_stack: + raise RuntimeError("PTODSL trace-session exited with an open subkernel lowering frame") + if self._carry_loop_stack: + raise RuntimeError("PTODSL trace-session exited with an open loop-carry lowering frame") + + +__all__ = [ + "HelperFunctionSpec", + "SubkernelTraceFrame", + "TraceSession", +] diff --git a/ptodsl/ptodsl/_types.py b/ptodsl/ptodsl/_types.py index 693b69a35..5d822f5a7 100644 --- a/ptodsl/ptodsl/_types.py +++ b/ptodsl/ptodsl/_types.py @@ -16,13 +16,14 @@ def softmax(arg0: pto.ptr(pto.float32, "GM"), ...): ... where the annotation is evaluated at *import* time (no active context), and -the actual type is materialised later by the ``@pto.to_ir`` decorator. +the actual type is materialised later by the ``@pto.jit`` decorator. """ from ._bootstrap import make_context # ensure MLIR is on sys.path from mlir.dialects import pto as _pto from mlir.ir import ( + BF16Type, F16Type, F32Type, IndexType, @@ -37,10 +38,20 @@ def softmax(arg0: pto.ptr(pto.float32, "GM"), ...): "gm": _pto.AddressSpace.GM, "vec": _pto.AddressSpace.VEC, "mat": _pto.AddressSpace.MAT, + "left": _pto.AddressSpace.LEFT, + "right": _pto.AddressSpace.RIGHT, + "acc": _pto.AddressSpace.ACC, + "bias": _pto.AddressSpace.BIAS, + "scaling": _pto.AddressSpace.SCALING, "GM": _pto.AddressSpace.GM, "UB": _pto.AddressSpace.VEC, "VEC": _pto.AddressSpace.VEC, "MAT": _pto.AddressSpace.MAT, + "LEFT": _pto.AddressSpace.LEFT, + "RIGHT": _pto.AddressSpace.RIGHT, + "ACC": _pto.AddressSpace.ACC, + "BIAS": _pto.AddressSpace.BIAS, + "SCALING": _pto.AddressSpace.SCALING, } @@ -66,14 +77,26 @@ def __init__(self, elem, space: str): def resolve(self) -> Type: elem = _resolve(self._elem) - space_enum = _ADDR_SPACE.get(self._space) + space_enum = _normalize_address_space(self._space) if space_enum is None: raise ValueError( f"Unknown address space '{self._space}'; " f"known: {list(_ADDR_SPACE)}" ) space_attr = _pto.AddressSpaceAttr.get(space_enum) - return _pto.PtrType.get(elem, memory_space=space_attr) + try: + return _pto.PtrType.get(elem, memory_space=space_attr) + except TypeError: + ptr_get_impl = getattr(_pto, "_ptr_type_get_impl", None) + if ptr_get_impl is None: + raise + if space_enum != _pto.AddressSpace.GM: + raise TypeError( + "The current PTO Python bindings only expose the default-GM " + "PtrType builder. Non-GM pointer construction is not " + "available through ptodsl._types.ptr(...) yet." + ) + return ptr_get_impl(elem) def __repr__(self): return f"" @@ -86,12 +109,35 @@ def __init__(self, lanes: int, elem): def resolve(self) -> Type: elem = _resolve(self._elem) - return Type.parse(f"!pto.vreg<{self._lanes}x{elem}>") + vreg_type_cls = getattr(_pto, "VRegType", None) + if vreg_type_cls is None: + raise TypeError( + "The current PTO Python bindings do not expose VRegType. " + "Rebuild the PTO Python extension before using pto.vreg_type(...)." + ) + return vreg_type_cls.get(self._lanes, elem) def __repr__(self): return f"" +class _MaskDescriptor(_DType): + def __init__(self, bits: str): + self._bits = bits + + def resolve(self) -> Type: + mask_type_cls = getattr(_pto, "MaskType", None) + if mask_type_cls is None: + raise TypeError( + "The current PTO Python bindings do not expose MaskType. " + "Rebuild the PTO Python extension before using pto.mask_type(...)." + ) + return mask_type_cls.get(self._bits) + + def __repr__(self): + return f"" + + def _resolve(dtype) -> Type: """Coerce a ``_DType`` descriptor or a concrete ``mlir.ir.Type`` to a Type.""" if isinstance(dtype, _DType): @@ -99,10 +145,20 @@ def _resolve(dtype) -> Type: return dtype # already an mlir.ir.Type +def _normalize_address_space(space): + if isinstance(space, str): + return _ADDR_SPACE.get(space) + if isinstance(space, _pto.AddressSpace): + return space + return None + + # ── Scalar dtype singletons ─────────────────────────────────────────────────── float32 = _DType(F32Type.get) float16 = _DType(F16Type.get) +bf16 = _DType(BF16Type.get) +int1 = _DType(lambda: IntegerType.get_signless(1)) int8 = _DType(lambda: IntegerType.get_signless(8)) int16 = _DType(lambda: IntegerType.get_signless(16)) int32 = _DType(lambda: IntegerType.get_signless(32)) @@ -122,9 +178,9 @@ def vreg_type(lanes: int, elem) -> _VRegDescriptor: return _VRegDescriptor(lanes, elem) -def mask_type(bits: str = "b32") -> Type: - """Return ``!pto.mask`` (b8 | b16 | b32). Requires active context.""" - return Type.parse(f"!pto.mask<{bits}>") +def mask_type(bits: str = "b32") -> _MaskDescriptor: + """Return a lazy descriptor for ``!pto.mask``.""" + return _MaskDescriptor(bits) def tile_buf_type(shape, dtype, valid_shape, *, @@ -142,7 +198,7 @@ def tile_buf_type(shape, dtype, valid_shape, *, Requires an active MLIR context. """ elem = _resolve(dtype) - space_enum = _ADDR_SPACE.get(address_space) + space_enum = _normalize_address_space(address_space) if space_enum is None: raise ValueError( f"Unknown address_space '{address_space}'; known: {list(_ADDR_SPACE)}" @@ -170,7 +226,7 @@ def part_tensor_view_type(rank: int, elem) -> Type: __all__ = [ "_DType", "_resolve", - "float32", "float16", "int8", "int16", "int32", "int64", "index", + "float32", "float16", "bf16", "int1", "int8", "int16", "int32", "int64", "index", "ptr", "vreg_type", "mask_type", "tile_buf_type", "tensor_view_type", "part_tensor_view_type", ] diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index e28c34fe0..36f473fe1 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -22,13 +22,25 @@ # ── Types ───────────────────────────────────────────────────────────────────── from ._types import ( # noqa: F401 - float32, float16, - int8, int16, int32, int64, + float32, float16, bf16, + int1, int8, int16, int32, int64, index, ptr, vreg_type, mask_type, tile_buf_type, tensor_view_type, part_tensor_view_type, _resolve, ) +from ._surface_types import ( # noqa: F401 + constexpr, + tensor_spec, + TensorSpec, + BarrierType, + Pipe, + MemorySpace, + TensorView, + PartitionTensorView, + Tile, +) +from ._tensor_factories import empty_like # noqa: F401 # ── Operations ──────────────────────────────────────────────────────────────── from ._ops import ( # noqa: F401 @@ -36,11 +48,17 @@ castptr, addptr, vlds, vbrc_load, vsts, vsts_1pt, plt_b32, pset_b32, + make_mask, vadd, vmul, vmax, vdiv, vcmax, vcadd, vdup, vexpdif, + vexp, vcgmax, vcgadd, vsubs, make_tensor_view, partition_view, - alloc_tile, tload, tstore, tile_ptr, - get_block_idx, barrier_all, + alloc_tile, tload, tstore, tmov, as_ptr, + mte_load, mte_store, mem_bar, + mte_l1_l0a, mte_l1_l0b, mte_l0c_ub, mad, + get_block_idx, get_block_num, get_subblock_idx, get_subblock_num, + store_vfsimt_info, get_tid_x, get_tid_y, get_tid_z, + pipe_barrier, set_flag, wait_flag, ) @@ -52,7 +70,17 @@ ) # ── Decorator ───────────────────────────────────────────────────────────────── -from ._module import to_ir, KernelHandle # noqa: F401 +from ._jit import jit, KernelHandle # noqa: F401 +from ._subkernels import ukernel, cube, simd, simt # noqa: F401 # ── Scalar sub-namespace ────────────────────────────────────────────────────── from . import scalar # noqa: F401 + +# ── Shorthand dtype aliases ─────────────────────────────────────────────────── +f32 = float32 +f16 = float16 +i1 = int1 +i8 = int8 +i16 = int16 +i32 = int32 +i64 = int64 diff --git a/ptodsl/ptodsl/scalar.py b/ptodsl/ptodsl/scalar.py index 4902112c0..40c973517 100644 --- a/ptodsl/ptodsl/scalar.py +++ b/ptodsl/ptodsl/scalar.py @@ -8,15 +8,25 @@ """ Scalar arithmetic helpers – exposed as ``pto.scalar.*`` (or ``s = pto.scalar``). -All functions operate on raw ``mlir.ir.Value`` objects and emit the +Arithmetic helpers operate on raw ``mlir.ir.Value`` objects and emit the corresponding arith dialect operations at the active insertion point. +Scalar memory helpers (`load` / `store`) also accept PTODSL surface-level +address views such as `tile[row, col]` and `tile.as_ptr() + offset`. """ from ._bootstrap import make_context # ensure MLIR is on sys.path # noqa: F401 +from ._scalar_coercion import coerce_scalar_to_type +from ._runtime_scalar_ops import ( + classify_runtime_scalar_type, + emit_runtime_max, +) +from ._surface_values import resolve_address_access, unwrap_surface_value, wrap_surface_value from ._types import _resolve from mlir.dialects import arith -from mlir.ir import IndexType, IntegerType +from mlir.dialects import math +from mlir.dialects import pto as _pto +from mlir.ir import IndexType, MemRefType, Operation _CMPI_PREDICATES = { "eq": arith.CmpIPredicate.eq, @@ -34,17 +44,17 @@ def muli(lhs, rhs): """arith.muli""" - return arith.MulIOp(lhs, rhs).result + return wrap_surface_value(arith.MulIOp(unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result) def addi(lhs, rhs): """arith.addi""" - return arith.AddIOp(lhs, rhs).result + return wrap_surface_value(arith.AddIOp(unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result) def subi(lhs, rhs): """arith.subi""" - return arith.SubIOp(lhs, rhs).result + return wrap_surface_value(arith.SubIOp(unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result) def index_cast(type_or_val, val=None): @@ -58,8 +68,8 @@ def index_cast(type_or_val, val=None): """ if val is None: # 1-arg form: cast to index - return arith.IndexCastOp(IndexType.get(), type_or_val).result - return arith.IndexCastOp(_resolve(type_or_val), val).result + return wrap_surface_value(arith.IndexCastOp(IndexType.get(), unwrap_surface_value(type_or_val)).result) + return wrap_surface_value(arith.IndexCastOp(_resolve(type_or_val), unwrap_surface_value(val)).result) def cmpi(pred: str, lhs, rhs): @@ -74,17 +84,79 @@ def cmpi(pred: str, lhs, rhs): raise ValueError( f"Unknown cmpi predicate '{pred}'; known: {list(_CMPI_PREDICATES)}" ) - return arith.CmpIOp(predicate, lhs, rhs).result + return wrap_surface_value( + arith.CmpIOp(predicate, unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result + ) def cmpi_sgt(lhs, rhs): """arith.cmpi sgt (signed greater-than).""" - return arith.CmpIOp(arith.CmpIPredicate.sgt, lhs, rhs).result + return wrap_surface_value(arith.CmpIOp( + arith.CmpIPredicate.sgt, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + ).result) def select(cond, true_val, false_val): """arith.select""" - return arith.SelectOp(cond, true_val, false_val).result - - -__all__ = ["muli", "addi", "subi", "index_cast", "cmpi", "cmpi_sgt", "select"] + return wrap_surface_value(arith.SelectOp( + unwrap_surface_value(cond), + unwrap_surface_value(true_val), + unwrap_surface_value(false_val), + ).result) + + +def max(lhs, rhs): + """Runtime scalar maximum across float / integer / index values.""" + return wrap_surface_value(emit_runtime_max( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + )) + + +def exp(value): + """Runtime scalar exponential for floating-point values.""" + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind != "float": + raise TypeError(f"scalar.exp(...) expects a floating-point runtime scalar, got {raw_value.type}") + return wrap_surface_value(math.ExpOp(raw_value).result) + + +def load(ptr_or_ref, offset=None): + """Load one scalar element from a PTODSL address view or tile element.""" + buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) + result_type = _infer_buffer_element_type(buffer_value.type) + return wrap_surface_value(Operation.create( + "pto.load", + results=[result_type], + operands=[buffer_value, index_value], + ).results[0]) + + +def store(value, ptr_or_ref, offset=None): + """Store one scalar element to a PTODSL address view or tile element.""" + buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) + elem_type = _infer_buffer_element_type(buffer_value.type) + Operation.create( + "pto.store", + operands=[buffer_value, index_value, coerce_scalar_to_type(value, elem_type, context="scalar.store(...)")], + ) + + +def _infer_buffer_element_type(buffer_type): + try: + return _pto.PtrType(buffer_type).element_type + except Exception: + return MemRefType(buffer_type).element_type + + +__all__ = [ + "muli", "addi", "subi", + "index_cast", + "cmpi", "cmpi_sgt", + "select", + "max", "exp", + "load", "store", +] diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index c56756f56..ab5788683 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -51,6 +51,9 @@ def get_op_result_or_value(value): register_dialect = _pto_mod.register_dialect PtrType = _pto_mod.PtrType +VRegType = _pto_mod.VRegType +MaskType = _pto_mod.MaskType +AlignType = _pto_mod.AlignType AsyncSessionType = _pto_mod.AsyncSessionType AsyncEventType = _pto_mod.AsyncEventType HiF8Type = _pto_mod.HiF8Type @@ -115,9 +118,19 @@ def _ptr_type_get_compat(cls, element_type, memory_space=None, context=None): raise TypeError("PtrType.get got multiple context arguments") context = memory_space memory_space = None - return _ptr_type_get_impl( - element_type, memory_space=memory_space, context=context - ) + if memory_space is None: + if context is None: + return _ptr_type_get_impl(element_type) + return _ptr_type_get_impl(element_type, context=context) + try: + return _ptr_type_get_impl( + element_type, memory_space=memory_space, context=context + ) + except TypeError as exc: + raise TypeError( + "PtrType.get(element_type, memory_space=...) requires a PTO Python " + "extension built with non-default address-space pointer support" + ) from exc PtrType.get = classmethod(_ptr_type_get_compat) @@ -162,6 +175,9 @@ def compat_init(self, *args, __orig_init=original_init, precision_mode=None, **k "register_dialect", # Types "PtrType", + "VRegType", + "MaskType", + "AlignType", "AsyncSessionType", "AsyncEventType", "HiF8Type", diff --git a/test/python/ptodsl_jit_compile.py b/test/python/ptodsl_jit_compile.py new file mode 100644 index 000000000..754f116af --- /dev/null +++ b/test/python/ptodsl_jit_compile.py @@ -0,0 +1,618 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import re +import sys + + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "ptodsl")) + +from ptodsl import pto +from ptodsl._bootstrap import make_context +from ptodsl._tracing import current_session +from mlir.ir import Location + + +def expect(condition: bool, message: str) -> None: + if not condition: + raise AssertionError(message) + + +@pto.jit(target="a5") +def host_vec_copy( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) + out = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) + pto.tload(part, a_tile) + pto.tstore(o_tile, out) + + +@pto.jit(target="a5") +def runtime_metadata_kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A) + o_view = pto.make_tensor_view(O) + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, valid_shape=[rows, cols]) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, valid_shape=[rows, cols]) + part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) + out = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) + pto.tload(part, a_tile) + pto.tstore(o_tile, out) + + +SUBKERNEL_OBSERVATIONS = [] + + +@pto.simd +def nested_simd_probe(): + session = current_session() + frame = session.current_subkernel + SUBKERNEL_OBSERVATIONS.append((frame.role, frame.symbol_name, session.subkernel_stack_depth)) + + +@pto.cube +def top_level_cube_probe(): + session = current_session() + frame = session.current_subkernel + SUBKERNEL_OBSERVATIONS.append((frame.role, frame.symbol_name, session.subkernel_stack_depth)) + + +@pto.ukernel +def ukernel_probe(): + session = current_session() + frame = session.current_subkernel + SUBKERNEL_OBSERVATIONS.append((frame.role, frame.symbol_name, session.subkernel_stack_depth)) + nested_simd_probe() + + +@pto.jit(target="a5") +def shared_subkernel_lowering_probe(*, TRACE_TOKEN: pto.constexpr = 0): + top_level_cube_probe() + ukernel_probe() + nested_simd_probe() + + +@pto.simt +def simt_tid_probe(): + pto.get_tid_x() + pto.get_tid_y() + pto.get_tid_z() + + +@pto.jit(target="a5") +def simt_helper_lowering_probe(*, TRACE_TOKEN: pto.constexpr = 0): + simt_tid_probe() + simt_tid_probe() + + +@pto.jit(target="a5") +def carry_loop_lowering_probe(*, BLOCK: pto.constexpr = 128): + m_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + l_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + m_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + l_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + m_prev.fill(0.0) + l_prev.fill(0.0) + o_prev.fill(0.0) + + kv_loop = pto.for_(0, 4, step=1).carry(m=m_prev, l=l_prev, o=o_prev) + with kv_loop: + kv_loop.m.fill(1.0) + kv_loop.l.fill(2.0) + kv_loop.o.fill(3.0) + kv_loop.update(m=m_next, l=l_next, o=o_next) + + final_o = kv_loop.final("o") + final_o.fill(4.0) + + +@pto.jit(target="a5") +def runtime_scalar_operator_probe( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 8, +): + rows = A.shape[0] + cols = A.shape[1] + block_idx = pto.get_block_idx() + o_view = pto.make_tensor_view(O) + o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) + o_ptr = o_part.as_ptr() + + batch_idx = block_idx // rows + head_idx = block_idx % rows + chunks = (cols + BLOCK - 1) // BLOCK + tail = cols % BLOCK + + x = pto.const(2.0, dtype=pto.f32) + y = (x + 1.0) * 2.0 + z = 4.0 - y + w = 1.0 / z + m = pto.scalar.max(w, x) + e = pto.scalar.exp(m) + pto.scalar.store(e, o_ptr + 0) + + _ = batch_idx + _ = head_idx + _ = chunks + _ = tail + _ = w + _ = m + _ = e + + +@pto.simd +def tile_slice_vector_probe(inp_tile: pto.Tile, out_tile: pto.Tile, row: pto.index): + mask, _ = pto.plt_b32(pto.const(64, dtype=pto.i32)) + vec = pto.vlds(inp_tile[row, 0:]) + pto.vsts(vec, out_tile[row, 0:], mask) + + +@pto.jit(target="a5") +def tile_slice_surface_probe(*, BLOCK: pto.constexpr = 128): + inp_tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) + out_tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) + with pto.for_(0, 1, step=1) as row: + tile_slice_vector_probe(inp_tile, out_tile, row) + + +@pto.jit(target="a5") +def tile_valid_shape_update_probe( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + tile = pto.alloc_tile( + shape=[1, BLOCK], + dtype=pto.f32, + valid_shape=[pto.const(1), cols], + ) + tile.valid_shape = [rows, cols] + + +@pto.jit(target="a5") +def integer_loop_bound_probe(*, BLOCK: pto.constexpr = 8): + row_start = pto.const(0, dtype=pto.i32) + row_stop = pto.const(BLOCK, dtype=pto.i32) + valid_dim = pto.const(BLOCK // 2, dtype=pto.i32) + with pto.for_(row_start, row_stop, step=1) as row: + with pto.for_(0, valid_dim, step=1) as col: + _ = row + _ = col + + +@pto.jit(target="a5") +def scalar_pointer_offset_probe(): + meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) + meta_ptr = meta_tile.as_ptr() + pto.scalar.store(0, meta_ptr, 0) + pto.scalar.store(1, meta_ptr, 1) + pto.scalar.store(2, meta_ptr + 2) + row_start = pto.scalar.load(meta_ptr, 0) + row_stop = pto.scalar.load(meta_ptr, 1) + valid_cols = pto.scalar.load(meta_ptr + 2) + _ = row_start + _ = row_stop + _ = valid_cols + + +@pto.simt +def simt_pointer_offset_helper(meta_ptr: pto.ptr(pto.i32, pto.MemorySpace.UB)): + pto.scalar.store(7, meta_ptr + 0) + pto.scalar.store(9, meta_ptr + 1) + + +@pto.jit(target="a5") +def simt_pointer_offset_probe(): + meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 2]) + simt_pointer_offset_helper(meta_tile.as_ptr()) + first = pto.scalar.load(meta_tile.as_ptr() + 0) + second = pto.scalar.load(meta_tile.as_ptr() + 1) + _ = first + _ = second + + +@pto.jit(target="a5") +def scalar_store_element_coercion_probe(): + meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 4]) + meta_ptr = meta_tile.as_ptr() + row_start = pto.const(0) + row_stop = pto.const(4) + pto.scalar.store(row_start, meta_ptr + 0) + pto.scalar.store(row_stop, meta_ptr + 1) + pto.scalar.store(pto.const(2, dtype=pto.i64), meta_ptr + 2) + pto.scalar.store(3, meta_ptr + 3) + + +@pto.simd +def public_vector_surface_probe(inp_tile: pto.Tile, out_tile: pto.Tile, stats_tile: pto.Tile): + col_mask = pto.make_mask(pto.f32, pto.const(16, dtype=pto.i32)) + row = pto.const(0) + s_row = pto.vlds(inp_tile[row, 0:]) + row_max = pto.vcgmax(s_row, col_mask) + s_shifted = pto.vsubs(s_row, row_max, col_mask) + p_row = pto.vexp(s_shifted, col_mask) + row_sum = pto.vcgadd(p_row, col_mask) + pto.vsts(p_row, out_tile[row, 0:], col_mask) + pto.scalar.store(row_max, stats_tile[row, 0]) + pto.scalar.store(row_sum, stats_tile[row, 1]) + + +@pto.cube +def public_cube_surface_probe( + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + lhs_l0a: pto.Tile, + rhs_l0b: pto.Tile, + acc_tile: pto.Tile, + out_tile: pto.Tile, +): + m = pto.const(16) + k = pto.const(16) + n = pto.const(16) + pto.mte_l1_l0a(lhs_tile.as_ptr(), lhs_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(rhs_tile.as_ptr(), rhs_l0b.as_ptr(), k, n, transpose=True) + pto.mad(lhs_l0a.as_ptr(), rhs_l0b.as_ptr(), acc_tile.as_ptr(), m, n, k) + pto.mte_l0c_ub(acc_tile.as_ptr(), out_tile.as_ptr(), m, n, n, n, 0) + + +@pto.ukernel +def public_mte_surface_probe( + inp_part: pto.PartitionTensorView, + out_part: pto.PartitionTensorView, + dma_tile: pto.Tile, +): + pto.mte_load(inp_part, dma_tile) + pto.pipe_barrier(pto.Pipe.ALL) + pto.mte_store(dma_tile, out_part) + pto.mem_bar(pto.BarrierType.VST_VLD) + pto.pipe_barrier(pto.Pipe.ALL) + + +@pto.jit(target="a5") +def public_surface_exports_probe( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), +): + cols = A.shape[1] + a_view = pto.make_tensor_view(A) + o_view = pto.make_tensor_view(O) + a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[1, cols]) + + dma_tile = pto.alloc_tile(shape=[1, 128], dtype=pto.f32, valid_shape=[1, cols]) + public_mte_surface_probe(a_part, o_part, dma_tile) + + vec_in = pto.alloc_tile(shape=[1, 128], dtype=pto.f32, valid_shape=[1, 16]) + vec_out = pto.alloc_tile(shape=[1, 128], dtype=pto.f32, valid_shape=[1, 16]) + stats_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.f32, valid_shape=[1, 2]) + public_vector_surface_probe(vec_in, vec_out, stats_tile) + + lhs_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.MAT, + valid_shape=[16, 16], + ) + rhs_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.MAT, + valid_shape=[16, 16], + ) + lhs_l0a = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.LEFT, + valid_shape=[16, 16], + ) + rhs_l0b = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.RIGHT, + valid_shape=[16, 16], + ) + acc_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f32, + memory_space=pto.MemorySpace.ACC, + valid_shape=[16, 16], + ) + cube_out = pto.alloc_tile(shape=[16, 16], dtype=pto.f32, valid_shape=[16, 16]) + public_cube_surface_probe(lhs_tile, rhs_tile, lhs_l0a, rhs_l0b, acc_tile, cube_out) + + +class _FakeTensor: + def __init__(self, shape): + self.shape = tuple(shape) + + def new_empty(self, shape): + return _FakeTensor(shape) + + +def main() -> None: + expected_public_exports = [ + "make_mask", + "vexp", + "vcgmax", + "vcgadd", + "vsubs", + "mte_load", + "mte_store", + "mem_bar", + "BarrierType", + "Pipe", + "pipe_barrier", + "mte_l1_l0a", + "mte_l1_l0b", + "mte_l0c_ub", + "mad", + "empty_like", + ] + for name in expected_public_exports: + expect(hasattr(pto, name), f"pto.{name} should be exported from the public namespace") + + fake_tensor = _FakeTensor((2, 3, 4)) + fake_empty = pto.empty_like(fake_tensor) + expect(isinstance(fake_empty, _FakeTensor), "pto.empty_like(...) should preserve host tensor factory type") + expect(fake_empty.shape == fake_tensor.shape, "pto.empty_like(...) should preserve the logical tensor shape") + expect(not hasattr(pto.scalar, "sts"), "scalar.sts should not remain in the public scalar namespace") + + with make_context() as ctx, Location.unknown(ctx): + tile_buf_ty = pto.tile_buf_type( + [16, 32], + pto.f32, + [16, 8], + address_space="mat", + blayout="ColMajor", + slayout="RowMajor", + ) + expect(hasattr(tile_buf_ty, "memory_space"), "TileBufType should expose a memory_space accessor") + expect(hasattr(tile_buf_ty, "shape"), "TileBufType should expose a shape accessor") + expect(hasattr(tile_buf_ty, "valid_shape"), "TileBufType should expose a valid_shape accessor") + expect(hasattr(tile_buf_ty, "element_type"), "TileBufType should expose an element_type accessor") + expect(tile_buf_ty.memory_space.value == pto.MemorySpace.MAT.value, "TileBufType.memory_space should preserve the authored address space") + expect(list(tile_buf_ty.shape) == [16, 32], "TileBufType.shape should preserve the authored physical shape") + expect(list(tile_buf_ty.valid_shape) == [16, 8], "TileBufType.valid_shape should preserve the authored valid shape") + expect(str(tile_buf_ty.element_type) == "f32", "TileBufType.element_type should preserve the authored element type") + + host_vec_copy.verify() + runtime_metadata_kernel.verify() + shared_subkernel_lowering_probe.verify() + simt_helper_lowering_probe.verify() + carry_loop_lowering_probe.verify() + runtime_scalar_operator_probe.verify() + tile_slice_surface_probe.verify() + tile_valid_shape_update_probe.verify() + integer_loop_bound_probe.verify() + scalar_pointer_offset_probe.verify() + simt_pointer_offset_probe.verify() + scalar_store_element_coercion_probe.verify() + public_surface_exports_probe.verify() + + default_compiled = host_vec_copy.compile() + explicit_default = host_vec_copy.compile(BLOCK=128) + block64 = host_vec_copy.compile(BLOCK=64) + + expect(default_compiled is explicit_default, "default constexpr compile should hit specialization cache") + expect(default_compiled is not block64, "different constexpr values should materialize different specializations") + expect(len(host_vec_copy.cached_specializations()) == 2, "expected exactly two cached specializations") + expect(default_compiled.constexpr_bindings == {"BLOCK": 128}, "default constexpr binding mismatch") + expect(block64.constexpr_bindings == {"BLOCK": 64}, "BLOCK=64 constexpr binding mismatch") + expect( + default_compiled.specialization_key.abi_signature == block64.specialization_key.abi_signature, + "ABI signature should stay stable across constexpr-only specializations", + ) + expect( + default_compiled.specialization_key.constexpr_signature + != block64.specialization_key.constexpr_signature, + "constexpr specialization key should differ when BLOCK changes", + ) + + default_text = default_compiled.mlir_text() + block64_text = block64.mlir_text() + expect("!pto.tile_buf" in default_text, "default specialization MLIR missing BLOCK=128 tile") + expect("!pto.tile_buf" in block64_text, "BLOCK=64 specialization MLIR missing specialized tile") + expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") + + runtime_metadata_text = runtime_metadata_kernel.compile().mlir_text() + expect( + "pto.make_tensor_view %arg0, shape = [%arg1, %arg2], strides = [%arg3, %arg4]" in runtime_metadata_text, + "make_tensor_view(A) should materialize runtime shape/stride metadata from the tensor proxy", + ) + expect( + "pto.alloc_tile valid_row = %arg1 valid_col = %arg2 : !pto.tile_buf" in runtime_metadata_text, + "alloc_tile(valid_shape=[rows, cols]) should lower runtime metadata through valid_row/valid_col operands", + ) + expect( + "sizes = [%arg1, %arg2]" in runtime_metadata_text, + "partition_view sizes derived from tensor metadata should remain runtime MLIR values", + ) + + tile_valid_shape_text = tile_valid_shape_update_probe.compile().mlir_text() + expect( + re.search( + r"pto\.set_validshape %[0-9]+, %arg1, %arg2 : !pto\.tile_buf", + tile_valid_shape_text, + ) is not None, + "tile.valid_shape = [rows, cols] should lower to pto.set_validshape on a dynamic-valid tile", + ) + + SUBKERNEL_OBSERVATIONS.clear() + shared_subkernel_lowering_probe.compile(TRACE_TOKEN=1) + expect( + SUBKERNEL_OBSERVATIONS == [ + ("cube", "top_level_cube_probe", 1), + ("ukernel", "ukernel_probe", 1), + ("simd", "nested_simd_probe", 2), + ("simd", "nested_simd_probe", 1), + ], + f"unexpected shared subkernel lowering observations: {SUBKERNEL_OBSERVATIONS!r}", + ) + + simt_text = simt_helper_lowering_probe.compile(TRACE_TOKEN=1).mlir_text() + expect( + simt_text.count("pto.store_vfsimt_info") == 2, + "each @pto.simt callsite should materialize a caller-side store_vfsimt_info", + ) + expect( + simt_text.count("call @simt_tid_probe()") == 2, + "each @pto.simt callsite should lower to a func.call of the helper symbol", + ) + expect( + simt_text.count("func.func @simt_tid_probe() attributes {pto.simt_entry}") == 1, + "@pto.simt helper should materialize exactly one reusable pto.simt_entry function", + ) + expect("pto.get_tid_x" in simt_text, "SIMT helper body should contain pto.get_tid_x") + expect("pto.get_tid_y" in simt_text, "SIMT helper body should contain pto.get_tid_y") + expect("pto.get_tid_z" in simt_text, "SIMT helper body should contain pto.get_tid_z") + + carry_text = carry_loop_lowering_probe.compile(BLOCK=32).mlir_text() + expect("scf.for" in carry_text, "carry loop should lower to scf.for") + expect("iter_args(" in carry_text, "carry loop should lower named state through scf.for iter_args") + expect("scf.yield" in carry_text, "carry loop should lower loop.update(...) to scf.yield") + expect( + carry_text.count("!pto.tile_buf") >= 3, + "carry loop MLIR should materialize the specialized carried tile types", + ) + expect( + re.search(r"outs\(%[^\s]+#2 : !pto\.tile_buf\)", carry_text) is not None, + "loop.final(\"o\") should materialize the third scf.for result as the final carried state", + ) + + runtime_scalar_text = runtime_scalar_operator_probe.compile(BLOCK=8).mlir_text() + expect("arith.index_cast" in runtime_scalar_text, "mixed i64/index runtime arithmetic should materialize index_cast") + expect("arith.floordivsi" in runtime_scalar_text, "runtime // should lower to arith.floordivsi") + expect("arith.remsi" in runtime_scalar_text, "runtime % should lower to arith.remsi") + expect("arith.addf" in runtime_scalar_text, "runtime float + should lower to arith.addf") + expect("arith.mulf" in runtime_scalar_text, "runtime float * should lower to arith.mulf") + expect("arith.subf" in runtime_scalar_text, "runtime float - should lower to arith.subf") + expect("arith.divf" in runtime_scalar_text, "runtime float / should lower to arith.divf") + expect("arith.maximumf" in runtime_scalar_text, "scalar.max(float, float) should lower to arith.maximumf") + expect("math.exp" in runtime_scalar_text, "scalar.exp(...) should lower to math.exp") + expect("pto.store" in runtime_scalar_text, "scalar.store(...) should lower to pto.store") + + tile_slice_text = tile_slice_surface_probe.compile(BLOCK=128).mlir_text() + expect("memref.subview" in tile_slice_text, "tile[row, col:] should lower through memref.subview") + expect("memref.collapse_shape" in tile_slice_text, "2D tile[row, col:] should flatten through memref.collapse_shape") + expect("pto.tile_buf_addr" in tile_slice_text, "tile[row, col:] should materialize a memref tile address view") + expect( + "pto.vlds" in tile_slice_text and "memref<128xf32, strided<[1], offset: ?>, #pto.address_space>" in tile_slice_text, + "vlds(tile[row, col:]) should lower against the memref slice view", + ) + expect( + "pto.vsts" in tile_slice_text and "memref<128xf32, strided<[1], offset: ?>, #pto.address_space>" in tile_slice_text, + "vsts(vec, tile[row, col:], mask) should lower against the memref slice view", + ) + + integer_loop_text = integer_loop_bound_probe.compile(BLOCK=8).mlir_text() + expect( + integer_loop_text.count("arith.index_cast") >= 2, + "integer runtime loop bounds should be normalized to index with arith.index_cast", + ) + expect( + integer_loop_text.count("scf.for") == 2, + "integer loop bound probe should still lower nested authored loops to scf.for", + ) + + scalar_pointer_offset_text = scalar_pointer_offset_probe.compile().mlir_text() + expect( + re.search(r"pto\.store %c1_i32, %\d+\[%c1\]", scalar_pointer_offset_text) is not None, + "scalar.store(ptr, 1) should lower as element offset 1", + ) + expect( + re.search(r"pto\.store %c2_i32, %\d+\[%c2\]", scalar_pointer_offset_text) is not None, + "scalar.store(ptr + 2) should lower as element offset 2", + ) + expect( + re.search(r"pto\.load %\d+\[%c1(?:_\d+)?\]", scalar_pointer_offset_text) is not None, + "scalar.load(ptr, 1) should lower as element offset 1", + ) + expect( + re.search(r"pto\.load %\d+\[%c2(?:_\d+)?\]", scalar_pointer_offset_text) is not None, + "scalar.load(ptr + 2) should lower as element offset 2", + ) + + simt_pointer_offset_text = simt_pointer_offset_probe.compile().mlir_text() + expect( + "call @simt_pointer_offset_helper" in simt_pointer_offset_text, + "@pto.simt pointer helper should lower to a helper func.call", + ) + expect( + re.search(r"pto\.store %c9_i32, %(?:arg0|\d+)\[%c1(?:_\d+)?\]", simt_pointer_offset_text) is not None, + "ptr+offset sugar inside @pto.simt helpers should lower as address offsets, not scalar add", + ) + expect( + re.search(r"pto\.load %\d+\[%c1(?:_\d+)?\]", simt_pointer_offset_text) is not None, + "@pto.simt pointer helper probe should preserve ptr+offset load syntax on the caller side", + ) + + scalar_store_coercion_text = scalar_store_element_coercion_probe.compile().mlir_text() + expect( + scalar_store_coercion_text.count("arith.index_cast") >= 2, + "scalar.store(...) should coerce index runtime values to the destination integer element type", + ) + expect( + "arith.trunci" in scalar_store_coercion_text, + "scalar.store(...) should coerce wider integer runtime values down to the destination element type", + ) + expect( + scalar_store_coercion_text.count("pto.store") == 4, + "scalar.store(...) coercion probe should still lower to four pto.store operations", + ) + + public_surface_text = public_surface_exports_probe.compile().mlir_text() + expect("pto.mte_gm_ub" in public_surface_text, "mte_load(...) should lower to pto.mte_gm_ub") + expect("pto.mte_ub_gm" in public_surface_text, "mte_store(...) should lower to pto.mte_ub_gm") + expect(public_surface_text.count("pto.mem_bar") >= 1, "mem_bar(...) should still lower explicit memory barriers") + expect("pto.barrier " in public_surface_text, "pipe_barrier(Pipe.ALL) should lower to pto.barrier") + expect("pto.vexp" in public_surface_text, "vexp(...) should lower to pto.vexp") + expect("pto.vcgmax" in public_surface_text, "vcgmax(...) should lower to pto.vcgmax") + expect("pto.vcgadd" in public_surface_text, "vcgadd(...) should lower to pto.vcgadd") + expect("pto.vadds" in public_surface_text, "vsubs(...) should lower via scalar negation plus pto.vadds") + expect("pto.mte_l1_l0a" in public_surface_text, "mte_l1_l0a(...) should lower to pto.mte_l1_l0a") + expect("pto.mte_l1_l0b" in public_surface_text, "mte_l1_l0b(...) should lower to pto.mte_l1_l0b") + expect("pto.mte_l0c_ub" in public_surface_text, "mte_l0c_ub(...) should lower to pto.mte_l0c_ub") + expect("pto.mad" in public_surface_text, "mad(...) should lower to pto.mad") + + try: + block64[1, None] + except NotImplementedError as exc: + expect("compile / inspect / verify / emit" in str(exc), "runtime-launch diagnostic text mismatch") + else: + raise AssertionError("compiled handle unexpectedly accepted runtime launch syntax") + + print("ptodsl_jit_compile: PASS") + + +if __name__ == "__main__": + main() diff --git a/test/python/ptodsl_jit_diagnostics.py b/test/python/ptodsl_jit_diagnostics.py new file mode 100644 index 000000000..9d9bf98d4 --- /dev/null +++ b/test/python/ptodsl_jit_diagnostics.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "ptodsl")) + +from ptodsl import pto +from ptodsl._host_tensors import inspect_host_tensor_metadata + + +def expect(condition: bool, message: str) -> None: + if not condition: + raise AssertionError(message) + + +def expect_raises(callback, exc_type, *message_fragments: str) -> None: + try: + callback() + except exc_type as exc: + text = str(exc) + for fragment in message_fragments: + expect(fragment in text, f"expected diagnostic fragment {fragment!r} in {text!r}") + else: + raise AssertionError(f"expected {exc_type.__name__} to be raised") + + +@pto.jit(target="a5") +def native_python_if_runtime_const_probe(): + if pto.const(1): + pto.pipe_barrier(pto.Pipe.ALL) + + +@pto.jit(target="a5") +def native_python_range_runtime_metadata_probe(A: pto.tensor_spec(rank=2, dtype=pto.f32)): + for _ in range(A.shape[0]): + pto.pipe_barrier(pto.Pipe.ALL) + + +@pto.jit(target="a5") +def float_loop_bound_probe(): + with pto.for_(0, pto.const(1.5, dtype=pto.f32), step=1): + pto.pipe_barrier(pto.Pipe.ALL) + + +@pto.jit(target="a5") +def carry_update_mismatch_probe(*, BLOCK: pto.constexpr = 8): + acc = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + loop = pto.for_(0, 1, step=1).carry(acc=acc) + with loop: + loop.update(other=acc) + + +@pto.jit(target="a5") +def carry_final_mismatch_probe(*, BLOCK: pto.constexpr = 8): + acc = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + loop = pto.for_(0, 1, step=1).carry(acc=acc) + with loop: + loop.update(acc=acc) + loop.final("missing") + + +@pto.jit(target="a5") +def misaligned_row_major_tile_probe(): + pto.alloc_tile(shape=[128, 1], dtype=pto.f32, valid_shape=[128, 1]) + + +class MissingDTypeTensor: + shape = (4, 8) + strides = (8, 1) + + def data_ptr(self): + return 1024 + + +class BadDataHandleTensor: + shape = (4, 8) + strides = (8, 1) + dtype = "float32" + + def data_ptr(self): + return "not-an-int" + + +def define_missing_constexpr_default_probe(): + @pto.jit(target="a5") + def bad_probe(*, BLOCK: pto.constexpr): + pto.pipe_barrier(pto.Pipe.ALL) + + return bad_probe + + +def main() -> None: + expect_raises( + native_python_if_runtime_const_probe.compile, + TypeError, + "native Python if/while condition", + "pto.if_(...)", + "pto.constexpr", + ) + expect_raises( + native_python_range_runtime_metadata_probe.compile, + TypeError, + "native Python range()/loop bound", + "pto.for_(...)", + "runtime value", + ) + expect_raises( + float_loop_bound_probe.compile, + TypeError, + "pto.for_(...) loop bound", + "expects an index or integer runtime scalar", + "f32", + ) + expect_raises( + carry_update_mismatch_probe.compile, + RuntimeError, + "loop.update(...) must match carry names exactly", + "missing: acc", + "unexpected: other", + ) + expect_raises( + carry_final_mismatch_probe.compile, + RuntimeError, + "loop.final(...) requested unknown carry state 'missing'", + "expected one of: acc", + ) + expect_raises( + misaligned_row_major_tile_probe.compile, + TypeError, + "alloc_tile(shape=...) physical row layout is invalid", + "shape=[128, 1]", + "row byte size of 4", + "32-byte aligned", + "prefer blayout='ColMajor'", + ) + expect_raises( + define_missing_constexpr_default_probe, + TypeError, + "@pto.jit constexpr parameter 'BLOCK' must declare a default value", + ) + expect_raises( + lambda: inspect_host_tensor_metadata(MissingDTypeTensor()), + TypeError, + "host tensor metadata is incomplete or unsupported", + "missing .dtype", + ) + expect_raises( + lambda: inspect_host_tensor_metadata(BadDataHandleTensor()), + TypeError, + "host tensor metadata is incomplete or unsupported", + "data_ptr must return an integer-like data handle", + ) + print("ptodsl_jit_diagnostics: PASS") + + +if __name__ == "__main__": + main() diff --git a/test/python/ptodsl_subkernel_diagnostics.py b/test/python/ptodsl_subkernel_diagnostics.py new file mode 100644 index 000000000..b26898e8a --- /dev/null +++ b/test/python/ptodsl_subkernel_diagnostics.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "ptodsl")) + +from ptodsl import pto + + +def expect(condition: bool, message: str) -> None: + if not condition: + raise AssertionError(message) + + +def expect_raises(callback, exc_type, *message_fragments: str) -> None: + try: + callback() + except exc_type as exc: + text = str(exc) + for fragment in message_fragments: + expect(fragment in text, f"expected diagnostic fragment {fragment!r} in {text!r}") + else: + raise AssertionError(f"expected {exc_type.__name__} to be raised") + + +def define_bad_subkernel_signature_probe(): + @pto.ukernel + def bad_tensor_formal(A: pto.tensor_spec(rank=2, dtype=pto.f32)): + pto.pipe_barrier(pto.Pipe.ALL) + + return bad_tensor_formal + + +@pto.ukernel +def host_tensor_operand_probe(tensor): + pto.pipe_barrier(pto.Pipe.ALL) + + +@pto.jit(target="a5") +def host_tensor_into_subkernel_probe(A: pto.tensor_spec(rank=2, dtype=pto.f32)): + host_tensor_operand_probe(A) + + +@pto.simt +def nested_simt_probe(): + pto.get_tid_x() + + +@pto.simd +def illegal_simt_placement_probe(): + nested_simt_probe() + + +@pto.jit(target="a5") +def nested_simt_from_simd_entry(*, TRACE_TOKEN: pto.constexpr = 0): + illegal_simt_placement_probe() + + +@pto.simd +def simd_value_escape_probe(): + return pto.pset_b32("PAT_ALL") + + +@pto.jit(target="a5") +def simd_value_escape_entry(*, TRACE_TOKEN: pto.constexpr = 0): + simd_value_escape_probe() + + +def main() -> None: + expect_raises( + define_bad_subkernel_signature_probe, + TypeError, + "@pto.ukernel parameter 'A' cannot be annotated with pto.tensor_spec(...)", + "@pto.jit positional parameters", + ) + expect_raises( + host_tensor_into_subkernel_probe.compile, + TypeError, + "@pto.ukernel parameter 'tensor' uses a host tensor value", + "host tensors only belong at the @pto.jit boundary", + ) + expect_raises( + nested_simt_from_simd_entry.compile, + RuntimeError, + "@pto.simt helper materialization is only supported from the top-level @pto.jit body or inside @pto.ukernel", + "inside @pto.simd", + ) + expect_raises( + simd_value_escape_entry.compile, + RuntimeError, + "@pto.simd cannot return transient SIMD values", + "!pto.mask", + "Write the value back to a Tile/UB buffer instead", + ) + print("ptodsl_subkernel_diagnostics: PASS") + + +if __name__ == "__main__": + main() From 78c4cf8c93e15b637ba4a41138c589fe4f80f73b Mon Sep 17 00:00:00 2001 From: Giacomo Castiglioni Date: Tue, 19 May 2026 18:10:55 +0200 Subject: [PATCH 16/31] pip install ptoas --- README_en.md | 236 +++++++++++++++++++++++++++++++++++++++ _ptoas_build_backend.py | 242 ++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 26 ++++- 3 files changed, 499 insertions(+), 5 deletions(-) create mode 100644 README_en.md create mode 100644 _ptoas_build_backend.py diff --git a/README_en.md b/README_en.md new file mode 100644 index 000000000..b7a060a0f --- /dev/null +++ b/README_en.md @@ -0,0 +1,236 @@ +# ptoas (PTO Assembler & Optimizer) + +## 1. Introduction + +**ptoas** is a specialized compiler toolchain built on top of **LLVM/MLIR (llvmorg-19.1.7)** *(Commit cd708029e0b2869e80abe31ddb175f7c35361f90)*, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). + +Acting as the bridge between upper-level AI frameworks and underlying NPU/GPGPU/CPU hardware, `ptoas` is built in an **Out-of-Tree** architecture and provides complete C++ and Python interfaces. Its primary responsibilities include: + +1. **IR Parsing & Verification**: Parses `.pto` input files and verifies the semantic correctness of PTO Dialect operations (Ops). +2. **Compilation & Optimization (Passes)**: Executes optimization passes targeting the Da Vinci Architecture, such as operator fusion and automatic synchronization insertion. +3. **Code Generation (Lowering)**: Supports lowering PTO IR to `EmitC` / `Linalg` dialects, ultimately generating code that calls the `pto-isa` C++ library. +4. **Python Bindings**: Provides seamlessly integrated Python modules. Through integration with MLIR Core bindings, frameworks such as **PyPTO**, **TileLang**, and **CuTile** can build, manipulate, and compile PTO Bytecode directly from Python. + +--- + +## 2. Directory Structure + +```text +PTOAS/ +├── include/ +│ └── PTO/ # PTO Dialect headers and TableGen definitions (.td) +├── lib/ +│ ├── PTO/ # Dialect core implementation (IR) and Pass logic (Transforms) +│ ├── CAPI/ # C language interface exposure +│ └── Bindings/Python/ # Python Binding C++ implementation (Pybind11) +├── python/ # Python module build scripts and helper code +├── test/ +│ └── samples/ # Test cases +├── tools/ +│ ├── ptoas/ # ptoas command-line tool entry point (Output: ptoas) +│ └── ptobc/ # ptobc command-line tool entry point (Output: ptobc) +└── CMakeLists.txt # Top-level build configuration +``` + +--- + +## 3. Build Instructions + +⚠️ **Important**: This project strictly requires **LLVM llvmorg-19.1.7**. + +### 3.0 Environment Variable Configuration + +To simplify the build process, **first modify and run the following commands according to your environment**. Subsequent steps reference these variables directly. + +```bash +# ================= Configuration (edit here) ================= +# Set your workspace root directory +# (recommended: a dedicated directory for LLVM and PTOAS) +export WORKSPACE_DIR=$HOME/llvm-workspace + +# LLVM source and build paths +export LLVM_SOURCE_DIR=$WORKSPACE_DIR/llvm-project +export LLVM_BUILD_DIR=$LLVM_SOURCE_DIR/build-shared + +# PTOAS source and install paths +export PTO_SOURCE_DIR=$WORKSPACE_DIR/PTOAS +export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install +# ============================================================= + +# Create the workspace directory +mkdir -p $WORKSPACE_DIR +``` + +### 3.1 Prerequisites + +* **OS**: Linux (Ubuntu 20.04+ recommended) +* **Compiler**: GCC >= 9 or Clang (C++17 support required) +* **Build System**: CMake >= 3.20, Ninja +* **Python**: 3.8+ +* **Python Packages**: `pybind11`, `numpy` + +```bash +python3 -m pip install pybind11==2.12.0 numpy +``` + +> **Note**: The current LLVM/MLIR Python bindings are not compatible with `pybind11` 3.x. +> If you encounter errors like `def_property family does not currently support keep_alive` +> when building LLVM, run the downgrade command above first. + +### 3.2 Step 1: Build LLVM/MLIR (Dependency) + +Download the LLVM source, check out the `llvmorg-19.1.7` tag, and build with **shared libraries** to ensure correct linking for Python bindings. + +```bash +# 1. Clone LLVM +cd $WORKSPACE_DIR +git clone https://github.com/llvm/llvm-project.git +cd $LLVM_SOURCE_DIR + +# 2. [Critical] Check out llvmorg-19.1.7 +git checkout llvmorg-19.1.7 + +# 3. Configure CMake (build shared libs with Python bindings enabled) +cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DBUILD_SHARED_LIBS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=$(which python3) \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_TARGETS_TO_BUILD="host" + +# 4. Build LLVM (this step takes a long time) +ninja -C $LLVM_BUILD_DIR +``` + +### 3.3 Step 2: Build PTOAS (Out-of-Tree) + +Clone the PTOAS source and build against the LLVM 19 you just compiled. + +```bash +# 1. Clone PTOAS +cd $WORKSPACE_DIR +git clone https://gitcode.com/cann/pto-as.git PTOAS +cd $PTO_SOURCE_DIR + +# 2. Build and install via pip +# The build backend (pyproject.toml) drives CMake + Ninja automatically. +pip install . +``` + +This produces the same artifacts as a manual CMake build: + +```text +# CLI tools +$PTO_SOURCE_DIR/build/tools/ptoas/ptoas +$PTO_SOURCE_DIR/build/tools/ptobc/ptobc + +# Native extension installed into the MLIR Python package +$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core/ +└── mlir + └── _mlir_libs + └── _pto.cpython-*.so + +# Python dialect files +$PTO_INSTALL_DIR/ +└── mlir + └── dialects + ├── pto.py + └── _pto_ops_gen.py +``` + +### 3.4 Step 3: Python Editable Install (Optional, for Python development) + +If you want to develop and test Python code against the in-tree build without reinstalling after every C++ change, use an **editable install**. + +```bash +pip install -e . --no-build-isolation +``` + +> **Why `--no-build-isolation`?** Without this flag, pip uses a temporary virtual environment for the build, records its pybind11 path in `CMakeCache.txt`, then deletes the venv — breaking any subsequent `ninja` reconfigure. + +If you previously ran `pip install -e .` without the flag and your build is now broken, fix the existing `CMakeCache.txt` with: + +```bash +cmake -B build -Dpybind11_DIR=$(python3 -m pybind11 --cmakedir) +``` + +--- + +## 4. Usage + +### 4.1 Command-Line Interface (CLI) + +```bash +# Parse and print PTO IR +ptoas test/lit/pto/empty_func.pto + +# Run the AutoSyncInsert pass +ptoas test/lit/pto/empty_func.pto --enable-insert-sync -o outputfile.cpp + +# Specify target hardware architecture (A3 / A5) +ptoas test/lit/pto/empty_func.pto --pto-arch=a5 -o outputfile.cpp + +# Specify build level (level3 disables PlanMemory/InsertSync) +ptoas test/lit/pto/empty_func.pto --pto-level=level3 -o outputfile.cpp + +# Print the current ptoas release version +ptoas --version +``` + +### 4.2 Python API + +After configuring the environment variables, the PTO Dialect is loaded as part of `mlir.dialects`. + +```python +from mlir.ir import Context, Module, Location +# [Key] Import pto from mlir.dialects — the standard pattern for out-of-tree bindings +from mlir.dialects import pto + +with Context() as ctx, Location.unknown(): + pto.register_dialect(ctx, load=True) + module = Module.create() + print("PTO Dialect registered successfully!") +``` + +### 4.3 Running Tests + +```bash +# Run Python binding tests +cd $PTO_SOURCE_DIR/test/samples/MatMul/ +python3 ./tmatmulk.py > ./tmatmulk.pto + +# Run ptoas tests +$PTO_SOURCE_DIR/build/tools/ptoas/ptoas ./tmatmulk.pto -o ./tmatmulk.cpp +``` + +### 4.4 On-Board Validation + +This flow generates NPU validation test cases from the `.cpp` files produced by ptoas (under `test/samples/`) and runs them on an NPU. The example below reuses `MatMul/tmatmulk.cpp` generated in section 4.3. + +> For compile-only validation on a machine without an NPU card, see [docs/no_npu_compile_only_guide_zh.md](docs/no_npu_compile_only_guide_zh.md). + +```bash +# 1) Generate the npu_validation test directory +# (creates npu_validation/ under the current sample directory) + +# A2/A3 example: +python3 test/npu_validation/scripts/generate_testcase.py \ + --input test/samples/MatMul/tmatmulk.cpp \ + --run-mode npu \ + --soc-version Ascend910B1 + +# A5 example: +python3 test/npu_validation/scripts/generate_testcase.py \ + --input test/samples/MatMul/tmatmulk.cpp \ + --run-mode npu \ + --soc-version Ascend950 + +# 2) Run validation (run.sh requires no additional arguments) +test/samples/MatMul/npu_validation/tmatmulk/run.sh +``` + +Notes: +- `test/samples/MatMul/npu_validation/tmatmulk/` will contain `tmatmulk_kernel.cpp`, `main.cpp`, `golden.py`, `compare.py`, `run.sh`, and `CMakeLists.txt`. +- `golden.py` generates random inputs by default; outputs default to all zeros (only the count, shape, and data type of inputs/outputs match the kernel parameters). +- `compare.py` compares `golden*.bin` against `output*.bin` and reports an error if they differ. diff --git a/_ptoas_build_backend.py b/_ptoas_build_backend.py new file mode 100644 index 000000000..794ff8905 --- /dev/null +++ b/_ptoas_build_backend.py @@ -0,0 +1,242 @@ +""" +PEP 517 build backend for ptoas. + +Runs the CMake/Ninja build (assuming LLVM is already built), then delegates +wheel packaging to docker/create_wheel.sh. + +Environment variables (all optional): + LLVM_BUILD_DIR Path to LLVM build dir + (default: /llvm-workspace/llvm-project/build-shared) + PTO_INSTALL_DIR Install prefix (default: /install) + PTOAS_PYTHON_PACKAGE_VERSION Wheel version override +""" +from __future__ import annotations + +import base64 +import glob +import hashlib +import io +import os +import shutil +import subprocess +import sys +import zipfile +from pathlib import Path + +_REPO = Path(__file__).parent.resolve() +_LLVM_BUILD_DIR = Path( + os.environ.get("LLVM_BUILD_DIR", + "/llvm-workspace/llvm-project/build-shared") +) +_PTO_INSTALL_DIR = Path( + os.environ.get("PTO_INSTALL_DIR", str(_REPO / "install")) +) +_BUILD_DIR = _REPO / "build" +_MLIR_PY_PKG = ( + _LLVM_BUILD_DIR / "tools" / "mlir" / "python_packages" / "mlir_core" +) + + +def get_requires_for_build_wheel(config_settings=None): + return ["setuptools>=68", "wheel", "pybind11"] + + +def get_requires_for_build_editable(config_settings=None): + return ["setuptools>=68", "wheel", "pybind11"] + + +def get_requires_for_build_sdist(config_settings=None): + return [] + + +def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): + """Return wheel metadata without running the full build.""" + import email.message + + version = os.environ.get("PTOAS_PYTHON_PACKAGE_VERSION", "0.1.0") + dist_info = Path(metadata_directory) / f"ptoas-{version}.dist-info" + dist_info.mkdir(parents=True, exist_ok=True) + + meta = email.message.Message() + meta["Metadata-Version"] = "2.1" + meta["Name"] = "ptoas" + meta["Version"] = version + meta["Summary"] = "PTO Assembler & Optimizer" + meta["Requires-Python"] = ">=3.9" + meta["License"] = "Apache-2.0" + meta["Requires-Dist"] = "numpy" + meta["Requires-Dist"] = f"ptodsl @ file://{_REPO / 'ptodsl'}" + (dist_info / "METADATA").write_text(str(meta)) + (dist_info / "WHEEL").write_text( + "Wheel-Version: 1.0\nGenerator: _ptoas_build_backend\n" + "Root-Is-Purelib: True\nTag: py3-none-any\n" + ) + return dist_info.name + + +prepare_metadata_for_build_editable = prepare_metadata_for_build_wheel + + +def build_sdist(sdist_directory, config_settings=None): + raise NotImplementedError( + "ptoas does not support sdist. Use `pip install .` to build a wheel." + ) + + +def _cmake_configure_and_build(): + """CMake configure + Ninja build + install.""" + _BUILD_DIR.mkdir(exist_ok=True) + + pybind11_dir = subprocess.check_output( + [sys.executable, "-m", "pybind11", "--cmakedir"], text=True + ).strip() + + cmake_cmd = [ + "cmake", "-GNinja", + f"-S{_REPO}", f"-B{_BUILD_DIR}", + "-DCMAKE_BUILD_TYPE=Release", + f"-DLLVM_DIR={_LLVM_BUILD_DIR}/lib/cmake/llvm", + f"-DMLIR_DIR={_LLVM_BUILD_DIR}/lib/cmake/mlir", + f"-DPython3_ROOT_DIR={sys.prefix}", + f"-DPython3_EXECUTABLE={sys.executable}", + "-DPython3_FIND_STRATEGY=LOCATION", + f"-Dpybind11_DIR={pybind11_dir}", + f"-DMLIR_PYTHON_PACKAGE_DIR={_MLIR_PY_PKG}", + f"-DCMAKE_INSTALL_PREFIX={_PTO_INSTALL_DIR}", + ] + + hardening_cache = _REPO / "cmake" / "LinuxHardeningCache.cmake" + if hardening_cache.exists(): + cmake_cmd.insert(1, f"-C{hardening_cache}") + + subprocess.check_call(cmake_cmd) + subprocess.check_call(["ninja", "-C", str(_BUILD_DIR)]) + subprocess.check_call(["ninja", "-C", str(_BUILD_DIR), "install"]) + + +def _install_dialect_files(): + """Copy PTO dialect .py files and TileLang resources into the MLIR package dir.""" + dialects_src = _PTO_INSTALL_DIR / "mlir" / "dialects" + dialects_dst = _MLIR_PY_PKG / "mlir" / "dialects" + if dialects_src.exists() and dialects_dst.exists(): + for f in dialects_src.glob("*.py"): + shutil.copy2(f, dialects_dst / f.name) + + tilelang_src = _PTO_INSTALL_DIR / "tilelang_dsl" + tileops_src = _PTO_INSTALL_DIR / "share" / "ptoas" / "TileOps" + if tilelang_src.exists(): + dst = _MLIR_PY_PKG / "tilelang_dsl" + if dst.exists(): + shutil.rmtree(dst) + shutil.copytree(tilelang_src, dst) + if tileops_src.exists(): + dst = _MLIR_PY_PKG / "TileOps" + if dst.exists(): + shutil.rmtree(dst) + shutil.copytree(tileops_src, dst) + + +def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): + _cmake_configure_and_build() + + env = os.environ.copy() + env.update({ + "PTO_SOURCE_DIR": str(_REPO), + "PTO_INSTALL_DIR": str(_PTO_INSTALL_DIR), + "LLVM_BUILD_DIR": str(_LLVM_BUILD_DIR), + }) + subprocess.check_call( + ["bash", str(_REPO / "docker" / "create_wheel.sh")], + env=env, + ) + + wheels = sorted( + glob.glob(str(_MLIR_PY_PKG / "dist" / "ptoas-*.whl")), + key=os.path.getmtime, + ) + if not wheels: + raise RuntimeError( + f"No ptoas-*.whl found in {_MLIR_PY_PKG / 'dist'} after build." + ) + + wheel_path = Path(wheels[-1]) + dest = Path(wheel_directory) / wheel_path.name + shutil.copy2(wheel_path, dest) + return dest.name + + +def build_editable(wheel_directory, config_settings=None, metadata_directory=None): + """PEP 660 editable install. + + Builds the C++ extensions in-place, then produces a minimal wheel that + installs a .pth file pointing sys.path at the build tree. No files are + copied into site-packages except the .pth file itself. + """ + _cmake_configure_and_build() + + # Copy dialect .py files so `from mlir.dialects import pto` works + _install_dialect_files() + + version = os.environ.get("PTOAS_PYTHON_PACKAGE_VERSION", "0.1.0") + + # Paths that must be on sys.path for the package to be importable + pth_paths = [ + # mlir.* namespace + _pto.so (installed there by CMake) + str(_MLIR_PY_PKG), + # _pto.so output directory (CMAKE_LIBRARY_OUTPUT_DIRECTORY) + str(_BUILD_DIR / "python" / "pto"), + # handwritten Python sources (pto/dialects/pto.py, etc.) + str(_REPO / "python"), + # ptodsl pure-Python sub-package + str(_REPO / "ptodsl"), + ] + + pth_content = "\n".join(pth_paths) + "\n" + pth_filename = "ptoas-editable.pth" + + # ---- Build the editable wheel (a zip with .pth + dist-info) ---- + tag = f"py3-none-any" + wheel_name = f"ptoas-{version}-{tag}.whl" + wheel_path = Path(wheel_directory) / wheel_name + + dist_info_dir = f"ptoas-{version}.dist-info" + + def _sha256_record(data: bytes) -> str: + digest = hashlib.sha256(data).digest() + b64 = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return f"sha256={b64}" + + pth_bytes = pth_content.encode() + wheel_meta = ( + "Wheel-Version: 1.0\n" + "Generator: _ptoas_build_backend\n" + "Root-Is-Purelib: True\n" + f"Tag: {tag}\n" + "Build: editable\n" + ).encode() + metadata_content = ( + "Metadata-Version: 2.1\n" + "Name: ptoas\n" + f"Version: {version}\n" + "Summary: PTO Assembler & Optimizer\n" + "Requires-Python: >=3.9\n" + "License: Apache-2.0\n" + "Requires-Dist: numpy\n" + f"Requires-Dist: ptodsl @ file://{_REPO / 'ptodsl'}\n" + ).encode() + + record_lines = [ + f"{pth_filename},{_sha256_record(pth_bytes)},{len(pth_bytes)}", + f"{dist_info_dir}/WHEEL,{_sha256_record(wheel_meta)},{len(wheel_meta)}", + f"{dist_info_dir}/METADATA,{_sha256_record(metadata_content)},{len(metadata_content)}", + f"{dist_info_dir}/RECORD,,", + ] + record_content = "\n".join(record_lines).encode() + + with zipfile.ZipFile(wheel_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr(pth_filename, pth_bytes) + zf.writestr(f"{dist_info_dir}/WHEEL", wheel_meta) + zf.writestr(f"{dist_info_dir}/METADATA", metadata_content) + zf.writestr(f"{dist_info_dir}/RECORD", record_content) + + return wheel_name diff --git a/pyproject.toml b/pyproject.toml index 56f81ad59..2a03b7276 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,22 @@ # pyproject.toml - Project metadata # -# NOTE: This project has a complex build process that requires LLVM/MLIR to be -# built first. The wheel is created from MLIR's python packages directory, not -# from this repo root. See .github/workflows/build_wheel.yml for the full build -# process. +# Build flow (requires LLVM already built): +# pip install . +# +# This will: +# 1. CMake configure + Ninja build + install +# 2. Package Python bindings into a wheel via docker/create_wheel.sh +# +# Environment variables (all optional): +# LLVM_BUILD_DIR Path to LLVM build dir +# (default: /llvm-workspace/llvm-project/build-shared) +# PTO_INSTALL_DIR Install prefix (default: /install) +# PTOAS_PYTHON_PACKAGE_VERSION Wheel version override + +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "_ptoas_build_backend" +backend-path = ["."] [project] name = "ptoas" @@ -11,7 +24,10 @@ version = "0.1.0" description = "PTO Assembler & Optimizer" readme = "README.md" requires-python = ">=3.9" -license = {text = "Apache-2.0"} +license = "Apache-2.0" +dependencies = [ + "numpy", +] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", From 8de19683dc551b5a6bd92c5059f8d2e506ec3e3e Mon Sep 17 00:00:00 2001 From: Giacomo Castiglioni Date: Sat, 23 May 2026 06:46:01 +0200 Subject: [PATCH 17/31] use pip install in CI (#385) * pip install ptoas * use pip install in CI * wheels pipelines use pip install * add missing license header * fix pip setup --- .github/workflows/build_wheel.yml | 18 ++--------- .github/workflows/build_wheel_mac.yml | 18 ++--------- .github/workflows/ci.yml | 30 +++-------------- _ptoas_build_backend.py | 18 ++++++++--- docker/create_wheel.sh | 3 ++ .../lit/vpto/expand_tileop_to_vpto_result.pto | 32 ------------------- 6 files changed, 26 insertions(+), 93 deletions(-) delete mode 100644 test/lit/vpto/expand_tileop_to_vpto_result.pto diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index d430fc312..e3c409f68 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -156,22 +156,8 @@ jobs: - name: Build PTOAS run: | export PATH="${PY_PATH}/bin:$PATH" - cd $PTO_SOURCE_DIR - cmake -C "$PTO_SOURCE_DIR/cmake/LinuxHardeningCache.cmake" -G Ninja \ - -S . \ - -B build \ - -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm \ - -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir \ - -DPython3_ROOT_DIR=${PY_PATH} \ - -DPython3_EXECUTABLE=${PY_PATH}/bin/python \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR=$(${PY_PATH}/bin/python -m pybind11 --cmakedir) \ - -DMLIR_PYTHON_PACKAGE_DIR=${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core \ - -DPTOAS_RELEASE_VERSION_OVERRIDE=${PTOAS_VERSION} \ - -DCMAKE_INSTALL_PREFIX=${PTO_INSTALL_DIR} \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build - ninja -C build install + PTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ + pip install . --no-build-isolation - name: Create Python wheel if: false diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index 0e370a33d..809424447 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -154,22 +154,8 @@ jobs: - name: Build PTOAS run: | - cd $PTO_SOURCE_DIR - cmake -G Ninja \ - -S . \ - -B build \ - -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm \ - -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir \ - -DPython3_ROOT_DIR=${PY_PATH} \ - -DPython3_EXECUTABLE=$(which python) \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR=$(python -m pybind11 --cmakedir) \ - -DMLIR_PYTHON_PACKAGE_DIR=${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core \ - -DPTOAS_RELEASE_VERSION_OVERRIDE=${PTOAS_VERSION} \ - -DCMAKE_INSTALL_PREFIX=${PTO_INSTALL_DIR} \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build - ninja -C build install + PTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ + pip install . --no-build-isolation - name: Create Python wheel if: false diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36d2ac3c4..8651ab2d3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -200,20 +200,9 @@ jobs: - name: Build PTOAS run: | - export PYBIND11_CMAKE_DIR="$(python3 -m pybind11 --cmakedir)" - cmake -C "${GITHUB_WORKSPACE}/cmake/LinuxHardeningCache.cmake" -G Ninja -S . -B build \ - -DLLVM_DIR="${LLVM_DIR}/lib/cmake/llvm" \ - -DMLIR_DIR="${LLVM_DIR}/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=python3 \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ - -DCMAKE_INSTALL_PREFIX="${PTO_INSTALL_DIR}" \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build ptoas - ninja -C build ptobc - ninja -C build install + # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). + # PTO_INSTALL_DIR is already set at the job level. + LLVM_BUILD_DIR="${LLVM_DIR}" pip install . --no-build-isolation - name: Run lit tests shell: bash @@ -398,17 +387,8 @@ jobs: shell: bash run: | set -euo pipefail - export PYBIND11_CMAKE_DIR="$(python3 -m pybind11 --cmakedir)" - cmake -G Ninja -S . -B build \ - -DLLVM_DIR="${LLVM_DIR}/lib/cmake/llvm" \ - -DMLIR_DIR="${LLVM_DIR}/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=python3 \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build ptoas + # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). + LLVM_BUILD_DIR="${LLVM_DIR}" pip install . --no-build-isolation - name: Resolve simulator environment shell: bash diff --git a/_ptoas_build_backend.py b/_ptoas_build_backend.py index 794ff8905..9cf83e445 100644 --- a/_ptoas_build_backend.py +++ b/_ptoas_build_backend.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """ PEP 517 build backend for ptoas. @@ -38,11 +46,11 @@ def get_requires_for_build_wheel(config_settings=None): - return ["setuptools>=68", "wheel", "pybind11"] + return ["setuptools>=68", "wheel", "pybind11<3"] def get_requires_for_build_editable(config_settings=None): - return ["setuptools>=68", "wheel", "pybind11"] + return ["setuptools>=68", "wheel", "pybind11<3"] def get_requires_for_build_sdist(config_settings=None): @@ -65,7 +73,6 @@ def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): meta["Requires-Python"] = ">=3.9" meta["License"] = "Apache-2.0" meta["Requires-Dist"] = "numpy" - meta["Requires-Dist"] = f"ptodsl @ file://{_REPO / 'ptodsl'}" (dist_info / "METADATA").write_text(str(meta)) (dist_info / "WHEEL").write_text( "Wheel-Version: 1.0\nGenerator: _ptoas_build_backend\n" @@ -105,6 +112,10 @@ def _cmake_configure_and_build(): f"-DCMAKE_INSTALL_PREFIX={_PTO_INSTALL_DIR}", ] + release_version = os.environ.get("PTOAS_RELEASE_VERSION_OVERRIDE", "") + if release_version: + cmake_cmd.append(f"-DPTOAS_RELEASE_VERSION_OVERRIDE={release_version}") + hardening_cache = _REPO / "cmake" / "LinuxHardeningCache.cmake" if hardening_cache.exists(): cmake_cmd.insert(1, f"-C{hardening_cache}") @@ -222,7 +233,6 @@ def _sha256_record(data: bytes) -> str: "Requires-Python: >=3.9\n" "License: Apache-2.0\n" "Requires-Dist: numpy\n" - f"Requires-Dist: ptodsl @ file://{_REPO / 'ptodsl'}\n" ).encode() record_lines = [ diff --git a/docker/create_wheel.sh b/docker/create_wheel.sh index 2145fb9e7..a3ecf38ee 100755 --- a/docker/create_wheel.sh +++ b/docker/create_wheel.sh @@ -51,6 +51,9 @@ rm -rf "${PY_PACKAGE_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/TileOps" cp -R "${PTO_INSTALL_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/tilelang_dsl" cp -R "${PTO_INSTALL_DIR}/share/ptoas/TileOps" "${PY_PACKAGE_DIR}/TileOps" +# Copy ptodsl into the wheel so it is always shipped with ptoas +cp -R "${PTO_SOURCE_DIR}/ptodsl/ptodsl" "${PY_PACKAGE_DIR}/ptodsl" + # Copy platform-specific setup.py to package directory. # On macOS, use setup_mac.py and rename it to setup.py in the build dir. SETUP_TEMPLATE="${PTO_SOURCE_DIR}/docker/setup.py" diff --git a/test/lit/vpto/expand_tileop_to_vpto_result.pto b/test/lit/vpto/expand_tileop_to_vpto_result.pto deleted file mode 100644 index 9644fc204..000000000 --- a/test/lit/vpto/expand_tileop_to_vpto_result.pto +++ /dev/null @@ -1,32 +0,0 @@ -// Generated by command: -// ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand --mlir-print-ir-after-all ./expand_tile_op_tilelang.pto -o out.pto - -module attributes {pto.target_arch = "a5"} { - module attributes {pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} { - func.func @TADD() { - %c0_i64 = arith.constant 0 : i64 - %c16 = arith.constant 16 : index - %c4096_i64 = arith.constant 4096 : i64 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64_i32 = arith.constant 64 : i32 - %c64 = arith.constant 64 : index - pto.vecscope { - %0 = pto.castptr %c4096_i64 : i64 -> !pto.ptr - %1 = pto.castptr %c0_i64 : i64 -> !pto.ptr - scf.for %arg0 = %c0 to %c16 step %c1 { - %mask, %scalar_out = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 - %2 = arith.muli %arg0, %c64 : index - %3 = pto.addptr %0, %2 : -> - %4 = pto.vlds %3[%c0] : !pto.ptr -> !pto.vreg<64xf32> - %5 = pto.addptr %1, %2 : -> - %6 = pto.vlds %5[%c0] : !pto.ptr -> !pto.vreg<64xf32> - %7 = pto.vadd %4, %6, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - pto.vsts %7, %5[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - } - } - return - } - } -} - From 36bb9c51f4830f47e55e0b40bdd75401351a63c0 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 19 May 2026 16:13:30 +0800 Subject: [PATCH 18/31] feature(ptodsl): align ptodsl implementation with user guide --- lib/PTO/IR/PTO.cpp | 6 - lib/PTO/IR/VPTO.cpp | 23 +- ptodsl/README.md | 120 +- ptodsl/check_ir.py | 151 - ptodsl/docs/user_guide/01-introduction.md | 18 +- ptodsl/docs/user_guide/02-quick-start.md | 186 +- .../03-kernel-entry-and-subkernels.md | 212 +- .../user_guide/04-type-system-and-buffer.md | 61 +- ptodsl/docs/user_guide/05-control-flow.md | 156 +- .../user_guide/06-scalar-and-pointer-ops.md | 88 +- .../docs/user_guide/07-data-movement-ops.md | 261 +- .../docs/user_guide/08-compute-operations.md | 391 ++- .../user_guide/09-predicate-and-mask-ops.md | 52 +- ptodsl/docs/user_guide/10-sync-ops.md | 245 +- .../11-flash-attention-walkthrough.md | 265 +- .../docs/user_guide/12-additional-examples.md | 259 +- .../flash_attention_sketch.py | 43 +- ptodsl/examples/softmax_dsl.py | 28 +- ptodsl/examples/tadd_dsl.py | 4 +- ptodsl/ptodsl/_control_flow.py | 308 +- ptodsl/ptodsl/_diagnostics.py | 9 + ptodsl/ptodsl/_host_tensors.py | 6 +- ptodsl/ptodsl/_ops.py | 2923 ++++++++++++++--- ptodsl/ptodsl/_runtime_scalar_ops.py | 198 +- ptodsl/ptodsl/_scalar_coercion.py | 36 +- ptodsl/ptodsl/_subkernels.py | 61 +- ptodsl/ptodsl/_surface_types.py | 101 + ptodsl/ptodsl/_surface_values.py | 95 +- ptodsl/ptodsl/_tile_namespace.py | 128 + ptodsl/ptodsl/_tracing/session.py | 20 +- ptodsl/ptodsl/_types.py | 248 +- ptodsl/ptodsl/pto.py | 59 +- ptodsl/ptodsl/scalar.py | 87 +- ptodsl/tests/test_vector_cube_ops.py | 381 +++ test/python/ptodsl_docs_as_test.py | 463 +++ test/python/ptodsl_docs_fragment_fixtures.py | 1583 +++++++++ .../ptodsl_flash_attention_demo_compile.py | 92 + test/python/ptodsl_jit_compile.py | 959 +++++- test/python/ptodsl_jit_diagnostics.py | 135 + test/python/ptodsl_ptoas_frontend_verify.py | 108 + test/python/ptodsl_subkernel_diagnostics.py | 17 + tools/ptoas/ptoas.cpp | 3 + 42 files changed, 8860 insertions(+), 1729 deletions(-) delete mode 100644 ptodsl/check_ir.py rename ptodsl/{demos => examples}/flash_attention_sketch.py (96%) create mode 100644 ptodsl/ptodsl/_tile_namespace.py create mode 100644 ptodsl/tests/test_vector_cube_ops.py create mode 100644 test/python/ptodsl_docs_as_test.py create mode 100644 test/python/ptodsl_docs_fragment_fixtures.py create mode 100644 test/python/ptodsl_flash_attention_demo_compile.py create mode 100644 test/python/ptodsl_ptoas_frontend_verify.py diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 68ded38e7..ec3a8c4fe 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2255,12 +2255,6 @@ void mlir::pto::annotatePTOEntryFunctions(ModuleOp module) { LogicalResult AllocTileOp::verify() { auto ty = getResult().getType(); // TileBufType - - Type elemTy = ty.getElementType(); - if (isPTOLowPrecisionType(elemTy)) - return emitOpError() << "result dtype " << elemTy - << " is not supported by pto.alloc_tile yet"; - if (failed(verifyTileBufLayoutConstraints(*this, ty, "result"))) return failure(); diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 44d18c890..d16e6ceda 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -115,6 +115,12 @@ static bool isMaskGranularityAdjacentWidening(StringRef inputGranularity, (inputGranularity == "b16" && resultGranularity == "b32"); } +static bool isMaskGranularityAdjacentNarrowing(StringRef inputGranularity, + StringRef resultGranularity) { + return (inputGranularity == "b16" && resultGranularity == "b8") || + (inputGranularity == "b32" && resultGranularity == "b16"); +} + LogicalResult PTOLoadOp::verify() { return verifyVPTOScalarAccessTypes(getOperation(), getPtr().getType(), getValue().getType(), "load"); @@ -4250,8 +4256,17 @@ LogicalResult PpackOp::verify() { if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) return failure(); - if (getPart() != "LOWER") - return emitOpError("currently supports only LOWER part"); + if (!isSupportedPartToken(getPart())) + return emitOpError("requires part to be LOWER or HIGHER"); + auto inputMaskType = cast(getInput().getType()); + auto resultMaskType = cast(getResult().getType()); + StringRef inputGranularity = inputMaskType.getGranularity(); + StringRef resultGranularity = resultMaskType.getGranularity(); + if (inputGranularity != resultGranularity && + !isMaskGranularityAdjacentNarrowing(inputGranularity, resultGranularity)) { + return emitOpError( + "requires result mask granularity to match the input or narrow by one step"); + } return success(); } @@ -4259,8 +4274,8 @@ LogicalResult PunpackOp::verify() { if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) return failure(); - if (getPart() != "LOWER") - return emitOpError("currently supports only LOWER part"); + if (!isSupportedPartToken(getPart())) + return emitOpError("requires part to be LOWER or HIGHER"); auto inputMaskType = cast(getInput().getType()); auto resultMaskType = cast(getResult().getType()); StringRef inputGranularity = inputMaskType.getGranularity(); diff --git a/ptodsl/README.md b/ptodsl/README.md index 58809fcba..7e5b092fb 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -13,8 +13,8 @@ types as lazy descriptors, and control-flow maps 1-to-1 to MLIR operations. ptodsl/ ├── ptodsl/ # pip-installable package │ ├── __init__.py # exports: pto, scalar -│ ├── pto.py # main pto.* namespace -│ ├── scalar.py # pto.scalar.* arith helpers +│ ├── pto.py # main PTO DSL namespace +│ ├── scalar.py # top-level scalar.* helper namespace │ ├── _bootstrap.py # MLIR path setup + context factory │ ├── _types.py # lazy dtype descriptors and type constructors │ ├── _ops.py # PTO operation wrappers @@ -28,7 +28,6 @@ ptodsl/ │ ├── softmax_lowlevel.py # Softmax – raw MLIR Python binding calls │ └── softmax_dsl.py # Softmax – @pto.jit DSL style ├── pyproject.toml # pip install -e . -├── check_ir.py # IR correctness test runner └── README.md ``` @@ -56,41 +55,91 @@ pip install -e . --- -## Running the IR check +## Running regression checks ```bash -# From $PTOAS_REPO_ROOT/ptodsl/ -python3 check_ir.py - -# From the repository root ($PTOAS_REPO_ROOT) -python3 ptodsl/check_ir.py +cd $PTOAS_REPO_ROOT +python3 test/python/ptodsl_jit_compile.py +python3 test/python/ptodsl_jit_diagnostics.py +python3 test/python/ptodsl_subkernel_diagnostics.py +python3 test/python/ptodsl_flash_attention_demo_compile.py +python3 test/python/ptodsl_ptoas_frontend_verify.py +python3 test/python/ptodsl_docs_as_test.py ``` Expected output: ``` -ptodsl IR check -================================================== - PASS TADD low-level - PASS TADD dsl-style - PASS softmax low-level - PASS softmax dsl-style -================================================== -Result: ALL PASS +ptodsl_jit_compile: PASS +ptodsl_jit_diagnostics: PASS +ptodsl_subkernel_diagnostics: PASS +ptodsl_flash_attention_demo_compile: PASS +ptodsl_ptoas_frontend_verify: PASS +ptodsl_docs_as_test: PASS ``` -Exit code is `0` on full pass, `1` on any failure. A unified diff of up to -60 diverging lines is printed for each failing case. +`ptodsl_docs_as_test.py` is the docs-as-test regression for the PTODSL user +guide under `ptodsl/docs/user_guide/`. It scans every Python fenced code block +and requires each one to be explicitly classified with either +`ptodsl-doc-test` or `ptodsl-doc-pending` metadata. + +- `mode="compile"` blocks are executed as-authored and must pass the PTODSL + compile-only path, MLIR verify, and shared PTOAS frontend validation. +- `mode="compile_fragment"` blocks are embedded into explicit test fixtures so + representative partial snippets can be compiled under a declared outer + kernel context instead of relying on hidden heuristic context synthesis. +- `ptodsl-doc-pending` marks snippets the manual intends to treat as contract + later, but which are still blocked on missing implementation or missing test + harness support. + +Run it directly while editing the manual: + +```bash +cd $PTOAS_REPO_ROOT +python3 test/python/ptodsl_docs_as_test.py +``` + +When it fails, the diagnostic includes the Markdown path, starting line number, +and target symbol so the drift can be fixed in the manual instead of searching +through generated IR logs. + +These PTODSL regressions are intentionally complementary: + +- `ptodsl_jit_compile.py` protects canonical authored compile probes and + lowering contracts for the public PTODSL surface. +- `ptodsl_flash_attention_demo_compile.py` protects the bundled + `ptodsl/examplesflash_attention_sketch.py` authored demo as a stable end-to-end + contract. +- `ptodsl_ptoas_frontend_verify.py` protects the handoff from PTODSL-emitted + MLIR into standalone `ptoas` frontend verification. +- `ptodsl_docs_as_test.py` protects the user manual itself: documented + self-contained examples must still compile, fixture-backed partial fragments + must still compile inside their declared context, and explicitly marked + pending snippets remain visible as docs/test debt. + +`ptodsl_docs_as_test.py` is not a replacement for the authored compile/demo +regressions above. It reuses the same compile-only and frontend-validation +boundaries, but its job is to keep `ptodsl/docs/user_guide/` honest rather than +to redefine the canonical demo contracts. + +The legacy `ptodsl/check_ir.py` script has been retired. PTODSL validation now +lives under `test/python/` so every regression shares the same bootstrap, +public surface, and canonical authored targets as the tracing/JIT +implementation. --- ## DSL-style API quick reference ```python -from ptodsl import pto -s = pto.scalar # arith shorthand alias +from ptodsl import pto, scalar +s = scalar # arith shorthand alias ``` +`pto` is the main DSL namespace. `scalar` is a separate top-level helper +namespace for runtime scalar load/store, arithmetic helpers, and scalar math; +it is intentionally not exported as `pto.scalar`. + ### Kernel decorator ```python @@ -155,19 +204,20 @@ with pto.for_(c0, c128, step=c64, iter_args=(a, b)) as loop: pto.yield_(nx, ny) # scf.yield with values fx, fy = loop.results -with pto.if_(has_rows): # simple scf.if - ... # scf.yield inserted automatically - -with pto.if_(has_chunk, results=(vf32, vf32)) as br: +with pto.if_(has_rows) as br: # simple scf.if with br.then_: ... - pto.yield_(merged_max, merged_sum) + +with pto.if_(has_chunk) as br: + with br.then_: + br.assign(x=merged_max, y=merged_sum) with br.else_: - pto.yield_(running_max, running_sum) -x, y = br.results + br.assign(x=running_max, y=running_sum) +x = br.x +y = br.y ``` -### Scalar arithmetic (`s = pto.scalar`) +### Scalar arithmetic (`s = scalar`) ```python s.muli(a, b) # arith.muli @@ -175,8 +225,8 @@ s.addi(a, b) # arith.addi s.subi(a, b) # arith.subi s.index_cast(val) # arith.index_cast → index s.index_cast(pto.int32, val) # arith.index_cast → i32 -s.cmpi_sgt(a, b) # arith.cmpi sgt -s.cmpi("slt", a, b) # arith.cmpi with named predicate +(a > b) # scalar compare → pto.i1 +(a <= b) # scalar compare → pto.i1 s.select(cond, t, f) # arith.select ``` @@ -185,19 +235,21 @@ s.select(cond, t, f) # arith.select ```python pto.castptr(addr, ptr_type) # pto.castptr pto.addptr(ptr, offset) # pto.addptr -pto.vlds(ptr, offset, vreg_type) # pto.vlds +pto.vlds(ptr, offset) # pto.vlds, result vreg inferred from ptr element type pto.vbrc_load(ptr, offset, vreg_type) # pto.vlds {dist="BRC_B32"} pto.vsts(v, ptr, offset, mask) # pto.vsts pto.vsts_1pt(v, ptr, offset, mask) # pto.vsts {dist="1PT_B32"} pto.plt_b32(scalar) # → (mask, scalar_out) pto.pset_b32("PAT_ALL") # pto.pset_b32 → mask +pto.vbitcast(v, dtype) # pto.vbitcast +pto.pbitcast(mask, mask_type) # pto.pbitcast pto.vadd(a, b, mask) # infers result type from a.type pto.vmul / vmax / vdiv / vcmax / vcadd / vdup / vexpdif # similarly pto.make_tensor_view(ptr, shape=…, strides=…) # type inferred pto.partition_view(tv, offsets=…, sizes=…) # type inferred pto.alloc_tile(shape=…, dtype=…, memory_space=…) # authored surface -pto.tload(part, tile) -pto.tstore(tile, part) +pto.tile.load(part, tile) +pto.tile.store(tile, part) tile.as_ptr() / view.as_ptr() pto.get_block_idx() # → i64 pto.set_flag("MTE2", "V", event_id=0) diff --git a/ptodsl/check_ir.py b/ptodsl/check_ir.py deleted file mode 100644 index 6be6fdd71..000000000 --- a/ptodsl/check_ir.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -""" -IR correctness check for all ptodsl example scripts. - -Run from the repository root or from this directory: - python3 ptodsl/check_ir.py # from ptoas_a5/ - python3 check_ir.py # from ptoas_a5/ptodsl/ - -Each example's ``build()`` function is called; its output is compared against -the corresponding hand-written reference ``.pto`` file. - -Comparison methodology -────────────────────── -Both the generated module and the reference file are parsed by the MLIR Python -API (``Module.parse``), then printed back to a string. This round-trip: - - • Strips ``//`` comments present in hand-written ``.pto`` files - • Normalises SSA value names (``%block_idx`` → ``%0``, …) - • Normalises attribute ordering - -The resulting canonical strings are compared with ``==``. A unified diff of -the first 60 diverging lines is printed on failure. -""" - -import difflib -import importlib -import os -import sys - -# ── Path setup ──────────────────────────────────────────────────────────────── - -_HERE = os.path.dirname(os.path.abspath(__file__)) -_EXAMPLES = os.path.join(_HERE, "examples") -_MLIR_INSTALL = os.path.join(_HERE, "..", "install", "mlir") - -for _p in (_MLIR_INSTALL, _HERE, _EXAMPLES): - if _p not in sys.path: - sys.path.insert(0, _p) - -from mlir.ir import Context, Module # noqa: E402 -from mlir.dialects import pto as _pto_mod # noqa: E402 - - -# ── Helpers ─────────────────────────────────────────────────────────────────── - -def _normalize(mlir_text: str) -> str: - """Parse *mlir_text* with MLIR and return the canonical printed form.""" - with Context() as ctx: - _pto_mod.register_dialect(ctx, load=True) - return str(Module.parse(mlir_text)) - - -def _strip_comments(text: str) -> str: - """Remove ``//`` comment lines found in hand-written ``.pto`` files.""" - return "\n".join( - line for line in text.splitlines() if not line.strip().startswith("//") - ) - - -# ── Test cases ──────────────────────────────────────────────────────────────── -# Each entry: (label, module_name, reference_pto_path) - -_REPO_ROOT = os.path.abspath(os.path.join(_HERE, "..")) -_TADD_REF = os.path.join(_REPO_ROOT, "test/lit/vpto/expand_tileop_to_vpto_result.pto") -_SOFTMAX_REF = os.path.join(_REPO_ROOT, - "test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto") - -CASES = [ - ("TADD low-level ", "tadd_lowlevel", _TADD_REF), - ("TADD dsl-style ", "tadd_dsl", _TADD_REF), - ("softmax low-level", "softmax_lowlevel", _SOFTMAX_REF), - ("softmax dsl-style", "softmax_dsl", _SOFTMAX_REF), -] - - -# ── Runner ──────────────────────────────────────────────────────────────────── - -def run_checks(cases=CASES) -> bool: - """Execute every check case; return ``True`` if all passed.""" - all_passed = True - - for label, module_name, ref_path in cases: - # -- import the example and call build() -- - try: - # Re-import on every run so state doesn't leak between cases - spec = importlib.util.spec_from_file_location( - module_name, os.path.join(_EXAMPLES, f"{module_name}.py") - ) - builder = importlib.util.module_from_spec(spec) - spec.loader.exec_module(builder) - generated_text = str(builder.build()) - except Exception as exc: - print(f" FAIL {label} [builder error: {exc}]") - all_passed = False - continue - - # -- load and prepare the reference -- - try: - ref_raw = open(ref_path).read() - except FileNotFoundError: - print(f" FAIL {label} [reference not found: {ref_path}]") - all_passed = False - continue - - ref_clean = _strip_comments(ref_raw) - - # -- normalise both through the MLIR parser -- - try: - gen_norm = _normalize(generated_text) - ref_norm = _normalize(ref_clean) - except Exception as exc: - print(f" FAIL {label} [MLIR parse error: {exc}]") - all_passed = False - continue - - # -- compare -- - if gen_norm == ref_norm: - print(f" PASS {label}") - else: - all_passed = False - diff_lines = list( - difflib.unified_diff( - ref_norm.splitlines(), - gen_norm.splitlines(), - fromfile="reference", - tofile="generated", - lineterm="", - ) - ) - snippet = "\n".join(diff_lines[:60]) - print(f" FAIL {label}\n{snippet}") - if len(diff_lines) > 60: - print(f" … ({len(diff_lines) - 60} more diff lines)") - - return all_passed - - -if __name__ == "__main__": - print("ptodsl IR check") - print("=" * 50) - passed = run_checks() - print("=" * 50) - print("Result:", "ALL PASS" if passed else "SOME TESTS FAILED") - sys.exit(0 if passed else 1) diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index 4837a675f..64d75b5df 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -56,7 +56,7 @@ PTODSL organizes kernel code into three layers, each building on the one below i ``` Python Wrapper L0 user-facing wrapper (NumPy, torch-npu, pure Python) └─ @pto.jit L1 compile + cache + launch - ├─ Tile Ops tile-level: tload, tstore, tadd, ... + ├─ Tile Ops tile-level: tile.load, tile.store, tile.add, ... └─ @pto.ukernel L2 micro-instruction orchestration ├─ MTE Ops mte_load / mte_store / copy_gm_to_ubuf / ... ├─ @pto.cube matrix products (mad, mte_l1_l0a, mte_l0c_ub, ...) @@ -68,6 +68,7 @@ Python Wrapper L0 user-facing wrapper (NumPy, torch-npu, pure Pyth The outermost layer is plain Python. It handles ergonomic runtime concerns: allocating output tensors, extracting shapes and strides from framework tensors, compiling the JIT kernel, and launching it. Because L0 is just Python, you can freely mix in NumPy, torch-npu, or any other Python framework for pre- and post-processing, data preparation, or composing multiple kernel launches. This layer knows nothing about NPU internals — it is just a convenience function that most end users will call. + ```python def flash_attention(Q, K, V, *, O=None, causal=False): if O is None: @@ -89,7 +90,11 @@ Decorating a function with `@pto.jit` marks it as a launchable PTO kernel. This The parameters of a `@pto.jit` function are Python-native tensors (not PTODSL-specific descriptors). In PTODSL v1, their ABI contract is declared with `pto.tensor_spec(...)` in the function signature; this is a compile-time annotation, not a runtime object the Python wrapper must construct. The kernel body materializes `TensorView` descriptors from the runtime tensors via `make_tensor_view`, then partitions the problem with `partition_view`. Compile-time constants are declared as keyword-only arguments with `pto.constexpr`: + ```python +from ptodsl import pto + + @pto.jit(target="a5") def flash_attention_kernel( Q: pto.tensor_spec(rank=4, dtype=pto.f32), @@ -101,10 +106,11 @@ def flash_attention_kernel( BLOCK_KV: pto.constexpr = 128, CAUSAL: pto.constexpr = False, ): - ... + # ... tile allocation, block partitioning, and sub-kernel dispatch ... + return ``` -L1 is the primary layer for expressing **tile-level semantics**. Inside `@pto.jit`, you allocate tile buffers (`alloc_tile`), move data between GM and UB at block granularity (`tload`, `tstore`), and perform tile-level compute (`tadd`, `texp`, `trowsum`, etc.). When the built-in Tile Ops are not sufficient, you can drop down to `@pto.ukernel` to write custom tile-level semantics with micro-instructions. +L1 is the primary layer for expressing **tile-level semantics**. Inside `@pto.jit`, you allocate tile buffers (`alloc_tile`), move data between GM and UB at block granularity (`tile.load`, `tile.store`), and perform tile-level compute (`tile.add`, `tile.exp`, `tile.rowsum`, etc.). When the built-in Tile Ops are not sufficient, you can drop down to `@pto.ukernel` to write custom tile-level semantics with micro-instructions. The SPMD launch contract is also owned here: the runtime grid (e.g., `batch * heads` blocks) is declared at the call site, and block/subblock indices are queried via `pto.get_block_idx()` and friends. @@ -129,7 +135,7 @@ These are hardware-bound compute sub-kernels, each mapped to a specific NPU comp - **`@pto.simt`** is a scalar-programmable processor group that executes scalar instructions across many work-items in parallel. Typical operations: `lds`, `sts`, scalar arithmetic and comparison. Well-suited for per-element tile walks, boundary metadata, and pointwise blends. -L3 sub-kernels can be invoked in two ways: as named decorated functions (`@pto.cube` / `@pto.simd` / `@pto.simt`) — reusable and callable from `@pto.ukernel` or directly from `@pto.jit` — or inline as context managers (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) for quick prototyping. When called directly from `@pto.jit`, you stage data with `tload`/`tstore` instead of `mte_load`/`mte_store`; PTOAS handles the synchronization between Tile Ops and L3 compute automatically. +L3 sub-kernels can be invoked in two ways: as named decorated functions (`@pto.cube` / `@pto.simd` / `@pto.simt`) — reusable and callable from `@pto.ukernel` or directly from `@pto.jit` — or inline as context managers (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) for quick prototyping. When called directly from `@pto.jit`, you stage data with `tile.load`/`tile.store` instead of `mte_load`/`mte_store`; PTOAS handles the synchronization between Tile Ops and L3 compute automatically. The boundary contract is strict: vreg values do not escape a simd kernel, cube-local state does not leak into UB, and data crosses layer boundaries only through UB-backed tiles or typed UB pointers. @@ -153,9 +159,9 @@ Chapter 5 (Control Flow) and Chapter 6 (Scalar & Pointer Operations) cover this ## 1.4 A worked example -The flash attention kernel from Section 1.2 is not just an architectural diagram — it is a complete, runnable design sketch distributed with PTODSL (`demos/flash_attention_sketch.py`). Here is how the layers map to actual code: +The flash attention kernel from Section 1.2 is not just an architectural diagram — it is a complete, runnable design sketch distributed with PTODSL (`examplesflash_attention_sketch.py`). Here is how the layers map to actual code: -**L1 (`@pto.jit`)** allocates tiles for the Q block, KV block, online-softmax state (m/l/o ping-pong tiles), and cube-local scratch. It loops over Q blocks (outer `pto.for_`) and KV blocks (inner `pto.for_` with carry state), calling `kv_block_process` for each KV block and using `tload`/`tstore` at the GM boundary. +**L1 (`@pto.jit`)** allocates tiles for the Q block, KV block, online-softmax state (m/l/o ping-pong tiles), and cube-local scratch. It loops over Q blocks (outer `pto.for_`) and KV blocks (inner `pto.for_` with carry state), calling `kv_block_process` for each KV block and using `tile.load`/`tile.store` at the GM boundary. **L2 (`@pto.ukernel`)** stages the current K and V blocks with `mte_load`, issues `pipe_barrier(Pipe.ALL)` at phase boundaries, then sequences four sub-kernel calls: `qk_matmul` (cube), `online_softmax_rows` (simd), `pv_matmul` (cube), `blend_output_rows` (simt). diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md index 6830fbe73..2a219c386 100644 --- a/ptodsl/docs/user_guide/02-quick-start.md +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -1,41 +1,41 @@ # 2. Quick Start -This chapter walks through a minimal but complete PTODSL kernel — elementwise vector addition — covering the essential concepts you need to start writing your own kernels. +This chapter walks through a minimal but complete PTODSL kernel — a tiled copy from one GM tensor to another — covering the essential concepts you need to start writing your own kernels. -## 2.1 A first kernel: elementwise vector add +## 2.1 A first kernel: tiled copy + ```python from ptodsl import pto @pto.jit(target="a5") -def vec_add(A, B, O, *, N: pto.constexpr): - """O = A + B, elementwise, for vectors of length N.""" +def tile_copy( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + """Copy one 2D tensor tile from A to O.""" - # Describe the GM tensors. - a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) - b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) - o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) - - # Allocate a UB tile to hold one block of each vector. - a_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) - b_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) + rows = A.shape[0] + cols = A.shape[1] - # Partition the GM views to cover the whole vector. - a_part = pto.partition_view(a_view, offsets=[0], sizes=[N]) - b_part = pto.partition_view(b_view, offsets=[0], sizes=[N]) - o_part = pto.partition_view(o_view, offsets=[0], sizes=[N]) + # Describe the GM tensors. + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) - # Load A and B from GM into UB tiles. - pto.tload(a_part, a_tile) - pto.tload(b_part, b_tile) + # Allocate UB tiles for one row-strip block. + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) - # Elementwise add on the tiles. - pto.tadd(a_tile, b_tile, o_tile) + # Partition the GM views to cover the current logical slice. + a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) + o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) - # Store the result back to GM. - pto.tstore(o_tile, o_part) + # Load from GM into UB, then store back out. + pto.tile.load(a_part, a_tile) + pto.tile.store(o_tile, o_part) ``` Let us step through each piece. @@ -44,15 +44,15 @@ Let us step through each piece. ```python @pto.jit(target="a5") -def vec_add(A, B, O, *, N: pto.constexpr): +def tile_copy(A, O, *, BLOCK: pto.constexpr = 128): ``` -`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A`, `B`, `O` are Python-native tensors — they arrive from NumPy, torch-npu, or any framework that provides a shape and strides. The keyword-only argument `N` is a compile-time constant declared with `pto.constexpr`; the compiler specializes the kernel for each value of `N`. +`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A` and `O` are Python-native tensors — they arrive from NumPy, torch-npu, or any framework that provides a shape and strides. Their ABI contract is declared with `pto.tensor_spec(...)`. The keyword-only argument `BLOCK` is a compile-time constant declared with `pto.constexpr`; the compiler specializes the kernel for each tile width. ### Describing GM tensors ```python -a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) +a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) ``` `make_tensor_view` wraps a Python tensor into a `TensorView` — a descriptor that tells the kernel how to address the tensor in global memory. You provide the logical shape and the stride (in elements) of each dimension. @@ -60,7 +60,7 @@ a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) ### Allocating on-chip buffers ```python -a_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) +a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) ``` `alloc_tile` reserves space in the Unified Buffer (UB). A `Tile` is a 2D buffer that lives on-chip during kernel execution. Every tile has a `shape` and a `dtype`. @@ -68,78 +68,87 @@ a_tile = pto.alloc_tile(shape=[N], dtype=pto.f32) ### Partitioning GM views ```python -a_part = pto.partition_view(a_view, offsets=[0], sizes=[N]) +a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) ``` -`partition_view` creates a sub-view of a `TensorView` at a given offset and size. It describes *which part* of the GM tensor a `tload` or `tstore` should operate on. For this simple whole-vector example the offset is zero and the size equals the full length; in a blocked kernel you would slide the offset through a loop. +`partition_view` creates a sub-view of a `TensorView` at a given offset and size. It describes *which part* of the GM tensor a `tile.load` or `tile.store` should operate on. For this simple whole-tensor example the offset is zero and the size matches the logical tensor extent; in a blocked kernel you would slide the offset through a loop. -### Moving data: tload and tstore +### Moving data: tile.load and tile.store ```python -pto.tload(a_part, a_tile) # GM → UB -pto.tstore(o_tile, o_part) # UB → GM +pto.tile.load(a_part, a_tile) # GM → UB +pto.tile.store(o_tile, o_part) # UB → GM ``` -`tload` copies a block of data from GM (described by a partition) into a UB tile. `tstore` copies a UB tile back to GM. These are **Tile Ops** — they operate on entire tile buffers at once. +`tile.load` copies a block of data from GM (described by a partition) into a UB tile. `tile.store` copies a UB tile back to GM. These are **Tile Ops** — they operate on entire tile buffers at once. -### Computing on tiles +### Why start with copy ```python -pto.tadd(a_tile, b_tile, o_tile) +pto.tile.load(a_part, a_tile) +pto.tile.store(o_tile, o_part) ``` -`tadd` performs elementwise addition of two tiles. The result is written to a third tile. PTODSL provides a rich set of Tile-level compute instructions — `texp`, `trowsum`, `tcvt`, `tsel`, and many more — covered in Chapter 8. +A copy kernel strips the example down to the essential PTODSL boundary objects: -## 2.2 A blocked version with a loop - -The kernel above assumes the entire vector fits in one UB tile. For vectors longer than the maximum tile size, you need to process them in blocks. The length `N` is not known until the kernel is launched — it comes from the actual input tensor: - -```python -@pto.jit(target="a5") -def vec_add_blocked(A, B, O, *, BLOCK: pto.constexpr): - N = A.shape[0] +- host tensors entering `@pto.jit` +- `TensorView` descriptors over GM tensors +- UB `Tile` allocation +- `PartitionTensorView` slices +- tile-level movement with `tile.load` / `tile.store` - a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) - b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) - o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) +Once these pieces are clear, arithmetic and sub-kernel orchestration become much easier to layer on. - a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) +## 2.2 A blocked version with a loop - num_blocks = (N + BLOCK - 1) // BLOCK +The kernel above touches one logical slice directly. To introduce device-side control flow, we can iterate over the rows of a 2D tensor and copy one row-strip at a time: - with pto.for_(0, num_blocks, step=1) as i: - offset = i * BLOCK + +```python +from ptodsl import pto - a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) - b_part = pto.partition_view(b_view, offsets=[offset], sizes=[BLOCK]) - o_part = pto.partition_view(o_view, offsets=[offset], sizes=[BLOCK]) - pto.tload(a_part, a_tile) - pto.tload(b_part, b_tile) - pto.tadd(a_tile, b_tile, o_tile) - pto.tstore(o_tile, o_part) +@pto.jit(target="a5") +def blocked_copy( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + + tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + with pto.for_(0, rows, step=1) as row: + a_part = pto.partition_view(a_view, offsets=[row, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) + + pto.tile.load(a_part, tile) + pto.tile.store(tile, o_part) ``` -Here `N` is dynamic — it comes from `A.shape[0]` and can differ across launches. The loop bound `num_blocks` depends on `N`, so `pto.for_` records a structured loop in the IR rather than unrolling at trace time. The `BLOCK` parameter stays `constexpr` because it is a tuning knob, not data-dependent. Chapter 5 covers this distinction in detail. +Here `rows` and `cols` are dynamic — they come from `A.shape` and can differ across launches. The loop bound depends on `rows`, so `pto.for_` records a structured loop in the IR rather than unrolling at trace time. The `BLOCK` parameter stays `constexpr` because it is a tuning knob, not data-dependent. Chapter 5 covers this distinction in detail. ## 2.3 Compile and launch Once the kernel is defined, you compile it and then launch it: + ```python # Compile once, cache the result. -compiled = vec_add.compile(N=1024) +compiled = blocked_copy.compile(BLOCK=128) # Allocate or obtain input/output tensors (NumPy, torch-npu, ...). import numpy as np -A = np.random.randn(1024).astype(np.float32) -B = np.random.randn(1024).astype(np.float32) +A = np.random.randn(4, 128).astype(np.float32) O = np.empty_like(A) # Launch on the NPU. -compiled[1, None](A, B, O) +compiled[1, None](A, O) ``` - `.compile(**constexprs)` traces the kernel body, lowers it through the PTOAS pipeline, and returns a compiled handle. Repeated calls with the same tensor ABI contract and constexpr configuration hit the cache. @@ -165,16 +174,18 @@ This lets you map different data slices to different blocks — for example, one ## 2.5 Dropping down to micro-instructions -The examples above used Tile Ops (`tload`, `tadd`, `tstore`), which operate on entire tiles at once. When you need finer control — for instance, writing a custom softmax or an activation that maps directly to vector hardware — you can drop down to the micro-instruction level. This involves three layers working together: +The examples above used Tile Ops (`tile.load` / `tile.store` here, and arithmetic Tile Ops in later chapters), which operate on entire tiles at once. When you need finer control — for instance, writing a custom softmax or an activation that maps directly to vector hardware — you can drop down to the micro-instruction level. This involves three layers working together: + ```python # L3: hardware-bound SIMD kernel — vector instructions on individual rows. @pto.simd def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, - rows: pto.i32, cols: pto.i32): + rows: pto.index, cols: pto.index): VEC = pto.elements_per_vreg(pto.f32) + initial_remained = scalar.index_cast(pto.i32, cols) with pto.for_(0, rows, step=1) as r: - col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + col_loop = pto.for_(0, cols, step=VEC).carry(remained=initial_remained) with col_loop: c = col_loop.iv remained = col_loop.remained @@ -192,41 +203,52 @@ def add_block(a_part: pto.PartitionTensorView, b_part: pto.PartitionTensorView, o_part: pto.PartitionTensorView, a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, - rows: pto.i32, cols: pto.i32): - pto.mte_load(a_part, a_tile) - pto.mte_load(b_part, b_tile) + rows: pto.index, cols: pto.index): + row_bytes = cols * pto.bytewidth(pto.f32) + pto.mte_load(a_part.as_ptr(), a_tile.as_ptr(), 0, row_bytes, + nburst=(rows, 0, 0)) + pto.mte_load(b_part.as_ptr(), b_tile.as_ptr(), 0, row_bytes, + nburst=(rows, 0, 0)) pto.pipe_barrier(pto.Pipe.ALL) add_rows(a_tile, b_tile, o_tile, rows, cols) pto.pipe_barrier(pto.Pipe.ALL) - pto.mte_store(o_tile, o_part) + pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), row_bytes, + nburst=(rows, 0, 0)) # L1: JIT entry — tile allocation, partitioning, launch. @pto.jit(target="a5") -def vec_add_micro(A, B, O, *, BLOCK: pto.constexpr): +def vec_add_micro( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + O: pto.tensor_spec(rank=1, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): N = A.shape[0] a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) - a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) num_blocks = (N + BLOCK - 1) // BLOCK with pto.for_(0, num_blocks, step=1) as i: offset = i * BLOCK - a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) - b_part = pto.partition_view(b_view, offsets=[offset], sizes=[BLOCK]) - o_part = pto.partition_view(o_view, offsets=[offset], sizes=[BLOCK]) - add_block(a_part, b_part, o_part, a_tile, b_tile, o_tile, 1, BLOCK) + this_block = scalar.min(N - offset, BLOCK) + a_part = pto.partition_view(a_view, offsets=[offset], sizes=[this_block]) + b_part = pto.partition_view(b_view, offsets=[offset], sizes=[this_block]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) + add_block(a_part, b_part, o_part, a_tile, b_tile, o_tile, 1, this_block) ``` - **L1 `@pto.jit`**: allocates tiles, partitions the GM views, and loops over blocks — the same tile-level orchestration as Section 2.2, but now calling a ukernel instead of Tile Ops. -- **L2 `@pto.ukernel`**: stages data with `mte_load`, synchronizes with `mem_bar`, dispatches the SIMD kernel, synchronizes again, then writes back with `mte_store`. The ukernel owns the hardware-level sequencing. +- **L2 `@pto.ukernel`**: stages data with ptr-based `mte_load`, inserts explicit `pipe_barrier` phase boundaries, dispatches the SIMD kernel, synchronizes again, then writes back with `mte_store`. The ukernel owns the hardware-level sequencing. - **L3 `@pto.simd`**: the outer `pto.for_` iterates over rows, the inner `pto.for_` iterates over column chunks of the hardware vector width (`elements_per_vreg`). Each iteration loads a vector-width slice into a `vreg`, does the addition under a mask (for tail elements), and stores the result back. Both loops are recorded as structured control flow IR — the compiler decides whether to keep them or unroll them. diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 41f6cb564..b04a4c682 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -15,9 +15,9 @@ PTODSL provides five decorators that mark functions as PTO kernels, plus three c L3 sub-kernels can be invoked in two ways: 1. **As decorated functions** (`@pto.cube` / `@pto.simd` / `@pto.simt`) — reusable, named sub-kernels that can be called from `@pto.ukernel` or directly from `@pto.jit`. -2. **As context managers** (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) — inline L3 blocks for quick prototyping or one-off compute snippets inside any kernel. +2. **As context managers** (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) — inline L3 blocks for quick prototyping or one-off compute snippets inside `@pto.jit` or `@pto.ukernel`. -Calling an L3 sub-kernel directly from `@pto.jit` skips the ukernel layer: you stage data with `tload`/`tstore` instead of `mte_load`/`mte_store`, and PTOAS handles the synchronization between Tile Ops and L3 compute automatically. This is the recommended path for most users — drop down to `@pto.ukernel` only when you need explicit control over micro-instruction ordering and synchronization. +Calling an L3 sub-kernel directly from `@pto.jit` skips the ukernel layer: you stage data with `tile.load`/`tile.store` instead of `mte_load`/`mte_store`, and PTOAS handles the synchronization between Tile Ops and L3 compute automatically. This is the recommended path for most users — drop down to `@pto.ukernel` only when you need explicit control over micro-instruction ordering and synchronization. ## 3.2 `@pto.jit` — top-level JIT entry @@ -27,16 +27,18 @@ Calling an L3 sub-kernel directly from `@pto.jit` skips the ukernel layer: you s ### Signature + ```python @pto.jit(target="a5") def kernel_name( - tensor_arg_1, # Python-native tensor (positional) - tensor_arg_2, # Python-native tensor (positional) - ..., + tensor_arg_1: pto.tensor_spec(rank=1, dtype=pto.f32), # Python-native tensor (positional) + tensor_arg_2: pto.tensor_spec(rank=1, dtype=pto.f32), # Python-native tensor (positional) *, - CONST_A: pto.constexpr = default, # compile-time constant (keyword-only) - CONST_B: pto.constexpr = default, # compile-time constant (keyword-only) + CONST_A: pto.constexpr = 128, # compile-time constant (keyword-only) + CONST_B: pto.constexpr = 64, # compile-time constant (keyword-only) ): + # ... tensor views, tile allocation, and kernel logic ... + return ``` **Positional parameters** are Python-native tensors — they arrive from NumPy, torch-npu, or any framework with `.shape` and `.strides`. Inside the body, wrap them with `make_tensor_view` to create GM descriptors. @@ -45,6 +47,7 @@ def kernel_name( ### Compilation and launch + ```python # Compile (traces the body, lowers through PTOAS, caches the result) compiled = kernel_name.compile(CONST_A=128, CONST_B=64) @@ -71,73 +74,93 @@ Available inside a `@pto.jit` body: ```python @pto.jit(target="a5") -def my_kernel(A, B, O, *, BLOCK: pto.constexpr): - N = A.shape[0] - a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) - b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) - o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) - - a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - - num_blocks = (N + BLOCK - 1) // BLOCK - with pto.for_(0, num_blocks, step=1) as i: - offset = i * BLOCK - a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) - b_part = pto.partition_view(b_view, offsets=[offset], sizes=[BLOCK]) - o_part = pto.partition_view(o_view, offsets=[offset], sizes=[BLOCK]) - - pto.tload(a_part, a_tile) - pto.tload(b_part, b_tile) - pto.tadd(a_tile, b_tile, o_tile) - pto.tstore(o_tile, o_part) +def my_kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + b_view = pto.make_tensor_view(B, shape=B.shape, strides=B.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + with pto.for_(0, rows, step=1) as row: + a_part = pto.partition_view(a_view, offsets=[row, 0], sizes=[1, cols]) + b_part = pto.partition_view(b_view, offsets=[row, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.tile.add(a_tile, b_tile, o_tile) + pto.tile.store(o_tile, o_part) ``` ### Calling L3 sub-kernels directly -When you call an L3 sub-kernel directly from `@pto.jit`, data movement is handled by Tile Ops (`tload`/`tstore`) instead of MTE micro-instructions. PTOAS handles the synchronization between Tile Ops and L3 compute — the sub-kernel itself is unchanged: +When you call an L3 sub-kernel directly from `@pto.jit`, data movement is handled by Tile Ops (`tile.load`/`tile.store`) instead of MTE micro-instructions. PTOAS handles the synchronization between Tile Ops and L3 compute — the sub-kernel itself is unchanged: + ```python -@pto.cube -def my_matmul(a_tile, b_tile, l0a, l0b, acc, o_tile): - m = a_tile.valid_shape[0] - k = a_tile.valid_shape[1] - n = b_tile.valid_shape[0] - pto.mte_l1_l0a(a_tile.as_ptr(), l0a.as_ptr(), m, k) - pto.mte_l1_l0b(b_tile.as_ptr(), l0b.as_ptr(), k, n, transpose=True) - pto.mad(l0a.as_ptr(), l0b.as_ptr(), acc.as_ptr(), m, n, k) - pto.mte_l0c_ub(acc.as_ptr(), o_tile.as_ptr(), m, n, n, n, 0) +@pto.simd +def add_rows( + a_tile: pto.Tile, + b_tile: pto.Tile, + o_tile: pto.Tile, + rows: pto.index, + cols: pto.index, +): + VEC = pto.elements_per_vreg(pto.f32) + initial_remained = scalar.index_cast(pto.i32, cols) + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=initial_remained) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) + col_loop.update(remained=remained) @pto.jit(target="a5") -def my_kernel(A, B, O, *, BLOCK: pto.constexpr): - N = A.shape[0] - a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) - b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) - o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) - - a_tile = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32) - b_tile = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32) - l0a = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32, memory_space=pto.MemorySpace.LEFT) - l0b = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32, memory_space=pto.MemorySpace.RIGHT) - acc = pto.alloc_tile(shape=[BLOCK, BLOCK], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) - - num_blocks = (N + BLOCK - 1) // BLOCK - with pto.for_(0, num_blocks, step=1) as i: - offset = i * BLOCK - a_part = pto.partition_view(a_view, offsets=[offset, 0], sizes=[BLOCK, BLOCK]) - b_part = pto.partition_view(b_view, offsets=[offset, 0], sizes=[BLOCK, BLOCK]) - o_part = pto.partition_view(o_view, offsets=[offset, 0], sizes=[BLOCK, BLOCK]) +def my_kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + b_view = pto.make_tensor_view(B, shape=B.shape, strides=B.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + with pto.for_(0, rows, step=1) as row: + a_part = pto.partition_view(a_view, offsets=[row, 0], sizes=[1, cols]) + b_part = pto.partition_view(b_view, offsets=[row, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) # Tile Ops stage data from GM to UB (replaces mte_load at L1) - pto.tload(a_part, a_tile) - pto.tload(b_part, b_tile) + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) - # Direct L3 call — PTOAS handles sync between tload and compute - my_matmul(a_tile, b_tile, l0a, l0b, acc, o_tile) + # Direct L3 call — PTOAS handles sync between tile.load and compute + add_rows(a_tile, b_tile, o_tile, 1, cols) - pto.tstore(o_tile, o_part) + pto.tile.store(o_tile, o_part) ``` This is the recommended path for users who want hardware-unit compute without writing explicit MTE Ops and manual sync. Mixing direct L3 calls with Tile Ops and ukernel calls in the same `@pto.jit` body is supported — the compiler unifies the lowering. @@ -146,32 +169,42 @@ This is the recommended path for users who want hardware-unit compute without wr ### Role -`@pto.ukernel` (short for *micro-instruction kernel*) is the entry point for writing PTO micro-instructions directly. Unlike `@pto.jit` where you work with tile-level ops (`tload`, `tadd`, etc.), a ukernel lets you write explicit MTE, SIMD, SIMT, and Cube instructions — staging data with `mte_load`, synchronizing with `mem_bar`, and dispatching L3 sub-kernels. This is an advanced programming mode for expert users who need precise control over instruction ordering and hardware-level data movement. +`@pto.ukernel` (short for *micro-instruction kernel*) is the entry point for writing PTO micro-instructions directly. Unlike `@pto.jit` where you work with tile-level ops (`tile.load`, `tile.add`, etc.), a ukernel lets you write explicit MTE, SIMD, SIMT, and Cube instructions — staging data with `mte_load`, synchronizing with `mem_bar`, and dispatching L3 sub-kernels. This is an advanced programming mode for expert users who need precise control over instruction ordering and hardware-level data movement. ### Signature + ```python @pto.ukernel def my_ukernel( part: pto.PartitionTensorView, # GM partition descriptors tile: pto.Tile, # UB tile buffers scratch: pto.Tile, # cube-local scratch (LEFT, RIGHT, ...) - ptr: pto.ptr(dtype, space), # typed UB pointers - scalar: pto.i32, # PTO scalar values + ptr: pto.ptr(pto.f32, pto.MemorySpace.UB), # typed UB pointers + scalar_value: pto.i32, # PTO scalar values ): + return ``` Parameters are PTO-specific types — `Tile`, `PartitionTensorView`, `pto.ptr`, and PTO scalar types. Unlike `@pto.jit`, a ukernel does not accept Python-native tensors. ### Typical body + ```python @pto.ukernel -def process_block(k_part, v_part, k_tile, v_tile, - s_tile, o_tile, rows: pto.i32, cols: pto.i32): +def process_block(q_tile, k_part, v_part, k_tile, v_tile, + s_tile, o_tile, o_part, rows: pto.i32, cols: pto.i32): + in_row_bytes = cols * pto.bytewidth(pto.f16) + out_row_bytes = cols * pto.bytewidth(pto.f32) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f16) + ub_row_stride = k_tile.shape[1] * pto.bytewidth(pto.f16) + # Stage current block from GM to UB - pto.mte_load(k_part, k_tile) - pto.mte_load(v_part, v_tile) + pto.mte_load(k_part.as_ptr(), k_tile.as_ptr(), 0, in_row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.mte_load(v_part.as_ptr(), v_tile.as_ptr(), 0, in_row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) pto.pipe_barrier(pto.Pipe.ALL) # Dispatch sub-kernels @@ -182,10 +215,11 @@ def process_block(k_part, v_part, k_tile, v_tile, pto.pipe_barrier(pto.Pipe.ALL) # Write result back - pto.mte_store(o_tile, o_part) + pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), out_row_bytes, + nburst=(rows, ub_row_stride, gm_row_stride)) ``` -A ukernel stays below the tile-op boundary — GM↔UB movement is expressed with `mte_load`/`mte_store` (MTE Ops) rather than `tload`/`tstore`. +A ukernel stays below the tile-op boundary — GM↔UB movement is expressed with ptr-based `mte_load`/`mte_store` (MTE Ops) rather than `tile.load`/`tile.store`. ## 3.4 `@pto.cube` — Cube unit sub-kernel @@ -195,6 +229,7 @@ A ukernel stays below the tile-op boundary — GM↔UB movement is expressed wit ### Signature + ```python @pto.cube def my_cube_kernel( @@ -204,12 +239,14 @@ def my_cube_kernel( right_scratch: pto.Tile, # RIGHT buffer (cube-local) acc_scratch: pto.Tile, # ACC buffer (cube-local) ): + return ``` All parameters are `Tile` references. Tiles marked as cube-local must be allocated with the appropriate `memory_space` (e.g., `pto.MemorySpace.LEFT`, `pto.MemorySpace.ACC`). ### Typical body + ```python @pto.cube def qk_matmul( @@ -222,7 +259,7 @@ def qk_matmul( ): m = q_tile.valid_shape[0] k = q_tile.valid_shape[1] - n = k_tile.valid_shape[0] + n = k_tile.valid_shape[1] pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) @@ -245,6 +282,7 @@ Cube-local state (LEFT, RIGHT, ACC, BIAS) never leaks into UB — it is the call ### Signature + ```python @pto.simd def my_simd_kernel( @@ -253,19 +291,22 @@ def my_simd_kernel( rows: pto.i32, # PTO scalar cols: pto.i32, # PTO scalar ): + return ``` Parameters are UB `Tile` references and PTO scalar values (`pto.i32`, `pto.f32`, etc.). Scalar parameters may come from `lds` reads or compile-time constants. ### Typical body + ```python @pto.simd def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, - rows: pto.i32, cols: pto.i32): + rows: pto.index, cols: pto.index): VEC = pto.elements_per_vreg(pto.f32) + initial_remained = scalar.index_cast(pto.i32, cols) with pto.for_(0, rows, step=1) as r: - col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + col_loop = pto.for_(0, cols, step=VEC).carry(remained=initial_remained) with col_loop: c = col_loop.iv remained = col_loop.remained @@ -292,17 +333,20 @@ The boundary contract: `vreg` values (`a_vec`, `b_vec`, `o_vec`) are local to th ### Signature + ```python @pto.simt def my_simt_kernel( tile: pto.Tile, # UB tile - ptr: pto.ptr(dtype, space), # typed UB pointer - scalar: pto.i32, # PTO scalar + ptr: pto.ptr(pto.f32, pto.MemorySpace.UB), # typed UB pointer + scalar_value: pto.i32, # PTO scalar ): + return ``` ### Typical body + ```python @pto.simt def blend_output_rows( @@ -330,10 +374,11 @@ SIMT kernels read and write individual scalar elements from tiles. The unit exec ## 3.7 Context manager syntax for L3 sub-kernels -In addition to the decorator form, each L3 sub-kernel unit provides a context manager: `with pto.cube():`, `with pto.simd():`, and `with pto.simt():`. These open an inline L3 block without requiring a separate named function — useful for quick prototyping, one-off compute snippets, or when the logic is too trivial to extract. +In addition to the decorator form, each L3 sub-kernel unit provides a context manager: `with pto.cube():`, `with pto.simd():`, and `with pto.simt():`. These open an inline L3 block without requiring a separate named function — useful for quick prototyping, one-off compute snippets, or when the logic is too trivial to extract. The inline form is supported in top-level `@pto.jit` bodies and inside `@pto.ukernel`. ### Syntax + ```python with pto.simd(): # Direct L3 instructions — vreg ops, scalar loads/stores @@ -343,6 +388,7 @@ with pto.simd(): pto.vsts(o_vec, o_tile[r, c:], mask) ``` + ```python with pto.simt(): alpha = scalar.load(alpha_tile[row, 0]) @@ -351,6 +397,7 @@ with pto.simt(): scalar.store(o_next, o_next_tile[row, col]) ``` + ```python with pto.cube(): pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) @@ -396,10 +443,17 @@ Data crosses decorator boundaries only through UB-backed tiles or typed UB point `pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time constant. The compiler specializes the kernel for each combination of constexpr values, and the compiled artifact is cached by specialization key together with the kernel's tensor ABI contract. + ```python @pto.jit(target="a5") -def kernel(A, *, BLOCK: pto.constexpr = 128, DTYPE: pto.constexpr = pto.f32): - ... +def kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, + DTYPE: pto.constexpr = pto.f32, +): + # ... use BLOCK / DTYPE in tile shapes, loop bounds, or dtype-specialized paths ... + return ``` - Must appear as a keyword-only argument (after `*`). diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md index 22804daab..0c2b74b54 100644 --- a/ptodsl/docs/user_guide/04-type-system-and-buffer.md +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -29,6 +29,7 @@ Python literals are automatically typed by the tracer: `bool` → `pto.i1`, `int For explicit typing, use type constructors: + ```python x = pto.i32(1024) y = pto.ui16(7) @@ -37,7 +38,7 @@ z: pto.i32 = 1024 ### Low-precision types (storage only) -The following types are available for storage and data movement, but **not** for computation. Use them to reduce memory bandwidth; convert to a compute-capable type before arithmetic. +The following types are **storage-only**: they may only appear as element types when constructing `Tile`, `TensorView`, and `PartitionTensorView` values for storage and data movement. They **cannot** be used to construct scalars, vectors, pointers, or `tensor_spec(...)` ABI contracts. Use them to reduce memory bandwidth; convert to a compute-capable type before arithmetic. | DSL Type | Description | |----------|-------------| @@ -47,10 +48,21 @@ The following types are available for storage and data movement, but **not** for | `pto.f8e4m3` | 8-bit float (E4M3) | | `pto.f8e5m2` | 8-bit float (E5M2) | +These types can be used when constructing on-chip tiles and view descriptors: + + +```python +lp_tile = pto.alloc_tile(shape=[128, 64], dtype=pto.f8e4m3) +fp4_tile = pto.alloc_tile(shape=[64, 32], dtype=pto.f4e2m1x2) +``` + +Constructing a scalar, vector, pointer, or host tensor ABI contract with a low-precision type is **not supported** — `pto.f8e4m3(1.0)`, `pto.vreg_type(64, pto.f8e4m3)`, `pto.ptr(pto.f8e4m3)`, and `pto.tensor_spec(rank=2, dtype=pto.f8e4m3)` will raise an error. Load data as the storage type, then convert to a compute-capable type before arithmetic. + ### Integer literal guidance Prefer plain integer literals. Hex string literals are reserved for explicit bit-pattern authoring: + ```python count = pto.i32(1024) delta = pto.i16(-12) @@ -59,6 +71,7 @@ hi_bit = pto.i32("0x80000000") # bit-pattern: -2147483648 ### Floating-point literal forms + ```python a = pto.f16(-1.5) b = pto.f32("inf") @@ -82,6 +95,7 @@ Constraint: `element_count × bitwidth(dtype) = 2048`. Use `pto.elements_per_vreg(dtype)` to query the element count: + ```python lanes = pto.elements_per_vreg(pto.f32) # 64 ``` @@ -90,6 +104,7 @@ lanes = pto.elements_per_vreg(pto.f32) # 64 Reinterpret the bits of a vector register as a different element type: + ```python fvec = pto.vlds(ptr, offset) # !pto.vreg<64xf32> ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> @@ -108,8 +123,28 @@ Masks are typed by bit granularity and must match the vector element width: | `pto.mask_b16` | 16-bit | `f16`, `bf16`, `i16`, `si16`, `ui16` | | `pto.mask_b32` | 32-bit | `f32`, `i32`, `si32`, `ui32` | -Bitcast between mask types with `pto.pbitcast`: +### Constructing masks + +Use `make_mask` to generate a mask from a pattern or scalar — it automatically selects the correct bit width from the element dtype: + + +```python +active = pto.make_mask(pto.f16, "PAT_ALL") # pattern-based full mask +tail_mask, _ = pto.make_mask(pto.f32, tail_count) # load mask from tail count scalar +``` + +The bit-width-specific `pset_b32` and `plt_b32` forms are also available: + +```python +active = pto.pset_b32("PAT_ALL") +one_mask, _ = pto.plt_b32(c1_i32) +``` + +### Reinterpreting masks +`pbitcast` reinterprets a mask register at a different granularity: + + ```python mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) ``` @@ -118,6 +153,7 @@ mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) Pointers combine an element type and a memory space: + ```python ptr_gm = pto.ptr(pto.f32, pto.MemorySpace.GM) ptr_ub = pto.ptr(pto.f16, pto.MemorySpace.UB) @@ -139,13 +175,19 @@ ptr_ub = pto.ptr(pto.f16, pto.MemorySpace.UB) `TensorView` is a descriptor for a tensor in Global Memory. Create one inside a `@pto.jit` body with `make_tensor_view`: + ```python @pto.jit(target="a5") -def kernel(A, *, BLOCK: pto.constexpr): - tv = pto.make_tensor_view(A, shape=[N], strides=A.strides) +def kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + tv = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + return ``` -`make_tensor_view` wraps a Python-native tensor. You provide the logical shape and the stride of each dimension in **elements** (not bytes). The resulting `TensorView` can be partitioned for `tload`/`tstore`. +`make_tensor_view` wraps a Python-native tensor. You provide the logical shape and the stride of each dimension in **elements** (not bytes). The resulting `TensorView` can be partitioned for `tile.load`/`tile.store`. ### TensorView attributes @@ -159,8 +201,9 @@ Strides support non-contiguous tensors. Pass `strides=A.strides` from the source ## 4.6 PartitionTensorView -`partition_view` creates a sub-view of a TensorView at a given offset and size. It describes *which part* of the GM tensor a `tload` or `tstore` should operate on: +`partition_view` creates a sub-view of a TensorView at a given offset and size. It describes *which part* of the GM tensor a `tile.load` or `tile.store` should operate on: + ```python part = pto.partition_view(tv, offsets=[row_offset, 0], sizes=[BLOCK, dim]) ``` @@ -171,6 +214,7 @@ The result is a `PartitionTensorView` — a lightweight descriptor, not a data b A `Tile` is an on-chip buffer allocated in UB or cube-local memory. Allocate tiles with `alloc_tile`: + ```python # UB tile a_tile = pto.alloc_tile(shape=[BLOCK, dim], dtype=pto.f32) @@ -188,6 +232,8 @@ For narrow logical column tiles such as `[Br, 1]`, author them with `blayout="ColMajor"`. Row-major none-box tiles are validated against a 32-byte physical row-alignment rule. +For packed types (`pto.f4e1m2x2`, `pto.f4e2m1x2`), `shape` dimensions refer to the number of **packed** elements, each containing 2 f4 values. For example, `alloc_tile(shape=[128, 64], dtype=pto.f4e1m2x2)` allocates a 128×64 tile of packed elements, holding 128×64×2 individual 4-bit floats. The same applies to TensorView shapes when the tensor spec uses a packed dtype. + ### Tile attributes | Attribute | Type | Description | @@ -204,12 +250,15 @@ physical row-alignment rule. | `tile.fill(value)` | Fill the entire tile with a scalar value | | `tile.as_ptr()` | Obtain a typed pointer to the tile's base address | + ```python m_prev_tile.fill(float("-inf")) l_prev_tile.fill(0.0) rows = q_tile.valid_shape[0] cols = k_tile.valid_shape[1] +meta_tile.valid_shape = [pto.const(1), pto.const(2)] +tail_tile.valid_shape = [rows] meta_ptr = meta_tile.as_ptr() ``` diff --git a/ptodsl/docs/user_guide/05-control-flow.md b/ptodsl/docs/user_guide/05-control-flow.md index 0ef1f60c6..b2da72d56 100644 --- a/ptodsl/docs/user_guide/05-control-flow.md +++ b/ptodsl/docs/user_guide/05-control-flow.md @@ -26,9 +26,9 @@ def unrolled_kernel(A, O, *, N: pto.constexpr): o_part = pto.partition_view(o_view, offsets=[i], sizes=[1]) a_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) o_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) - pto.tload(a_part, a_tile) - pto.tadd(a_tile, a_tile, o_tile) - pto.tstore(o_tile, o_part) + pto.tile.load(a_part, a_tile) + pto.tile.add(a_tile, a_tile, o_tile) + pto.tile.store(o_tile, o_part) ``` This works when the loop bound is a compile-time constant (like a `constexpr` parameter). But if `N` comes from a tensor shape and varies per launch, `range(N)` would trace a different number of iterations each time — you would get a cache miss and recompilation on every new value. For dynamic bounds, use `pto.for_`. @@ -39,10 +39,10 @@ This works when the loop bound is a compile-time constant (like a `constexpr` pa ### Basic form + ```python -with pto.for_(start, stop, step) as iv: - # iv is the loop index (0-based relative to start) - ... +with pto.for_(start, stop, step=step) as iv: + pto.tile.load(pto.partition_view(a_view, offsets=[iv, 0], sizes=[1, cols]), tile) ``` - `start`, `stop`, `step` are PTO scalar expressions. They are evaluated on the device. @@ -51,60 +51,63 @@ with pto.for_(start, stop, step) as iv: Compare the two approaches: + ```python # Trace-time unrolling — BLOCK must be constexpr for i in range(BLOCK): - ... + pto.tile.load(pto.partition_view(a_view, offsets=[0, 0], sizes=[1, cols]), tile) # Device-side loop — num_blocks can be dynamic with pto.for_(0, num_blocks, step=1) as i: - offset = i * BLOCK - ... + pto.tile.load(pto.partition_view(a_view, offsets=[i, 0], sizes=[1, cols]), tile) ``` ### Nested loops + ```python with pto.for_(0, rows, step=1) as r: with pto.for_(0, cols, step=1) as c: val = scalar.load(tile[r, c]) - ... ``` Both loops execute on the device. The outer loop bound `rows` and inner loop bound `cols` can be runtime values. ### Loop with carry state -When a loop needs to propagate state from one iteration to the next, use the `.carry(...)` method. This is the PTODSL equivalent of a loop that accumulates or updates variables across iterations: +When a loop needs to propagate state from one iteration to the next, use the `.carry(...)` method. This is the PTODSL equivalent of a loop that accumulates or updates variables across iterations. The following self-contained kernel is the smallest compileable carry example used by the docs-as-test harness: + ```python -kv_loop = pto.for_(0, num_blocks, step=1).carry( - m=m_prev_tile, - l=l_prev_tile, - o=o_prev_tile, -) -with kv_loop: - i = kv_loop.iv # current iteration index - m_cur = kv_loop.m # value carried in from previous iteration - l_cur = kv_loop.l - o_cur = kv_loop.o - - # ... compute m_next, l_next, o_next from m_cur, l_cur, o_cur ... - - kv_loop.update( - m=m_next_tile, - l=l_next_tile, - o=o_next_tile, - ) - -# After the loop, retrieve the final carried values -final_o = kv_loop.final("o") +@pto.jit(target="a5") +def carry_loop_probe(*, BLOCK: pto.constexpr = 128): + m_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + l_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + m_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + l_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + m_prev.fill(0.0) + l_prev.fill(0.0) + o_prev.fill(0.0) + + kv_loop = pto.for_(0, 4, step=1).carry(m=m_prev, l=l_prev, o=o_prev) + with kv_loop: + kv_loop.m.fill(1.0) + kv_loop.l.fill(2.0) + kv_loop.o.fill(3.0) + kv_loop.update(m=m_next, l=l_next, o=o_next) + + final_o = kv_loop.final("o") + final_o.fill(4.0) ``` `.carry(name=initial_value)` declares named state variables that are passed from one iteration to the next. Inside the loop body, access the current value with `loop.name`. At the end of the body, call `loop.update(name=new_value)` to set what the next iteration receives. After the loop exits, `loop.final("name")` retrieves the value from the last iteration. This pattern is central to algorithms like online softmax, where each KV block updates running statistics (row max, sum, output accumulator). The ping-pong tile pattern — allocating two tiles and swapping them each iteration — is the idiomatic way to manage this state: + ```python # Allocate ping-pong state tiles m_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") @@ -121,7 +124,8 @@ with loop: m_cur = loop.m l_cur = loop.l - # ... compute new m and l into m_next, l_next ... + m_next.fill(1.0) + l_next.fill(2.0) loop.update(m=m_next, l=l_next) ``` @@ -130,6 +134,7 @@ with loop: For SIMD kernels that process data in vector-width chunks, use a carry loop to track the remaining element count across column iterations: + ```python VEC = pto.elements_per_vreg(pto.f32) col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) @@ -149,11 +154,29 @@ with col_loop: `pto.if_` records a device-side conditional branch. Unlike a Python `if`, the condition can be a runtime PTO scalar, and both branches are recorded into the program so the hardware can choose at runtime. -The condition must be a PTO scalar value (e.g., the result of a comparison like `scalar.gt(a, b)` or a value loaded from a tile). Python booleans evaluated at trace time should use a plain `if` instead. +The condition must be a PTO scalar value (e.g., the result of a comparison like `a > b` or a value loaded from a tile). Python booleans evaluated at trace time should use a plain `if` instead. + +### Recommended block structure + +PTODSL should treat one device-side conditional as one explicit branch object. +The recommended surface is: + +```python +with pto.if_(cond) as br: + with br.then_: + ... + with br.else_: + ... +``` + +This keeps the `if` / `else` pairing explicit. The `else_` branch is optional +for side-effect-only conditionals. -### Value merge across branches +### Automatic named merge across branches -When a variable is assigned inside both branches of `pto.if_`/`pto.else_`, the assignments are recorded and the variable holds the merged value after the conditional block. This is the standard SSA-style merge — the downstream code sees whichever value was produced by the taken branch: +When a value must flow out of both branches, PTODSL should merge by explicit +name. Each branch assigns the same output names with `br.assign(...)`, and the +merged results are read back from the branch handle after the conditional: ```python @pto.simt @@ -167,32 +190,22 @@ def conditional_scale( with pto.for_(0, rows, step=1) as r: with pto.for_(0, cols, step=1) as c: val = scalar.load(tile[r, c]) - big = scalar.gt(val, threshold) - - with pto.if_(big): - # Branch A: scale the value up - val = val * scale - with pto.else_(): - # Branch B: leave it as-is - pass - - # val is usable here — it is the merged result from both branches. - # If big was true, val = original * scale. - # If big was false, val = original (passed through unchanged). - scalar.store(val, tile[r, c]) -``` - -In this example, `val` is reassigned in the `if_` branch but left untouched in the `else_` branch. After the conditional block, `val` correctly represents the merged result and is stored back to the tile. You can reassign the same variable in both branches as well — the downstream code always sees the correct value. + big = val > threshold -### Expression form + with pto.if_(big) as br: + with br.then_: + br.assign(val=val * scale) + with br.else_: + br.assign(val=val) -For simple either-or selection, `pto.if_` also works as an expression that directly returns the merged value: - -```python -result = pto.if_(cond, then_value, else_value) + val = br.val + scalar.store(val, tile[r, c]) ``` -This is equivalent to the block form above and is convenient when each branch simply produces a different scalar or tile reference. +In this example, both branches define the merged value named `val`. After the +conditional closes, `br.val` is the SSA-merged result seen by downstream code. +This surface avoids explicit result-type declarations and explicit +`pto.yield_(...)` in user code while still keeping the merge contract explicit. ## 5.4 `pto.constexpr` and tracing @@ -200,22 +213,35 @@ This is equivalent to the block form above and is convenient when each branch si ```python @pto.jit(target="a5") -def kernel(A, *, BLOCK: pto.constexpr = 128, UNROLL: pto.constexpr = False): +def kernel( + A, + *, + BLOCK: pto.constexpr = 128, + NUM_BLOCKS: pto.constexpr = 8, + UNROLL: pto.constexpr = False, +): N = A.shape[0] num_blocks = (N + BLOCK - 1) // BLOCK + # N and num_blocks are runtime values derived from tensor metadata. + # They can drive device-side control flow such as pto.for_(...), + # but they are not Python integers and cannot be used in range(...). + with pto.for_(0, num_blocks, step=1) as i: + ... + if UNROLL: - # Trace-time: UNROLL is known, so this branch resolves at compile time. - # Each iteration records separately — the loop is fully unrolled. - for i in range(num_blocks): + # Trace-time: UNROLL and NUM_BLOCKS are both known during tracing. + # Each iteration records separately, so the loop is fully unrolled. + for i in range(NUM_BLOCKS): ... else: - # Device-side: a single loop instruction is recorded. - with pto.for_(0, num_blocks, step=1) as i: + # The non-unrolled path can still use a device-side loop whose bound + # is a constexpr value captured into the traced program. + with pto.for_(0, NUM_BLOCKS, step=1) as i: ... ``` -This lets you write a single kernel that specializes into different strategies based on constexpr knobs. +This lets you write a single kernel that specializes into different strategies based on constexpr knobs, while still using runtime tensor metadata for device-side control flow. ## 5.5 Summary diff --git a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md index 783e210db..b520b89d2 100644 --- a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md +++ b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md @@ -1,12 +1,12 @@ # 6. Scalar and Pointer Operations -Chapter 5 established the rule: Python constructs are resolved at trace time, PTO constructs produce device-side behavior. This chapter applies that distinction to scalars and pointers — when to use a plain Python number, when to use a `scalar.*` operation, and how to work with typed pointers. +Chapter 5 established the rule: Python constructs are resolved at trace time, PTO constructs produce device-side behavior. This chapter applies that distinction to scalars and pointers — when to use a plain Python number, when to use a top-level `scalar.*` helper, and how to work with typed pointers. ## 6.1 Python scalars vs PTO scalars A **Python scalar** is any value computed by Python during tracing: a literal (`3.14159`), a constexpr parameter (`BLOCK`), or an arithmetic expression built only from compile-time-known values (`1.0 / sqrt(128)`). These are evaluated at trace time and their results are baked into the device code as constants. -A **PTO scalar** is a value that lives on the device at runtime. It comes from a `scalar.load` read, a device-side computation (`scalar.max`, `scalar.exp`), a runtime query (`pto.get_block_idx()`), or `@pto.jit` tensor metadata such as `A.shape[0]` / `A.strides[1]`. PTO scalars flow through the recorded program and are not resolved until the kernel executes. +A **PTO scalar** is a value that lives on the device at runtime. It comes from a `scalar.load` read, a device-side computation (`scalar.max`, `scalar.exp`), a runtime query (`pto.get_block_idx()`), or `@pto.jit` tensor metadata such as `A.shape[0]` / `A.strides[1]`. PTO scalars flow through the recorded program and are not resolved until the kernel executes. The helper functions that operate on them live in the top-level `scalar` namespace, not under `pto.*`. ### The mixed expression @@ -56,6 +56,7 @@ When in doubt, ask: *can this value change between launches of the same compiled **Tile-index form** — the preferred syntax when loading from a tile: + ```python val = scalar.load(tile[row, col]) ``` @@ -64,6 +65,7 @@ val = scalar.load(tile[row, col]) **Pointer forms**: + ```python val = scalar.load(ptr, offset) # explicit offset val = scalar.load(ptr + offset) # pointer arithmetic shorthand @@ -87,12 +89,14 @@ val = scalar.load(ptr + offset) # pointer arithmetic shorthand **Tile-index form**: + ```python scalar.store(value, tile[row, col]) ``` **Pointer forms**: + ```python scalar.store(value, ptr, offset) ``` @@ -103,6 +107,7 @@ scalar.store(value, ptr, offset) `scalar.load` and `scalar.store` are the primary data access pattern inside `@pto.simt` kernels. Each `load`/`store` operates on one element per work-item, but the SIMT unit executes the same instruction across many work-items in parallel: + ```python @pto.simt def blend_output_rows( @@ -121,14 +126,29 @@ def blend_output_rows( scalar.store(o_next, o_next_tile[row, col]) ``` -When writing to a raw pointer (e.g., a small metadata buffer obtained via `as_ptr()`), use the pointer-plus-offset form: +When writing to a raw pointer (e.g., a small metadata buffer obtained via `as_ptr()`), use the pointer-plus-offset form. The following self-contained kernel is the smallest compileable pointer-offset example: + ```python -meta_ptr = meta_tile.as_ptr() -scalar.store(0, meta_ptr, 0) # store at element offset 0 -scalar.store(valid_rows, meta_ptr, 1) # store at element offset 1 -row_start = scalar.load(meta_ptr, 0) -row_stop = scalar.load(meta_ptr, 1) +from ptodsl import pto, scalar + + +@pto.jit(target="a5") +def scalar_pointer_offset_probe(): + meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) + meta_ptr = meta_tile.as_ptr() + + scalar.store(0, meta_ptr, 0) + scalar.store(1, meta_ptr, 1) + scalar.store(2, meta_ptr + 2) + + row_start = scalar.load(meta_ptr, 0) + row_stop = scalar.load(meta_ptr, 1) + valid_cols = scalar.load(meta_ptr + 2) + + _ = row_start + _ = row_stop + _ = valid_cols ``` ## 6.3 Scalar arithmetic and comparisons @@ -137,6 +157,7 @@ row_stop = scalar.load(meta_ptr, 1) Addition, subtraction, multiplication, and division of PTO scalars use standard Python syntax. The tracer records the corresponding device-side instructions automatically: + ```python o_next = alpha * o_prev + beta * pv_val # multiply-add l_scaled = l_prev * scalar.exp(m_prev - m_next) # subtraction inside exp @@ -147,7 +168,7 @@ When both operands are PTO scalars (loaded from device memory or produced by ano ### Math functions: `scalar.*` -Non-trivial scalar math functions live under the `scalar` namespace (imported as `from pto import scalar` or accessed as `pto.scalar`): +Non-trivial scalar math functions live under the top-level `scalar` namespace (imported as `from ptodsl import scalar`). They are intentionally separate from the `pto.*` namespace: #### `scalar.max(a: ScalarType, b: ScalarType) -> ScalarType` @@ -173,31 +194,42 @@ Non-trivial scalar math functions live under the `scalar` namespace (imported as **Description**: Absolute value. -#### `scalar.gt(a: ScalarType, b: ScalarType) -> pto.i1` - -**Description**: Greater-than comparison. Returns `pto.i1`. - -#### `scalar.lt(a: ScalarType, b: ScalarType) -> pto.i1` + +```python +lo = scalar.min(m_prev, row_max) +mag = scalar.abs(m_prev - row_max) +ln = scalar.log(threshold + 1.0) +root = scalar.sqrt(threshold + 4.0) +``` -**Description**: Less-than comparison. Returns `pto.i1`. +### Comparisons -#### `scalar.eq(a: ScalarType, b: ScalarType) -> pto.i1` +**Description**: PTO scalars use Python's native comparison operators. The tracer records the corresponding device-side comparison instruction and returns a `pto.i1` result. -**Description**: Equality comparison. Returns `pto.i1`. +| Operator | Predicate (signed) | Predicate (unsigned) | Predicate (float) | +|----------|---------------------|-----------------------|--------------------| +| `>` | `sgt` | `ugt` | `ogt` | +| `<` | `slt` | `ult` | `olt` | +| `==` | `eq` | `eq` | `oeq` | +| `!=` | `ne` | `ne` | `one` | +| `>=` | `sge` | `uge` | `oge` | +| `<=` | `sle` | `ule` | `ole` | **Example**: + ```python m_next = scalar.max(m_prev, row_max) l_scaled = l_prev * scalar.exp(m_prev - m_next) -need_scale = scalar.gt(val, threshold) +need_scale = val > threshold # pto.i1 result +is_zero_mask = val == threshold +in_range = (val >= threshold) & (val <= row_max) ``` -For readability in files with many scalar operations, assign `pto.scalar` to a short local name: +For readability in files with many scalar operations, use the top-level `scalar` namespace directly: + ```python -scalar = pto.scalar - m_next = scalar.max(m_prev, row_max) l_scaled = l_prev * scalar.exp(m_prev - m_next) ``` @@ -212,6 +244,7 @@ Typed pointers (Section 4.4) carry both an element type and a memory space. This Tiles and tensor views expose their base address via `as_ptr()`: + ```python gm_ptr = partition.as_ptr() # GM pointer from a PartitionTensorView ub_ptr = tile.as_ptr() # UB pointer from a Tile @@ -240,8 +273,9 @@ ub_ptr = tile.as_ptr() # UB pointer from a Tile **Example**: + ```python -ptr = pto.addptr(base_ptr, 1024) # advances by 1024 * sizeof(T) bytes +ptr = pto.addptr(base_ptr, 1024) ``` The `+` shorthand on pointers also counts in elements, not bytes. @@ -267,6 +301,11 @@ The `+` shorthand on pointers also counts in elements, not bytes. This is an advanced operation. Prefer `as_ptr()` when the source already carries type information. + +```python +ptr = pto.castptr(addr, pto.ptr(pto.i32, pto.MemorySpace.UB)) +``` + ## 6.5 Compile-time queries These functions return values that are known at trace time from type information or hardware constants. @@ -289,6 +328,7 @@ These functions return values that are known at trace time from type information **Example**: + ```python bw = pto.bytewidth(pto.f32) # 4 bw = pto.bytewidth(pto.f16) # 2 @@ -315,6 +355,7 @@ bw = pto.bytewidth(pto.i8) # 1 **Example**: + ```python vec = pto.elements_per_vreg(pto.f32) # 64 vec = pto.elements_per_vreg(pto.f16) # 128 @@ -323,6 +364,7 @@ vec = pto.elements_per_vreg(pto.i8) # 256 This is the standard stride for chunking column loops in SIMD kernels: + ```python VEC = pto.elements_per_vreg(pto.f32) with pto.for_(0, cols, step=VEC) as c: @@ -333,6 +375,7 @@ with pto.for_(0, cols, step=VEC) as c: `@pto.simt` kernels are the natural home for per-element scalar work. A typical pattern uses nested `pto.for_` loops to walk over a tile row by row, column by column: + ```python @pto.simt def elementwise_scale( @@ -353,6 +396,7 @@ This reads each element from `src_tile`, multiplies by `scale`, and writes to `d For operations that need per-row metadata alongside per-element computation, lift the row-level scalar out of the inner loop: + ```python @pto.simt def blend_with_per_row_coeffs( diff --git a/ptodsl/docs/user_guide/07-data-movement-ops.md b/ptodsl/docs/user_guide/07-data-movement-ops.md index 40c1f02c7..225cdb25d 100644 --- a/ptodsl/docs/user_guide/07-data-movement-ops.md +++ b/ptodsl/docs/user_guide/07-data-movement-ops.md @@ -2,11 +2,11 @@ This chapter covers every operation that moves data between memory spaces in PTODSL — tile-level transfers, DMA micro-instructions, vector loads and stores, and cube data movement. Operations are organized by abstraction level: tile ops (L1), DMA ops (L2), vector memory ops (L3 SIMD), and cube memory ops (L3 cube). -## 7.1 Tile-level movement: tload and tstore +## 7.1 Tile-level movement: tile.load and tile.store Tile ops move entire blocks between Global Memory and the Unified Buffer in a single call. They are the primary data movement interface inside `@pto.jit`. -#### `pto.tload(partition: PartitionTensorView, tile: Tile) -> None` +#### `pto.tile.load(partition: PartitionTensorView, tile: Tile) -> None` **Description**: Copies data from a GM partition into a UB tile. The transfer size is determined by the partition's `sizes` and the tile's shape — they must be compatible. @@ -21,15 +21,16 @@ Tile ops move entire blocks between Global Memory and the Unified Buffer in a si **Example**: + ```python -a_part = pto.partition_view(a_view, offsets=[offset], sizes=[BLOCK]) -a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) -pto.tload(a_part, a_tile) +a_part = pto.partition_view(a_view, offsets=[offset, 0], sizes=[1, cols]) +a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) +pto.tile.load(a_part, a_tile) ``` --- -#### `pto.tstore(tile: Tile, partition: PartitionTensorView) -> None` +#### `pto.tile.store(tile: Tile, partition: PartitionTensorView) -> None` **Description**: Copies data from a UB tile back to a GM partition. The tile's `valid_shape` determines how many elements are written; elements outside `valid_shape` are not stored. @@ -44,13 +45,14 @@ pto.tload(a_part, a_tile) **Example**: + ```python -pto.tstore(o_tile, o_part) +pto.tile.store(o_tile, o_part) ``` --- -Both `tload` and `tstore` operate at **tile granularity** — they are the idiomatic choice inside `@pto.jit` loops. When you need finer control over DMA scheduling, drop down to the micro-instruction level. +Both `tile.load` and `tile.store` operate at **tile granularity** — they are the idiomatic choice inside `@pto.jit` loops. When you need finer control over DMA scheduling, drop down to the micro-instruction level. ## 7.2 DMA micro-instructions (ukernel) @@ -65,8 +67,6 @@ Inside `@pto.ukernel`, data movement between memory spaces is expressed with gro All four share a common structure: a required innermost `nburst(...)` group that defines the repeated burst transfer, plus optional outer `loop(...)` groups for multi-level repetition. `pto.mte_gm_ub` additionally supports `pad(...)` for UB row padding. -> **Convenience wrappers**: `pto.mte_load(src, dst)` and `pto.mte_store(src, dst)` are Python-level shorthands that expand to `mte_gm_ub` / `mte_ub_gm` with inferred strides. The reference operations below are the full grouped MTE interfaces. - ### 7.2.1 GM → UB: `pto.mte_gm_ub` #### `pto.mte_gm_ub(gm_src: PtrType, ub_dst: PtrType, l2_cache_ctl: int, len_burst: int, *, nburst: tuple[int, int, int], loops: list[tuple[int, int, int]] | None = None, pad: tuple[ScalarType, int, int] | tuple[ScalarType] | None = None) -> None` @@ -94,16 +94,18 @@ All four share a common structure: a required innermost `nburst(...)` group that **Example** — load a 32×32 f32 tile from contiguous GM into contiguous UB: + ```python -pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 128, +pto.mte_gm_ub(gm_src, ub_dst, 0, 128, nburst=(32, 128, 128)) # 32 rows, 128 bytes per row, contiguous in both GM and UB ``` **Example** — load a 64×128 f16 tile from a larger GM matrix (1024×512) into UB: + ```python -pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 256, +pto.mte_gm_ub(gm_src, ub_dst, 0, 256, nburst=(64, 1024, 256)) # 64 rows of 256 bytes each. # GM: each row is 1024 bytes apart (full matrix row stride). @@ -112,8 +114,9 @@ pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 256, **Example** — load with padding (100 valid f16 columns into a 128-wide UB tile): + ```python -pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 200, +pto.mte_gm_ub(gm_src, ub_dst, 0, 200, nburst=(64, 200, 256), pad=(0.0, 0, 0)) # 64 rows, 200 valid bytes per row, 256-byte UB stride. @@ -122,8 +125,9 @@ pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 200, **Example** — multi-level loop: load 4 batches of 8×128 f16 tiles: + ```python -pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 256, +pto.mte_gm_ub(gm_src, ub_dst, 0, 256, nburst=(8, 256, 256), loops=[(4, 2048, 2048)]) # Innermost: 8 rows × 256B (one tile). @@ -152,15 +156,17 @@ pto.mte_gm_ub(gm_ptr, ub_ptr, 0, 256, **Example** — store a 32×32 f32 tile from UB to GM: + ```python -pto.mte_ub_gm(ub_ptr, gm_ptr, 128, +pto.mte_ub_gm(ub_src_f32, gm_dst_f32, 128, nburst=(32, 128, 128)) ``` **Example** — store a 64×128 f16 tile back to a larger GM matrix: + ```python -pto.mte_ub_gm(ub_ptr, gm_ptr, 256, +pto.mte_ub_gm(ub_src, gm_dst, 256, nburst=(64, 256, 1024)) # UB: contiguous rows (256-byte stride). # GM: rows spaced at 1024-byte intervals (full matrix width). @@ -189,6 +195,7 @@ Each burst copies `len_burst * 32` bytes. The next burst starts at `src + (len_b **Example**: + ```python pto.mte_ub_ub(ub_src, ub_dst, 8, nburst=(16, 0, 4)) @@ -245,6 +252,7 @@ For `mte_ub_ub` and `mte_ub_l1`, the parameters are in **32-byte units**. Each b ### 7.2.6 Typical ukernel DMA pattern + ```python @pto.ukernel def process_block(k_part, v_part, k_tile, v_tile, o_tile, o_part, @@ -277,8 +285,13 @@ Inside `@pto.simd`, data moves between UB tiles and vector registers (`vreg`). V All vector load and store operations support the element-indexing syntax, which eliminates manual byte-offset calculation: + ```python vec = pto.vlds(tile[row, col:]) # load from row, starting at column col +``` + + +```python vec = pto.vlds(tile[start:]) # 1D tile, starting at element start ``` @@ -316,6 +329,10 @@ The compiler automatically computes the byte offset from the tile's shape, eleme **Description**: Dual vector load with deinterleave (AoS → SoA). Loads interleaved data and deinterleaves into two vectors. +PTODSL accepts both pointer-based forms and tile-slice forms. The tile-slice +spellings are PTODSL surface sugar; the pointer form `buf[offset] + dist` is +the canonical form. + **Parameters**: | Parameter | Type | Description | @@ -324,7 +341,7 @@ The compiler automatically computes the byte offset from the tile's shape, eleme | `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | | `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | | `offset` | `Index` | Byte offset (pointer form) | -| `dist` | `DeinterleaveDist` | `DINTLV` (alternating elements) or `BDINTLV` (block deinterleave) | +| `dist` | `DeinterleaveDist` | `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` (alternating elements) or `BDINTLV` (block deinterleave) | **Returns**: @@ -341,6 +358,9 @@ The compiler automatically computes the byte offset from the tile's shape, eleme **Description**: Primes the alignment buffer for a subsequent unaligned load stream. Returns alignment state consumed by `vldus`. +PTODSL accepts both pointer-based forms and tile-slice forms. The tile-slice +spellings are PTODSL surface sugar; the pointer form is the canonical form. + **Parameters**: | Parameter | Type | Description | @@ -357,12 +377,15 @@ The compiler automatically computes the byte offset from the tile's shape, eleme --- -#### `pto.vldus(tile[row, col:], align: AlignType) -> (VRegType, AlignType, PtrType)` -#### `pto.vldus(tile[start:], align: AlignType) -> (VRegType, AlignType, PtrType)` -#### `pto.vldus(buf: PtrType, align: AlignType) -> (VRegType, AlignType, PtrType)` +#### `pto.vldus(tile[row, col:], align: AlignType) -> (VRegType, AlignType)` +#### `pto.vldus(tile[start:], align: AlignType) -> (VRegType, AlignType)` +#### `pto.vldus(buf: PtrType, align: AlignType) -> (VRegType, AlignType)` **Description**: Unaligned load with alignment state threading. Requires alignment state from `vldas` or a previous `vldus`. +PTODSL accepts both pointer-based forms and tile-slice forms. The tile-slice +spellings are PTODSL surface sugar; the pointer form is the canonical form. + **Parameters**: | Parameter | Type | Description | @@ -370,7 +393,6 @@ The compiler automatically computes the byte offset from the tile's shape, eleme | `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | | `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | | `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | | `align` | `AlignType` | Alignment state from `vldas` or previous `vldus` | **Returns**: @@ -379,13 +401,18 @@ The compiler automatically computes the byte offset from the tile's shape, eleme |--------------|------|-------------| | `vec` | `VRegType` | Assembled vector | | `align_out` | `AlignType` | Updated alignment state for next load | -| `base_out` | `PtrType` | Post-update base pointer | - **Example**: + ```python align = pto.vldas(tile[row, col:]) -vec, align, base = pto.vldus(tile[row, col:], align) +vec, align = pto.vldus(tile[row, col:], align) +``` + + +```python +align = pto.vldas(tile[start:]) +vec, align = pto.vldus(tile[start:], align) ``` --- @@ -412,9 +439,10 @@ vec, align, base = pto.vldus(tile[row, col:], align) --- -#### `pto.vgather2(buf: PtrType, offsets: Index, active_lanes: Index) -> VRegType` +#### `pto.vgather2(buf: PtrType, offsets: Index, mask: MaskType) -> VRegType` -**Description**: Indexed gather from UB using per-lane offsets. Only the first `active_lanes` lanes participate. +**Description**: Indexed gather from UB using per-lane offsets. Only masked-on +lanes participate. **Parameters**: @@ -422,7 +450,7 @@ vec, align, base = pto.vldus(tile[row, col:], align) |-----------|------|-------------| | `buf` | `PtrType` (UB) | Source buffer | | `offsets` | `Index` | Per-lane element offsets (vector register) | -| `active_lanes` | `Index` | Number of participating lanes | +| `mask` | `MaskType` | Predicate mask gating lane participation | **Returns**: @@ -452,16 +480,18 @@ vec, align, base = pto.vldus(tile[row, col:], align) --- -#### `pto.vgatherb(buf: PtrType, offsets: Index) -> VRegType` +#### `pto.vgatherb(buf: PtrType, offsets: Index, mask: MaskType) -> VRegType` -**Description**: Byte-granularity gather load. +**Description**: Block gather load. Participating lanes gather 32-byte blocks +from UB using byte offsets. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `buf` | `PtrType` | Source buffer | -| `offsets` | `Index` | Byte offsets | +| `buf` | `PtrType` (UB) | Source buffer | +| `offsets` | `Index` | Per-block byte offsets | +| `mask` | `MaskType` | `b32` predicate controlling which blocks participate | **Returns**: @@ -471,17 +501,20 @@ vec, align, base = pto.vldus(tile[row, col:], align) --- -#### `pto.vsldb(tile[row, col], offset: Index, mask: MaskType) -> VRegType` -#### `pto.vsldb(tile[pos], offset: Index, mask: MaskType) -> VRegType` -#### `pto.vsldb(buf: PtrType, offset: Index, mask: MaskType) -> VRegType` +#### `pto.vsldb(tile[row, col], block_stride: Index, repeat_stride: Index, mask: MaskType) -> VRegType` +#### `pto.vsldb(tile[pos], block_stride: Index, repeat_stride: Index, mask: MaskType) -> VRegType` +#### `pto.vsldb(buf: PtrType, block_stride: Index, repeat_stride: Index, mask: MaskType) -> VRegType` -**Description**: Block-strided load. The `offset` encodes a packed stride/control word, not a plain byte displacement. Masked-off blocks are zeroed. +**Description**: Block-strided load. The source is interpreted as a sequence of +32-byte blocks addressed by `repeat_stride + blk * block_stride`. Masked-off +blocks are zero-filled. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `offset` | `Index` | Packed stride/control word | +| `block_stride` | `Index` | 16-bit block stride field | +| `repeat_stride` | `Index` | 16-bit repeat stride field | | `mask` | `MaskType` | Mask controlling which blocks participate | **Returns**: @@ -498,7 +531,8 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t #### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType, dist: VStoreDist | None = None) -> None` #### `pto.vsts(vec: VRegType, buf: PtrType, offset: Index, mask: MaskType, dist: VStoreDist | None = None) -> None` -**Description**: Stateless vector store to UB. The mask gates which lanes are written. +**Description**: Stateless vector store to UB. The mask gates writes for the +distributions that use predicate masking. **Parameters**: @@ -510,15 +544,27 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t | `buf` | `PtrType` (UB) | Destination buffer (pointer form) | | `offset` | `Index` | Byte offset (pointer form) | | `mask` | `MaskType` | Predicate mask gating writes | -| `dist` | `VStoreDist` or `None` | Optional store distribution: `NORM_B32` (default), `PK_B16`/`PK_B32`/`PK_B64`, `ONE_POINT_B8`/`ONE_POINT_B16`/`ONE_POINT_B32` | +| `dist` | `VStoreDist` or `None` | Store distribution token. When omitted, PTODSL defaults to `NORM_B32` on the current surface. | **Returns**: None (side-effect operation). +**Distribution families**: + +| Family | Notes | +|--------|-------| +| `NORM_B8` / `NORM_B16` / `NORM_B32` | Contiguous vector store | +| `1PT_B8` / `1PT_B16` / `1PT_B32` | First-element-only store; predicate is ignored | +| `PK_B16` / `PK_B32` / `PK_B64` | Packed store families | +| `PK4_B32` | 4-way packed store | +| `MRG4CHN_B8` | 4-channel merge store | +| `MRG2CHN_B8` / `MRG2CHN_B16` | 2-channel merge store | + --- -#### `pto.psts(mask: MaskType, buf: PtrType, offset: Index) -> None` +#### `pto.psts(mask: MaskType, buf: PtrType, offset: Index, *, dist: PredicateDist = PredicateDist.NORM) -> None` -**Description**: Predicate store. Writes the packed predicate payload of `mask` to UB memory. +**Description**: Predicate store. Writes the packed predicate payload of `mask` +to UB memory. **Parameters**: @@ -527,6 +573,7 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t | `mask` | `MaskType` | Predicate payload to store | | `buf` | `PtrType` (UB) | Destination buffer | | `offset` | `Index` | Byte offset | +| `dist` | `PredicateDist` | Predicate payload layout. PTODSL defaults to `NORM` on the current surface. | **Returns**: None (side-effect operation). @@ -536,7 +583,8 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t #### `pto.vstsx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` #### `pto.vstsx2(low: VRegType, high: VRegType, buf: PtrType, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` -**Description**: Dual interleaving store (SoA → AoS). Interleaves two vectors into one destination. +**Description**: Dual interleaving store (SoA → AoS). Interleaves two vectors +into one destination. **Parameters**: @@ -548,58 +596,38 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t | `tile[start:]` | Tile index | 1D destination (vector-width range) | | `buf` | `PtrType` (UB) | Destination buffer (pointer form) | | `offset` | `Index` | Byte offset (pointer form) | -| `dist` | `InterleaveDist` | `INTLV` | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation). - ---- - -#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` -#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` -#### `pto.vsst(scalar: ScalarType, buf: PtrType, offset: Index, mask: MaskType) -> None` - -**Description**: Scalar broadcast store. Stores a scalar value replicated to all lanes under `mask`. - -**Parameters**: - -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value to broadcast | -| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | -| `tile[start:]` | Tile index | 1D destination (vector-width range) | -| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | -| `mask` | `MaskType` | Predicate mask | +| `dist` | `InterleaveDist` | `INTLV_B8` / `INTLV_B16` / `INTLV_B32` | +| `mask` | `MaskType` | Parameter retained for call-shape regularity; for the `INTLV_B*` family it does not affect the stored result | **Returns**: None (side-effect operation). --- -#### `pto.vsstb(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` -#### `pto.vsstb(scalar: ScalarType, tile[start:], mask: MaskType) -> None` -#### `pto.vsstb(scalar: ScalarType, buf: PtrType, offset: Index, mask: MaskType) -> None` +#### `pto.vsstb(tile[row, col], block_stride: Index, repeat_stride: Index, mask: MaskType) -> None` +#### `pto.vsstb(tile[pos], block_stride: Index, repeat_stride: Index, mask: MaskType) -> None` +#### `pto.vsstb(buf: PtrType, block_stride: Index, repeat_stride: Index, mask: MaskType) -> None` -**Description**: Enhanced scalar broadcast store. Same semantics as `vsst`. +**Description**: Block-strided store. Stores 32-byte source blocks to a +block-strided UB destination. Masked-off blocks do not write memory. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value to broadcast | -| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | -| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `tile[row, col]` | Tile index | 2D starting element | +| `tile[pos]` | Tile index | 1D starting element | | `buf` | `PtrType` (UB) | Destination buffer (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | -| `mask` | `MaskType` | Predicate mask | +| `block_stride` | `Index` | 16-bit block stride field | +| `repeat_stride` | `Index` | 16-bit repeat stride field | +| `mask` | `MaskType` | Mask controlling which blocks participate | **Returns**: None (side-effect operation). --- -#### `pto.vsta(align: AlignType, tile[row, col:]) -> None` -#### `pto.vsta(align: AlignType, tile[start:]) -> None` -#### `pto.vsta(align: AlignType, buf: PtrType, offset: Index) -> None` +#### `pto.vstar(align: AlignType, tile[row, col:]) -> None` +#### `pto.vstar(align: AlignType, tile[start:]) -> None` +#### `pto.vstar(align: AlignType, buf: PtrType) -> None` **Description**: Flush alignment state to memory. Commits buffered tail bytes from an unaligned store stream. Consumes the alignment state. @@ -611,7 +639,6 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t | `tile[row, col:]` | Tile index | 2D destination (vector-width range) | | `tile[start:]` | Tile index | 1D destination (vector-width range) | | `buf` | `PtrType` (UB) | Destination buffer (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | **Returns**: None (side-effect operation). @@ -621,7 +648,7 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t #### `pto.vstas(align: AlignType, tile[start:], offset: Index) -> None` #### `pto.vstas(align: AlignType, buf: PtrType, offset: Index) -> None` -**Description**: Scalar-register-offset form of alignment-state flush. Same buffered-tail semantics as `vsta` with an explicit scalar offset. +**Description**: Scalar-register-offset form of alignment-state flush. Same buffered-tail semantics as `vstar` with an explicit scalar offset. **Parameters**: @@ -637,27 +664,7 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t --- -#### `pto.vstar(align: AlignType, tile[row, col:]) -> None` -#### `pto.vstar(align: AlignType, tile[start:]) -> None` -#### `pto.vstar(align: AlignType, buf: PtrType) -> None` - -**Description**: Register-update form of alignment-state flush. Consumes the implicit update state from the matching store stream. - -**Parameters**: - -| Parameter | Type | Description | -|-----------|------|-------------| -| `align` | `AlignType` | Pending store-alignment state | - -| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | -| `tile[start:]` | Tile index | 1D destination (vector-width range) | -| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | - -**Returns**: None (side-effect operation). - ---- - -#### `pto.vscatter(vec: VRegType, buf: PtrType, offsets: Index, active_lanes: Index) -> None` +#### `pto.vscatter(vec: VRegType, buf: PtrType, offsets: Index, mask: MaskType) -> None` **Description**: Indexed scatter to UB. Stores vector lanes to irregular locations using per-lane offsets. @@ -668,7 +675,7 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t | `vec` | `VRegType` | Source vector to scatter | | `buf` | `PtrType` (UB) | Destination buffer | | `offsets` | `Index` | Per-lane element offsets (vector register) | -| `active_lanes` | `Index` | Number of participating lanes | +| `mask` | `MaskType` | Predicate mask gating lane participation | **Returns**: None (side-effect operation). @@ -678,54 +685,28 @@ Vector stores write `vreg` contents back to UB tiles. Like loads, they support t For streaming unaligned stores with explicit alignment threading: -#### `pto.vstu(align_in: AlignType, base_in: PtrType, vec: VRegType, buf: PtrType, mode: Index) -> (AlignType, PtrType)` - -**Description**: Unaligned store with explicit threaded alignment/base state. Returns updated state for the next store in the stream. - -**Parameters**: - -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `AlignType` | Incoming store-alignment state | -| `base_in` | `PtrType` | Current stream base pointer | -| `vec` | `VRegType` | Vector to store | -| `buf` | `PtrType` (UB) | Destination buffer | -| `mode` | `Index` | Post-update mode | - -**Returns**: - -| Return Value | Type | Description | -|--------------|------|-------------| -| `align_out` | `AlignType` | Updated buffered-tail state | -| `base_out` | `PtrType` | Post-update base pointer | - ---- - -#### `pto.vstus(align_in: AlignType, base_in: PtrType, vec: VRegType, buf: PtrType, offset: Index) -> (AlignType, PtrType)` +#### `pto.vstus(align_in: AlignType, offset: Index, vec: VRegType, buf: PtrType) -> AlignType` -**Description**: Scalar-offset unaligned store. Same roles as `vstu` with explicit scalar displacement. +**Description**: Scalar-offset unaligned store. Returns updated alignment state for the next store in the stream. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `align_in` | `AlignType` | Incoming store-alignment state | -| `base_in` | `PtrType` | Current stream base pointer | +| `offset` | `Index` | Scalar displacement | | `vec` | `VRegType` | Vector to store | | `buf` | `PtrType` (UB) | Destination buffer | -| `offset` | `Index` | Scalar displacement | -| `mode` | `Index` | Post-update mode | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| | `align_out` | `AlignType` | Updated buffered-tail state | -| `base_out` | `PtrType` | Post-update base pointer | --- -#### `pto.vstur(align_in: AlignType, vec: VRegType, buf: PtrType, mode: PostUpdateMode = PostUpdateMode.NO_POST_UPDATE) -> AlignType` +#### `pto.vstur(align_in: AlignType, vec: VRegType, buf: PtrType, mode: PostUpdate = PostUpdate.OFF) -> AlignType` **Description**: Register-update unaligned store. Updates only residual alignment state without base pointer update. @@ -736,7 +717,7 @@ For streaming unaligned stores with explicit alignment threading: | `align_in` | `AlignType` | Incoming store-alignment state | | `vec` | `VRegType` | Vector to store | | `buf` | `PtrType` (UB) | Destination buffer | -| `mode` | `PostUpdateMode` | `NO_POST_UPDATE` (default) or `POST_UPDATE` | +| `mode` | `PostUpdate` | `PostUpdate.OFF` (default) or `PostUpdate.ON` | **Returns**: @@ -769,10 +750,13 @@ For streaming unaligned stores with explicit alignment threading: **Unaligned store stream pattern** — prime, thread, flush: + ```python -align, base = pto.vstu(align0, base0, vec0, ub_ptr, mode) -align, base = pto.vstu(align, base, vec1, ub_ptr, mode) -pto.vsta(align, ub_ptr, flush_offset) +align = pto.init_align() +vec0 = pto.vlds(ub_src_f32, pto.const(0)) +align = pto.vstur(align, vec0, ub_dst_f32, pto.PostUpdate.OFF) +align = pto.vstus(align, pto.const(32), vec0, ub_dst_f32) +pto.vstas(align, ub_dst_f32, pto.const(64)) ``` ### Distribution enums reference @@ -780,11 +764,11 @@ pto.vsta(align, ub_ptr, flush_offset) | Enum | Values | Used with | |------|--------|-----------| | `VLoadDist` | `NORM`, `UNPK_B8`, `UNPK_B16`, `UNPK_B32`, `BRC_B8`, `BRC_B16`, `BRC_B32`, `US_B8`, `US_B16`, `DS_B8`, `DS_B16` | `vlds` | -| `VStoreDist` | `NORM_B8`, `NORM_B16`, `NORM_B32`, `ONE_POINT_B8`, `ONE_POINT_B16`, `ONE_POINT_B32`, `PK_B16`, `PK_B32`, `PK_B64`, `PK4_B32`, `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | `vsts` | -| `DeinterleaveDist` | `DINTLV`, `BDINTLV` | `vldsx2` | -| `InterleaveDist` | `INTLV` | `vstsx2` | +| `VStoreDist` | `NORM_B8`, `NORM_B16`, `NORM_B32`, `1PT_B8`, `1PT_B16`, `1PT_B32`, `PK_B16`, `PK_B32`, `PK_B64`, `PK4_B32`, `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | `vsts` | +| `DeinterleaveDist` | `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` | `vldsx2` | +| `InterleaveDist` | `INTLV_B8`, `INTLV_B16`, `INTLV_B32` | `vstsx2` | | `StrideMode` | `S3_B16`, `S4_B64`, `S8_B32`, `S2_B64` | `vsld` | -| `PostUpdateMode` | `NO_POST_UPDATE`, `POST_UPDATE` | `vstur` | +| `PostUpdate` | `OFF`, `ON` | `vstur` | ## 7.5 Cube data movement (cube) @@ -999,6 +983,7 @@ Inside `@pto.cube`, data flows through a hierarchy of private buffers: GM → L1 A full cube matmul (`@pto.cube`) follows this dataflow pattern: + ```python @pto.cube def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): diff --git a/ptodsl/docs/user_guide/08-compute-operations.md b/ptodsl/docs/user_guide/08-compute-operations.md index 75ec8e511..41703f495 100644 --- a/ptodsl/docs/user_guide/08-compute-operations.md +++ b/ptodsl/docs/user_guide/08-compute-operations.md @@ -10,11 +10,11 @@ Tile compute ops are the primary arithmetic surface inside `@pto.jit`. They oper Element-wise operations between two tiles of the same shape. -#### `pto.tadd(src0: Tile, src1: Tile, dst: Tile) -> None` -#### `pto.tsub(src0: Tile, src1: Tile, dst: Tile) -> None` -#### `pto.tmul(src0: Tile, src1: Tile, dst: Tile) -> None` -#### `pto.tmax(src0: Tile, src1: Tile, dst: Tile) -> None` -#### `pto.tmin(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.add(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.sub(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.mul(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.max(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.min(src0: Tile, src1: Tile, dst: Tile) -> None` **Description**: Element-wise `dst[i,j] = src0[i,j] src1[i,j]`. @@ -31,13 +31,13 @@ Element-wise operations between two tiles of the same shape. **Example**: ```python -pto.tadd(a_tile, b_tile, o_tile) -pto.tmul(scale_tile, data_tile, scaled_tile) +pto.tile.add(a_tile, b_tile, o_tile) +pto.tile.mul(scale_tile, data_tile, scaled_tile) ``` --- -#### `pto.tdiv(src0: Tile, src1: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.div(src0: Tile, src1: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` **Description**: Element-wise division. `precision_mode` can be `DEFAULT` or `HIGH_PRECISION` (f16/f32 only). @@ -58,11 +58,11 @@ pto.tmul(scale_tile, data_tile, scaled_tile) Element-wise operations between a tile and a scalar. -#### `pto.tadds(src: Tile, scalar: ScalarType, dst: Tile) -> None` -#### `pto.tsubs(src: Tile, scalar: ScalarType, dst: Tile) -> None` -#### `pto.tmuls(src: Tile, scalar: ScalarType, dst: Tile) -> None` -#### `pto.tmaxs(src: Tile, scalar: ScalarType, dst: Tile) -> None` -#### `pto.tmins(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.adds(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.subs(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.muls(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.maxs(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.mins(src: Tile, scalar: ScalarType, dst: Tile) -> None` **Description**: Element-wise `dst[i,j] = src[i,j] scalar`. @@ -78,9 +78,9 @@ Element-wise operations between a tile and a scalar. --- -#### `pto.tdivs(numer: Tile | ScalarType, denom: Tile | ScalarType, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.divs(src: Tile, scalar: ScalarType, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` -**Description**: Element-wise tile-scalar division. Accepts both `(tile, scalar)` and `(scalar, tile)` operand orders. +**Description**: Element-wise tile-scalar division: `dst[i,j] = src[i,j] / scalar`. --- @@ -88,11 +88,11 @@ Element-wise operations between a tile and a scalar. Single-source element-wise math functions. -#### `pto.texp(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` -#### `pto.tlog(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` -#### `pto.tsqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` -#### `pto.trsqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` -#### `pto.trecip(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.exp(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.log(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.sqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.rsqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.recip(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` **Description**: Element-wise `exp`, `ln`, `sqrt`, `1/sqrt`, `1/x`. @@ -108,8 +108,8 @@ Single-source element-wise math functions. --- -#### `pto.tabs(src: Tile, dst: Tile) -> None` -#### `pto.tneg(src: Tile, dst: Tile) -> None` +#### `pto.tile.abs(src: Tile, dst: Tile) -> None` +#### `pto.tile.neg(src: Tile, dst: Tile) -> None` **Description**: Element-wise absolute value and negation. No precision mode attribute. @@ -117,11 +117,11 @@ Single-source element-wise math functions. ### 8.1.4 Activation -#### `pto.trelu(src: Tile, dst: Tile) -> None` +#### `pto.tile.relu(src: Tile, dst: Tile) -> None` **Description**: `dst[i,j] = max(0, src[i,j])`. Supported on f16, f32, i32. -#### `pto.tlrelu(src: Tile, slope: float, dst: Tile) -> None` +#### `pto.tile.lrelu(src: Tile, slope: float, dst: Tile) -> None` **Description**: Leaky ReLU — `dst[i,j] = src[i,j] >= 0 ? src[i,j] : slope * src[i,j]`. @@ -133,22 +133,22 @@ Reductions collapse one dimension of a 2D tile, producing a tile with one row or #### Row reductions -#### `pto.trowsum(src: Tile, tmp: Tile, dst: Tile) -> None` -#### `pto.trowmax(src: Tile, tmp: Tile, dst: Tile) -> None` -#### `pto.trowmin(src: Tile, tmp: Tile, dst: Tile) -> None` -#### `pto.trowprod(src: Tile, tmp: Tile, dst: Tile) -> None` -#### `pto.trowargmax(src: Tile, tmp: Tile, dst: Tile) -> None` -#### `pto.trowargmin(src: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.tile.rowsum(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowmax(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowmin(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowprod(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowargmax(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowargmin(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` -**Description**: For each row `i`, reduce across columns: `dst[i, 0] = _j src[i, j]`. `trowargmax`/`trowargmin` return the column index of the extremum. +**Description**: For each row `i`, reduce across columns: `dst[i, 0] = _j src[i, j]`. `tile.rowargmax`/`tile.rowargmin` return the column index of the extremum. In the public PTODSL wrapper, `tmp` is optional; when omitted, PTODSL allocates a matching scratch tile automatically. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `src` | `Tile` | Source tile (`[rows, cols]`) | -| `tmp` | `Tile` | Scratch tile for intermediate reduction state | | `dst` | `Tile` | Destination tile (`[rows, 1]`) | +| `tmp` | `Tile | None` | Optional scratch tile for intermediate reduction state; when omitted, PTODSL synthesizes a matching scratch tile automatically | **Returns**: None. @@ -156,10 +156,10 @@ Reductions collapse one dimension of a 2D tile, producing a tile with one row or #### Column reductions -#### `pto.tcolsum(src: Tile, dst: Tile) -> None` -#### `pto.tcolmax(src: Tile, dst: Tile) -> None` -#### `pto.tcolmin(src: Tile, dst: Tile) -> None` -#### `pto.tcolprod(src: Tile, dst: Tile) -> None` +#### `pto.tile.colsum(src: Tile, dst: Tile) -> None` +#### `pto.tile.colmax(src: Tile, dst: Tile) -> None` +#### `pto.tile.colmin(src: Tile, dst: Tile) -> None` +#### `pto.tile.colprod(src: Tile, dst: Tile) -> None` **Description**: For each column `j`, reduce across rows: `dst[0, j] = _i src[i, j]`. @@ -180,7 +180,7 @@ Expansion ops take a narrow source (scalar, row vector, or column vector) and br #### Scalar broadcast -#### `pto.texpands(scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.expands(scalar: ScalarType, dst: Tile) -> None` **Description**: `dst[i,j] = scalar` — fills every element of `dst` with the same scalar value. @@ -188,7 +188,7 @@ Expansion ops take a narrow source (scalar, row vector, or column vector) and br #### Row expansion -#### `pto.trowexpand(src: Tile, dst: Tile) -> None` +#### `pto.tile.rowexpand(src: Tile, dst: Tile) -> None` **Description**: `dst[row, col] = src[row, 0]` — broadcasts each row's single value across all columns of `dst`. @@ -205,7 +205,7 @@ Expansion ops take a narrow source (scalar, row vector, or column vector) and br #### Column expansion -#### `pto.tcolexpand(src: Tile, dst: Tile) -> None` +#### `pto.tile.colexpand(src: Tile, dst: Tile) -> None` **Description**: `dst[row, col] = src[0, col]` — broadcasts each column's single value across all rows of `dst`. @@ -217,13 +217,13 @@ These combine broadcasting with an arithmetic operation: `src1` is a per-row coe | Op | Semantics | |----|-----------| -| `pto.trowexpandadd(src0, src1, dst)` | `dst = src0 + expand_rows(src1)` | -| `pto.trowexpandsub(src0, src1, dst)` | `dst = src0 - expand_rows(src1)` | -| `pto.trowexpandmul(src0, src1, dst)` | `dst = src0 * expand_rows(src1)` | -| `pto.trowexpanddiv(src0, src1, dst)` | `dst = src0 / expand_rows(src1)` (f-only) | -| `pto.trowexpandmax(src0, src1, dst)` | `dst = max(src0, expand_rows(src1))` | -| `pto.trowexpandmin(src0, src1, dst)` | `dst = min(src0, expand_rows(src1))` | -| `pto.trowexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_rows(src1))` (f-only) | +| `pto.tile.rowexpandadd(src0, src1, dst)` | `dst = src0 + expand_rows(src1)` | +| `pto.tile.rowexpandsub(src0, src1, dst)` | `dst = src0 - expand_rows(src1)` | +| `pto.tile.rowexpandmul(src0, src1, dst)` | `dst = src0 * expand_rows(src1)` | +| `pto.tile.rowexpanddiv(src0, src1, dst)` | `dst = src0 / expand_rows(src1)` (f-only) | +| `pto.tile.rowexpandmax(src0, src1, dst)` | `dst = max(src0, expand_rows(src1))` | +| `pto.tile.rowexpandmin(src0, src1, dst)` | `dst = min(src0, expand_rows(src1))` | +| `pto.tile.rowexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_rows(src1))` (f-only) | **Parameters**: @@ -239,8 +239,8 @@ These combine broadcasting with an arithmetic operation: `src1` is a per-row coe ```python # alpha_tile: [rows, 1], beta_tile: [rows, 1], data_tile: [rows, cols] -pto.trowexpandmul(data_tile, alpha_tile, scaled_tile) -pto.trowexpandadd(scaled_tile, beta_tile, result_tile) +pto.tile.rowexpandmul(data_tile, alpha_tile, scaled_tile) +pto.tile.rowexpandadd(scaled_tile, beta_tile, result_tile) ``` --- @@ -251,31 +251,31 @@ Same pattern as row-expand arithmetic, but `src1` is a per-column coefficient ti | Op | Semantics | |----|-----------| -| `pto.tcolexpandadd(src0, src1, dst)` | `dst = src0 + expand_cols(src1)` | -| `pto.tcolexpandsub(src0, src1, dst)` | `dst = src0 - expand_cols(src1)` | -| `pto.tcolexpandmul(src0, src1, dst)` | `dst = src0 * expand_cols(src1)` | -| `pto.tcolexpanddiv(src0, src1, dst)` | `dst = src0 / expand_cols(src1)` (f-only) | -| `pto.tcolexpandmax(src0, src1, dst)` | `dst = max(src0, expand_cols(src1))` | -| `pto.tcolexpandmin(src0, src1, dst)` | `dst = min(src0, expand_cols(src1))` | -| `pto.tcolexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_cols(src1))` (f-only) | +| `pto.tile.colexpandadd(src0, src1, dst)` | `dst = src0 + expand_cols(src1)` | +| `pto.tile.colexpandsub(src0, src1, dst)` | `dst = src0 - expand_cols(src1)` | +| `pto.tile.colexpandmul(src0, src1, dst)` | `dst = src0 * expand_cols(src1)` | +| `pto.tile.colexpanddiv(src0, src1, dst)` | `dst = src0 / expand_cols(src1)` (f-only) | +| `pto.tile.colexpandmax(src0, src1, dst)` | `dst = max(src0, expand_cols(src1))` | +| `pto.tile.colexpandmin(src0, src1, dst)` | `dst = min(src0, expand_cols(src1))` | +| `pto.tile.colexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_cols(src1))` (f-only) | --- ### 8.1.7 Selection -#### `pto.tsel(mask: Tile, src0: Tile, src1: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.tile.sel(mask: Tile, src0: Tile, src1: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` -**Description**: Element-wise ternary: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. The `mask` is an integer tile where zero means false and non-zero means true. +**Description**: Element-wise ternary: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. The `mask` is an integer tile where zero means false and non-zero means true. `tmp` is an optional scratch tile override; when omitted, PTODSL synthesizes any architecture-specific scratch tile automatically. -#### `pto.tsels(mask: Tile, src: Tile, scalar: ScalarType, tmp: Tile, dst: Tile) -> None` +#### `pto.tile.sels(mask: Tile, src: Tile, scalar: ScalarType, dst: Tile, *, tmp: Tile | None = None) -> None` -**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. +**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. As with `tile.sel`, `tmp` is optional and PTODSL synthesizes any required scratch tile automatically when it is omitted. --- ### 8.1.8 Type conversion -#### `pto.tcvt(src: Tile, dst: Tile, *, rmode: RoundMode = RoundMode.NONE) -> None` +#### `pto.tile.cvt(src: Tile, dst: Tile, *, rmode: RoundMode = RoundMode.NONE) -> None` **Description**: Element-wise type conversion. The destination tile's `dtype` determines the target type. @@ -291,24 +291,187 @@ Same pattern as row-expand arithmetic, but `src1` is a per-column coefficient ti --- -### 8.1.9 Tile compute quick reference +### 8.1.9 Bitwise ops + +Bitwise operations on integer tiles (i8, i16, i32, etc.). All follow the standard `(src, dst)` or `(src0, src1, dst)` pattern. + +#### Unary bitwise + +#### `pto.tile.bit_not(src: Tile, dst: Tile) -> None` + +**Description**: Element-wise bitwise NOT: `dst[i,j] = ~src[i,j]`. Integer types only. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (integer dtype) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### Binary bitwise (tile-tile) + +#### `pto.tile.bit_and(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.bit_or(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.bit_shl(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.bit_shr(src0: Tile, src1: Tile, dst: Tile) -> None` + +**Description**: Element-wise bitwise `dst[i,j] = src0[i,j] src1[i,j]`. Integer types only. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile | +| `src1` | `Tile` | Second source tile | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tile.bit_xor(src0: Tile, src1: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise bitwise XOR. Requires an additional scratch buffer `tmp` of the same type as `dst`. When `tmp` is omitted, PTODSL synthesizes a matching scratch tile automatically. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile | +| `src1` | `Tile` | Second source tile | +| `dst` | `Tile` | Destination tile | +| `tmp` | `Tile | None` | Optional scratch tile; when omitted, PTODSL synthesizes one automatically | + +**Returns**: None. + +--- + +#### Binary bitwise (tile-scalar) + +#### `pto.tile.bit_ands(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.bit_ors(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.bit_shls(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.bit_shrs(src: Tile, scalar: ScalarType, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src[i,j] scalar`. The scalar is broadcast to all elements. Integer types only. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `scalar` | `ScalarType` | Scalar operand (Python int or PTO scalar) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tile.bit_xors(src: Tile, scalar: ScalarType, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise bitwise XOR with scalar. Requires an additional scratch buffer `tmp` of the same type as `dst`. When `tmp` is omitted, PTODSL synthesizes a matching scratch tile automatically. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `scalar` | `ScalarType` | Scalar operand | +| `dst` | `Tile` | Destination tile | +| `tmp` | `Tile | None` | Optional scratch tile; when omitted, PTODSL synthesizes one automatically | + +**Returns**: None. + +--- + +### 8.1.10 Partial elementwise ops + +Partial elementwise ops compute over the **intersection** of the valid regions of two source tiles. This allows element-wise arithmetic between tiles that have different `valid_shape`s — only the overlapping area is computed. + +#### `pto.tile.partadd(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.partmul(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.partmax(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.partmin(src0: Tile, src1: Tile, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src0[i,j] src1[i,j]` over the intersection of `src0.valid_shape` and `src1.valid_shape`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile (may have a partial valid region) | +| `src1` | `Tile` | Second source tile (may have a partial valid region) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +**Example** — adding tiles with different valid regions: + +```python +# a_tile: valid_shape = [64, 32], b_tile: valid_shape = [64, 64] +# The partial add only operates on the intersection: 64 columns × min(32, 64) = 32 columns +pto.tile.partadd(a_tile, b_tile, result_tile) +``` + +--- + +### 8.1.11 Fill/padding + +Fill-padding ops copy a source tile's valid region into a destination tile, filling the remaining physical elements (outside `src.valid_shape`) with a configured pad value. The pad value is specified at tile allocation time via the tile's `PadValue` attribute (`Null`, `Zero`, `Max`, or `Min`). + +#### `pto.tile.fillpad(src: Tile, dst: Tile) -> None` + +**Description**: Copies `src`'s valid region into `dst` and fills extra elements of `dst` with the pad value configured on `dst`'s type. The `dst` physical shape must be at least as large as `src.valid_shape`. + +#### `pto.tile.fillpad_expand(src: Tile, dst: Tile) -> None` + +**Description**: Like `fillpad`, but the destination tile may have a different shape in the partition/tensor view. The src valid region is copied and the expanded area is filled with the pad value. Useful when expanding a tile into a larger buffer for downstream processing. + +#### `pto.tile.fillpad_inplace(src: Tile, dst: Tile) -> None` + +**Description**: In-place variant of `fillpad`. `src` and `dst` may refer to the same tile buffer, padding the tile's own valid region in place. + +**Parameters** (all three ops): + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (with valid region to copy) | +| `dst` | `Tile` | Destination tile (carries `PadValue` attribute set at allocation) | + +**Returns**: None. + +**Example** — padding a partial tile to full shape: + +```python +# tile has valid_shape [32, 16] in a physical buffer of [32, 32] +# pad=Zero at allocation time fills extra columns with zeros +pto.tile.fillpad(partial_tile, padded_tile) +``` + +--- + +### 8.1.12 Tile compute quick reference | Category | Operations | |----------|------------| -| Binary tile-tile | `tadd`, `tsub`, `tmul`, `tdiv`, `tmax`, `tmin` | -| Tile-scalar | `tadds`, `tsubs`, `tmuls`, `tdivs`, `tmaxs`, `tmins` | -| Unary math | `texp`, `tlog`, `tsqrt`, `trsqrt`, `trecip`, `tabs`, `tneg` | -| Activation | `trelu`, `tlrelu` | -| Row reductions | `trowsum`, `trowmax`, `trowmin`, `trowprod`, `trowargmax`, `trowargmin` | -| Column reductions | `tcolsum`, `tcolmax`, `tcolmin`, `tcolprod` | -| Broadcast | `texpands`, `trowexpand`, `tcolexpand` | -| Row-expand arith | `trowexpandadd`, `trowexpandsub`, `trowexpandmul`, `trowexpanddiv`, `trowexpandmax`, `trowexpandmin`, `trowexpandexpdif` | -| Col-expand arith | `tcolexpandadd`, `tcolexpandsub`, `tcolexpandmul`, `tcolexpanddiv`, `tcolexpandmax`, `tcolexpandmin`, `tcolexpandexpdif` | -| Selection | `tsel`, `tsels` | -| Type conversion | `tcvt` | -| Bitwise | `tnot`, `tand`, `tor`, `txor`, `tshl`, `tshr`, `tands`, `tors`, `txors`, `tshls`, `tshrs` | -| Partial elementwise | `tpartadd`, `tpartmul`, `tpartmax`, `tpartmin` | -| Fill/padding | `tfillpad`, `tfillpad_expand`, `tfillpad_inplace` | +| Binary tile-tile | `tile.add`, `tile.sub`, `tile.mul`, `tile.div`, `tile.max`, `tile.min` | +| Tile-scalar | `tile.adds`, `tile.subs`, `tile.muls`, `tile.divs`, `tile.maxs`, `tile.mins` | +| Unary math | `tile.exp`, `tile.log`, `tile.sqrt`, `tile.rsqrt`, `tile.recip`, `tile.abs`, `tile.neg` | +| Activation | `tile.relu`, `tile.lrelu` | +| Row reductions | `tile.rowsum`, `tile.rowmax`, `tile.rowmin`, `tile.rowprod`, `tile.rowargmax`, `tile.rowargmin` | +| Column reductions | `tile.colsum`, `tile.colmax`, `tile.colmin`, `tile.colprod` | +| Broadcast | `tile.expands`, `tile.rowexpand`, `tile.colexpand` | +| Row-expand arith | `tile.rowexpandadd`, `tile.rowexpandsub`, `tile.rowexpandmul`, `tile.rowexpanddiv`, `tile.rowexpandmax`, `tile.rowexpandmin`, `tile.rowexpandexpdif` | +| Col-expand arith | `tile.colexpandadd`, `tile.colexpandsub`, `tile.colexpandmul`, `tile.colexpanddiv`, `tile.colexpandmax`, `tile.colexpandmin`, `tile.colexpandexpdif` | +| Selection | `tile.sel`, `tile.sels` | +| Type conversion | `tile.cvt` | +| Bitwise | `tile.bit_not`, `tile.bit_and`, `tile.bit_or`, `tile.bit_xor`, `tile.bit_shl`, `tile.bit_shr`, `tile.bit_ands`, `tile.bit_ors`, `tile.bit_xors`, `tile.bit_shls`, `tile.bit_shrs` | +| Partial elementwise | `tile.partadd`, `tile.partmul`, `tile.partmax`, `tile.partmin` | +| Fill/padding | `tile.fillpad`, `tile.fillpad_expand`, `tile.fillpad_inplace` | --- @@ -347,6 +510,7 @@ All vector ops in this section follow the pattern established in Section 7.3 for **Example**: + ```python exp_vec = pto.vexp(s_row, col_mask) ``` @@ -418,6 +582,7 @@ exp_vec = pto.vexp(s_row, col_mask) **Example** — subtract row max from score row (online softmax): + ```python s_shifted = pto.vsubs(s_row, m_next, col_mask) ``` @@ -452,11 +617,11 @@ s_shifted = pto.vsubs(s_row, m_next, col_mask) These reduce within each hardware vector lane group (typically 8 groups per vector). Useful when a vector register holds multiple independent sub-vectors that need separate reductions. -#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> VRegType` -#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> VRegType` -#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> ScalarType` +#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> ScalarType` +#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> ScalarType` -**Description**: Per-group sum, max, or min. Each group's result is placed in the first lane of that group. +**Description**: Per-group sum, max, or min. The underlying vector reduction places each group's result in the first lane of that group; the ptodsl surface extracts lane 0 and returns it as a runtime scalar. **Parameters**: @@ -469,13 +634,14 @@ These reduce within each hardware vector lane group (typically 8 groups per vect | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Vector with per-group reduction results | +| `result` | `ScalarType` | Lane-0 scalar extracted from the grouped reduction result | **Example** — row max and row sum from online softmax: + ```python -row_max = pto.vcgmax(s_row, col_mask) # per-group max → first lane of each group -row_sum = pto.vcgadd(p_row, col_mask) # per-group sum → first lane of each group +row_max = pto.vcgmax(s_row, col_mask) # grouped reduction, surfaced as a runtime scalar +row_sum = pto.vcgadd(p_row, col_mask) # grouped reduction, surfaced as a runtime scalar ``` --- @@ -490,9 +656,9 @@ row_sum = pto.vcgadd(p_row, col_mask) # per-group sum → first lane of each g These combine an arithmetic operation with a math function or activation in a single instruction. -#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, *, part: PartMode = PartMode.EVEN) -> VRegType` +#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, *, part: PartMode = PartMode.ODD) -> VRegType` -**Description**: `exp(vec[i] - max_vec[i])` — the stable softmax numerator. `part` controls which half of the vector is computed: `EVEN` or `ODD`. Result type is always f32. +**Description**: `exp(vec[i] - max_vec[i])` — the stable softmax numerator. `part` controls which half of the vector is computed: `EVEN` or `ODD`. The result keeps the same `VRegType` as the input vector. --- @@ -565,20 +731,16 @@ These combine an arithmetic operation with a math function or activation in a si | Category | Operations | |----------|------------| -| Unary | `vexp`, `vln`, `vsqrt`, `vabs`, `vneg`, `vrec`, `vrsqrt`, `vrelu`, `vnot`, `vmov`, `vcls`, `vbcnt` | -| Binary | `vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor`, `vshl`, `vshr`, `vmod` | -| Vector-scalar | `vadds`, `vsubs`, `vmuls`, `vmaxs`, `vmins`, `vshls`, `vshrs`, `vlrelu`, `vands`, `vors`, `vxors` | -| Broadcast | `vbr`, `vdup` | +| Unary | `vexp`, `vln`, `vsqrt`, `vabs`, `vneg`, `vrec`, `vrsqrt`, `vrelu`, `vnot` | +| Binary | `vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor`, `vshl`, `vshr` | +| Vector-scalar | `vadds`, `vsubs`, `vmuls`, `vmaxs`, `vmins`, `vlrelu` | +| Broadcast | `vdup` | | Full reduction | `vcadd`, `vcmax`, `vcmin` | | Group reduction | `vcgadd`, `vcgmax`, `vcgmin` | | Scan | `vcpadd` | -| Fused | `vexpdif`, `vaxpy`, `vprelu`, `vaddrelu`, `vsubrelu`, `vmulconv`, `vaddreluconv` | -| Compare/select | `vcmp`, `vcmps`, `vsel`, `vselr`, `vselrv2` | -| Carry | `vaddc`, `vsubc`, `vaddcs`, `vsubcs` | -| Extended arith | `vmull`, `vmula` | -| Conversion | `vcvt`, `vtrc`, `vbitcast`, `pbitcast` | -| Index gen | `vci` | -| Rearrangement | `vintlv`, `vdintlv`, `vintlvv2`, `vdintlvv2`, `vsqz`, `vusqz`, `vpack`, `vsunpack`, `vzunpack`, `vperm`, `vshift`, `vslide`, `vsort32`, `vmrgsort`, `vtranspose` | +| Fused | `vexpdif`, `vaxpy`, `vaddrelu`, `vsubrelu` | +| Compare/select | `vcmp`, `vcmps`, `vsel` | +| Conversion | `vbitcast`, `pbitcast` | --- @@ -607,30 +769,49 @@ The Cube unit performs matrix multiplication. Its operands are typed pointers in --- -#### `pto.mad_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, k: int, n: int) -> None` +#### `pto.mad_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` **Description**: Accumulating matrix multiply: `dst[M×N] += lhs[M×K] * rhs[K×N]`. `dst` must already hold a prior accumulation result. --- -#### `pto.mad_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, k: int, n: int) -> None` +#### `pto.mad_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int) -> None` **Description**: Bias-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N] + bias[M×N]`. `bias` is a BIAS pointer. --- +#### `pto.mad_mx(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` + +**Description**: MX-format zero-initialized matrix multiply. This variant is intended for MX-enabled operand formats such as f8 payloads with their associated scale data already staged into cube-local buffers. + +--- + +#### `pto.mad_mx_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` + +**Description**: MX-format accumulating matrix multiply: `dst[M×N] += lhs[M×K] * rhs[K×N]`. + +--- + +#### `pto.mad_mx_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int) -> None` + +**Description**: MX-format bias-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N] + bias[M×N]`. + +--- + ### 8.3.2 Typical cube matmul pattern A full cube matmul follows a three-stage pattern: stage operands into L0A/L0B, compute, write back to UB. + ```python @pto.cube def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): m = q_tile.valid_shape[0] k = q_tile.valid_shape[1] - n = k_tile.valid_shape[0] + n = k_tile.valid_shape[1] - # Stage: UB → L0A / L0B + # Stage: source tiles → L0A / L0B pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) @@ -641,7 +822,7 @@ def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) ``` -The `mte_l1_l0a`/`mte_l1_l0b` stage operands from UB into cube-local buffers. `mad` performs the matrix multiply into L0C. `mte_l0c_ub` writes the result back to a UB tile for downstream processing. At this micro-op layer, the operands are explicit pointer views obtained with `.as_ptr()`. +The `mte_l1_l0a`/`mte_l1_l0b` stage operands from the authored source tiles into cube-local buffers. `mad` performs the matrix multiply into L0C. `mte_l0c_ub` writes the result back to a UB tile for downstream processing. At this micro-op layer, the operands are explicit pointer views obtained with `.as_ptr()`. --- @@ -650,10 +831,10 @@ The `mte_l1_l0a`/`mte_l1_l0b` stage operands from UB into cube-local buffers. `m | Operation | Semantics | |-----------|-----------| | `pto.mad(lhs, rhs, dst, m, n, k)` | `dst = lhs * rhs` (zero-init) | -| `pto.mad_acc(lhs, rhs, dst, m, k, n)` | `dst += lhs * rhs` (accumulating) | -| `pto.mad_bias(lhs, rhs, dst, bias, m, k, n)` | `dst = lhs * rhs + bias` | -| `pto.mad_mx(lhs, rhs, dst, m, k, n)` | MX-format zero-init matmul | -| `pto.mad_mx_acc(lhs, rhs, dst, m, k, n)` | MX-format accumulating matmul | -| `pto.mad_mx_bias(lhs, rhs, dst, bias, m, k, n)` | MX-format bias-init matmul | +| `pto.mad_acc(lhs, rhs, dst, m, n, k)` | `dst += lhs * rhs` (accumulating) | +| `pto.mad_bias(lhs, rhs, dst, bias, m, n, k)` | `dst = lhs * rhs + bias` | +| `pto.mad_mx(lhs, rhs, dst, m, n, k)` | MX-format zero-init matmul | +| `pto.mad_mx_acc(lhs, rhs, dst, m, n, k)` | MX-format accumulating matmul | +| `pto.mad_mx_bias(lhs, rhs, dst, bias, m, n, k)` | MX-format bias-init matmul | MX variants require MX-enabled dtypes (f8) and pre-loaded scale payloads. For most users, the standard `mad`, `mad_acc`, and `mad_bias` are the primary interface. diff --git a/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md index e8cc6bf6b..2ba02213a 100644 --- a/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md +++ b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md @@ -42,6 +42,7 @@ The recommended front door for creating masks is `pto.make_mask`. It dispatches **Example** — chunked SIMD loop with tail handling: + ```python VEC = pto.elements_per_vreg(pto.f32) col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) @@ -61,6 +62,7 @@ with col_loop: When the mask pattern is known at compile time, pass a `MaskPattern` instead: + ```python full_mask = pto.make_mask(pto.f32, pto.MaskPattern.ALL) ``` @@ -77,6 +79,17 @@ When you need explicit control over the mask granularity, use these ops directly `pset` generates a mask from a named pattern. `pge` generates a tail mask where the first N lanes are active (N encoded in the pattern). + +```python +full_mask = pto.pset_b32(pto.MaskPattern.ALL) +``` + + +```python +mask8 = pto.pset_b8(pto.MaskPattern.ALL) +mask16 = pto.pset_b16(pto.MaskPattern.ALL) +``` + #### `pto.pset_b8(pattern: MaskPattern) -> pto.mask_b8` #### `pto.pset_b16(pattern: MaskPattern) -> pto.mask_b16` #### `pto.pset_b32(pattern: MaskPattern) -> pto.mask_b32` @@ -109,6 +122,11 @@ When you need explicit control over the mask granularity, use these ops directly `plt` generates a tail mask from a live `i32` scalar — the idiomatic choice for dynamic tail handling when not using `make_mask`. + +```python +mask, remained = pto.plt_b32(remained) +``` + #### `pto.plt_b8(scalar: pto.i32) -> (pto.mask_b8, pto.i32)` #### `pto.plt_b16(scalar: pto.i32) -> (pto.mask_b16, pto.i32)` #### `pto.plt_b32(scalar: pto.i32) -> (pto.mask_b32, pto.i32)` @@ -134,6 +152,11 @@ When you need explicit control over the mask granularity, use these ops directly ## 9.4 Mask logical operations + +```python +merged = pto.pand(src0, src1, gate) +``` + Once created, masks can be combined with bitwise logical ops. All take a gating mask that selects which lanes participate; inactive lanes are zeroed in the result. #### `pto.pand(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` @@ -180,6 +203,7 @@ These ops reshape masks between granularities and layouts without changing the u **Example**: + ```python # Reinterpret a b16 mask as b32 mask32 = pto.pbitcast(mask16, pto.mask_b32) @@ -195,6 +219,12 @@ mask32 = pto.pbitcast(mask16, pto.mask_b32) **Description**: Widening unpack — reads the selected half of the source, zero-extends each 1-bit element into a 2-bit group in the result. + +```python +packed_hi = pto.ppack(mask32, pto.PredicatePart.HIGHER) +unpacked_hi = pto.punpack(packed_hi, pto.PredicatePart.HIGHER) +``` + --- #### `pto.pintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` @@ -242,6 +272,7 @@ Vector comparisons produce predicate masks from vector data. The result can feed **Example** — threshold a vector: + ```python big = pto.vcmps(scores, threshold, seed, pto.CmpMode.GT) # big[i] = 1 where scores[i] > threshold @@ -249,7 +280,7 @@ big = pto.vcmps(scores, threshold, seed, pto.CmpMode.GT) --- -**Tile-level comparisons** (`pto.tcmp`, `pto.tcmps`) compare two tiles and write packed predicate bytes into an `i8` destination tile. They are used when the comparison result needs to be stored to UB for later selection (`tsel`) or cross-kernel communication. +**Tile-level comparisons** (`pto.tile.cmp`, `pto.tile.cmps`) compare two tiles and write packed predicate bytes into an `i8` destination tile. They are used when the comparison result needs to be stored to UB for later selection (`tile.sel`) or cross-kernel communication. --- @@ -261,7 +292,7 @@ Masks can be persisted to and loaded from UB memory, enabling cross-stage predic #### `pto.plds(buf: PtrType, offset: Index, *, dist: PredicateDist = PredicateDist.NORM) -> MaskType` -**Description**: Load a predicate mask from UB memory at the given byte offset. The mask granularity is inferred from context. +**Description**: Load a predicate mask from UB memory at the given byte offset. The mask granularity is determined by the pointer element type of `buf` (`ui8`/`ui16`/`ui32` -> `mask_b8`/`mask_b16`/`mask_b32`). **Parameters**: @@ -302,14 +333,14 @@ Masks can be persisted to and loaded from UB memory, enabling cross-stage predic #### `pto.pstu(align_in: AlignType, mask: MaskType, buf: PtrType) -> (AlignType, PtrType)` -**Description**: Unaligned predicate store with alignment state threading. Threads the `align` state through a stream of stores, ensuring tail bytes are correctly buffered. The base pointer type is determined by the mask granularity (`ui16` for `b16`, `ui32` for `b32`). +**Description**: Unaligned predicate store with alignment state threading. Threads the `align` state through a stream of stores, ensuring tail bytes are correctly buffered. This op currently supports only `mask_b16` and `mask_b32`; the base pointer type is determined by the mask granularity (`ui16` for `b16`, `ui32` for `b32`). **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `align_in` | `AlignType` | Incoming alignment state (from `init_align` or previous `pstu`) | -| `mask` | `MaskType` | Predicate mask to store | +| `mask` | `MaskType` | Predicate mask to store (`mask_b16` or `mask_b32` only) | | `buf` | `PtrType` (UB) | Destination buffer | **Returns**: @@ -333,6 +364,7 @@ The mask granularity must match the vector element type. Using a `mask_b16` with **Typical pattern** — tail-safe vector processing: + ```python VEC = pto.elements_per_vreg(pto.f32) with pto.for_(0, rows, step=1) as r: @@ -355,11 +387,11 @@ The `mask` gates the `vexp` (masked-off lanes produce 0) and the `vsts` (masked- ## 9.9 Tile-level mask operations -When working at the tile level (L1, `@pto.jit`), masks are carried in `i8` tile buffers holding packed predicate bytes. The key consumer of tile-level masks is `tsel`. +When working at the tile level (L1, `@pto.jit`), masks are carried in `i8` tile buffers holding packed predicate bytes. The key consumer of tile-level masks is `tile.sel`. -#### `pto.tsel(mask: Tile, src0: Tile, src1: Tile, tmp: Tile, dst: Tile) -> None` +#### `pto.tile.sel(mask: Tile, src0: Tile, src1: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` -**Description**: Element-wise ternary select: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. `mask` is an integer tile (typically `i8`) where zero means false. `tmp` is a scratch tile for the underlying implementation. +**Description**: Element-wise ternary select: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. `mask` is an integer tile (typically `i8`) where zero means false. `tmp` is an optional scratch tile override; when omitted, PTODSL synthesizes any architecture-specific scratch tile automatically. **Parameters**: @@ -368,16 +400,16 @@ When working at the tile level (L1, `@pto.jit`), masks are carried in `i8` tile | `mask` | `Tile` | Integer mask tile (zero = false) | | `src0` | `Tile` | True-branch source tile | | `src1` | `Tile` | False-branch source tile | -| `tmp` | `Tile` | Scratch tile | +| `tmp` | `Tile \| None` | Optional scratch tile override | | `dst` | `Tile` | Destination tile | **Returns**: None. --- -#### `pto.tsels(mask: Tile, src: Tile, scalar: ScalarType, tmp: Tile, dst: Tile) -> None` +#### `pto.tile.sels(mask: Tile, src: Tile, scalar: ScalarType, dst: Tile, *, tmp: Tile | None = None) -> None` -**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. +**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. As with `tile.sel`, `tmp` is optional and PTODSL synthesizes any required scratch tile automatically when it is omitted. --- diff --git a/ptodsl/docs/user_guide/10-sync-ops.md b/ptodsl/docs/user_guide/10-sync-ops.md index a0a9b3a03..06416d587 100644 --- a/ptodsl/docs/user_guide/10-sync-ops.md +++ b/ptodsl/docs/user_guide/10-sync-ops.md @@ -46,22 +46,16 @@ Hardware pipeline identifiers used with `pto.set_flag`, `pto.wait_flag`, and `pt The most commonly used pipes in synchronization are `MTE2` (GM ↔ UB DMA), `MTE3` (UB ↔ UB DMA), `V` (vector compute), and `M` (matrix compute). -### `Event` +### `event_id` -Event identifiers for pipeline synchronization flags. The hardware provides 8 event IDs (0–7) per pipeline pair, supporting up to 8 concurrent in-flight DMA/compute sequences. +Event identifiers for pipeline synchronization flags. The hardware provides 8 event IDs (`0`–`7`) per pipeline pair, supporting up to 8 concurrent in-flight DMA/compute sequences. -| Member | Value | -|--------|-------| -| `ID0` | Event 0 | -| `ID1` | Event 1 | -| `ID2` | Event 2 | -| `ID3` | Event 3 | -| `ID4` | Event 4 | -| `ID5` | Event 5 | -| `ID6` | Event 6 | -| `ID7` | Event 7 | +In PTODSL, `event_id` may be either: -Events are per-pipeline-pair: the same `ID0` used between `MTE2 → V` is independent from `ID0` used between `MTE3 → V`. +- a Python integer literal in `0`–`7` +- a runtime index-like PTO scalar value + +Events are per-pipeline-pair: the same `event_id=0` used between `MTE2 → V` is independent from `event_id=0` used between `MTE3 → V`. --- @@ -69,7 +63,7 @@ Events are per-pipeline-pair: the same `ID0` used between `MTE2 → V` is indepe Pipeline synchronization is the primary mechanism for ordering work across pipelines. The pattern is always **signal then wait**: the producer pipeline sets a flag when its work is done; the consumer pipeline waits on that flag before proceeding. -### `pto.set_flag(pipe_from, pipe_to, event_id)` +### `pto.set_flag(pipe_from, pipe_to, *, event_id=0)` **Description**: Sets a synchronization flag between two hardware pipelines. The producing pipeline signals that work up to this point is complete. @@ -79,20 +73,19 @@ Pipeline synchronization is the primary mechanism for ordering work across pipel |-----------|------|-------------| | `pipe_from` | `Pipe` | Source pipeline — the pipeline that has completed its work | | `pipe_to` | `Pipe` | Destination pipeline — the pipeline being notified | -| `event_id` | `Event` | Event identifier for this specific synchronization point | +| `event_id` | `int` or index-like PTO scalar | Event identifier for this specific synchronization point (`0`–`7`) | **Returns**: None (side-effect operation). **Example**: + ```python -from pto import Pipe, Event - # MTE2 has finished loading tile data — signal Vector pipeline -pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID0) +pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) ``` -### `pto.wait_flag(pipe_from, pipe_to, event_id)` +### `pto.wait_flag(pipe_from, pipe_to, *, event_id=0)` **Description**: Waits for a synchronization flag. The consuming pipeline blocks until the flag is set by the producing pipeline. @@ -102,17 +95,16 @@ pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID0) |-----------|------|-------------| | `pipe_from` | `Pipe` | Source pipeline that set the flag | | `pipe_to` | `Pipe` | Destination pipeline — the pipeline that is waiting | -| `event_id` | `Event` | Event identifier matching the corresponding `set_flag` | +| `event_id` | `int` or index-like PTO scalar | Event identifier matching the corresponding `set_flag` (`0`–`7`) | **Returns**: None (side-effect operation). **Example**: + ```python -from pto import Pipe, Event - # Vector pipeline waits for MTE2 to finish loading -pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID0) +pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) ``` ### `pto.pipe_barrier(pipes)` @@ -129,41 +121,60 @@ pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID0) **Example**: + ```python -from pto import Pipe - # Full hardware barrier — all pipelines synchronize -pto.pipe_barrier(Pipe.ALL) +pto.pipe_barrier(pto.Pipe.ALL) ``` ### Typical usage pattern A common ukernel pattern interleaves DMA and compute with `set_flag` / `wait_flag` pairs: + ```python @pto.ukernel -def gemm_block(q_tile, k_tile, v_tile, o_tile, ...): +def gemm_block( + q_tile: pto.Tile, + k_part: pto.PartitionTensorView, + v_part: pto.PartitionTensorView, + k_tile: pto.Tile, + v_tile: pto.Tile, + p_tile: pto.Tile, + o_tile: pto.Tile, + o_part: pto.PartitionTensorView, + rows: pto.i32, + cols: pto.i32, +): # DMA: load K and V tiles from GM to UB - # mte_load derives strides, burst sizes, etc. from k_part / k_tile types - pto.mte_load(k_part, k_tile) - pto.mte_load(v_part, v_tile) + row_bytes = cols * pto.bytewidth(pto.f16) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f16) + ub_row_stride = k_tile.shape[1] * pto.bytewidth(pto.f16) + out_row_bytes = cols * pto.bytewidth(pto.f32) + out_gm_row_stride = o_part.strides[0] * pto.bytewidth(pto.f32) + out_ub_row_stride = o_tile.shape[1] * pto.bytewidth(pto.f32) + pto.mte_load(k_part.as_ptr(), k_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.mte_load(v_part.as_ptr(), v_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) # Signal: DMA done, UB data ready - pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID0) + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) # Wait: vector pipeline stalls until data arrives - pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID0) + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) # Compute: now safe to use k_tile and v_tile - qk_matmul(q_tile, k_tile, ...) - pv_matmul(p_tile, v_tile, ...) + qk_matmul(q_tile, k_tile, p_tile) + pv_matmul(p_tile, v_tile, o_tile) # Signal: compute done, results ready for store - pto.set_flag(Pipe.V, Pipe.MTE3, Event.ID1) - pto.wait_flag(Pipe.V, Pipe.MTE3, Event.ID1) + pto.set_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.wait_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) # DMA: store results back to GM - pto.mte_store(o_tile, o_part) + pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), out_row_bytes, + nburst=(rows, out_ub_row_stride, out_gm_row_stride)) ``` --- @@ -202,23 +213,22 @@ Double-buffering is a common optimization in NPU kernels: while one buffer is be ### Double-buffering example + ```python -from pto import Pipe - # Pipeline V acquires buffer 0 for compute -pto.get_buf(Pipe.V, 0, 0) +pto.get_buf(pto.Pipe.V, 0, 0) # ... compute into buffer 0 ... # Release buffer 0 — DMA can now refill it -pto.rls_buf(Pipe.V, 0, 0) +pto.rls_buf(pto.Pipe.V, 0, 0) # Pipeline MTE2 acquires buffer 0 for reload -pto.get_buf(Pipe.MTE2, 0, 0) +pto.get_buf(pto.Pipe.MTE2, 0, 0) # ... DMA loads next block into buffer 0 ... -pto.rls_buf(Pipe.MTE2, 0, 0) +pto.rls_buf(pto.Pipe.MTE2, 0, 0) ``` --- @@ -241,11 +251,10 @@ Within a single pipeline, load and store instructions may be reordered by the ha **Example**: + ```python -from pto import BarrierType - # Ensure all prior vector stores are visible before any subsequent vector loads -pto.mem_bar(BarrierType.VST_VLD) +pto.mem_bar(pto.BarrierType.VST_VLD) ``` The most commonly used barrier types in practice: @@ -262,29 +271,49 @@ The most commonly used barrier types in practice: In flash attention, phase boundaries use `pipe_barrier(Pipe.ALL)`, while `mem_bar` remains the tool for narrower intra-pipeline ordering: + ```python @pto.ukernel -def flash_attention_block(q_tile, k_tile, v_tile, ...): +def flash_attention_block( + q_tile: pto.Tile, + k_part: pto.PartitionTensorView, + v_part: pto.PartitionTensorView, + k_tile: pto.Tile, + v_tile: pto.Tile, + s_tile: pto.Tile, + p_tile: pto.Tile, + pv_tile: pto.Tile, + o_prev_tile: pto.Tile, + o_next_tile: pto.Tile, + rows: pto.i32, + cols: pto.i32, +): # Phase 1: load K/V - pto.mte_load(k_part, k_tile) - pto.mte_load(v_part, v_tile) - pto.pipe_barrier(Pipe.ALL) + row_bytes = cols * pto.bytewidth(pto.f16) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f16) + ub_row_stride = k_tile.shape[1] * pto.bytewidth(pto.f16) + pto.mte_load(k_part.as_ptr(), k_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.mte_load(v_part.as_ptr(), v_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.pipe_barrier(pto.Pipe.ALL) # Phase 2: S = Q @ K^T - qk_matmul(q_tile, k_tile, ...) - pto.pipe_barrier(Pipe.ALL) + qk_matmul(q_tile, k_tile, s_tile) + pto.pipe_barrier(pto.Pipe.ALL) # Phase 3: softmax(S) - online_softmax(s_tile, ...) - pto.pipe_barrier(Pipe.ALL) + online_softmax(s_tile, p_tile, rows, cols) + pto.mem_bar(pto.BarrierType.VV_ALL) + pto.pipe_barrier(pto.Pipe.ALL) # Phase 4: PV = P @ V - pv_matmul(p_tile, v_tile, ...) - pto.pipe_barrier(Pipe.ALL) + pv_matmul(p_tile, v_tile, pv_tile) + pto.pipe_barrier(pto.Pipe.ALL) # Phase 5: blend output - blend_output(o_prev_tile, pv_tile, ...) - pto.pipe_barrier(Pipe.ALL) + blend_output(o_prev_tile, pv_tile, o_next_tile, rows, cols) + pto.pipe_barrier(pto.Pipe.ALL) ``` --- @@ -293,85 +322,82 @@ def flash_attention_block(q_tile, k_tile, v_tile, ...): Section 10.2 covers the general pipe-to-pipe sync mechanism (`set_flag`/`wait_flag`). This section covers two additional sync domains that the pipe-flag mechanism does not address: **cross-core** communication between separate NPU cores, and **intra-block** synchronization between the Cube and Vector units within a block. -### 10.5.1 Cross-core sync: `set_cross_core`, `wait_flag_dev` +### 10.5.1 Cross-core sync: `set_cross_flag`, `wait_cross_flag` -When a kernel spans multiple cores, cores need to coordinate through shared resources. `set_cross_core` sends a signal to another core; `wait_flag_dev` blocks the calling core until the expected signal arrives. +When a kernel spans multiple cores, cores need to coordinate through shared resources. `set_cross_flag` sends a signal to another core; `wait_cross_flag` blocks the calling core until the expected signal arrives. -These are core-level (SU) operations — `wait_flag_dev` stalls the entire core, not just a single pipeline. Use them sparingly: splitting work so that each core operates independently for as long as possible minimises cross-core sync overhead. +These are core-level (SU) operations — `wait_cross_flag` stalls the entire core, not just a single pipeline. Use them sparingly: splitting work so that each core operates independently for as long as possible minimises cross-core sync overhead. -#### `pto.set_cross_core(core_id, event_id)` +#### `pto.set_cross_flag(pipe, event_id)` -**Description**: Signal an event to another core, indicating that shared data or a pipeline stage is ready. +**Description**: Signal an event on a synchronization endpoint. In the current PTODSL surface this is authored with a `Pipe`; the backend maps it to the architecture-specific cross-core / intra-block builtin during lowering. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `core_id` | `pto.i64` | Target core identifier (platform-specific mapping) | -| `event_id` | `Event` | Cross-core event identifier | +| `pipe` | `Pipe` | Producing endpoint for the synchronization event. The public DSL accepts `Pipe.FIX` here. | +| `event_id` | `int` | Cross-core event identifier (`0`–`7`) | **Returns**: None (side-effect operation). **Example**: + ```python -from pto import Event - -# Signal core 0 that our computation is complete -pto.set_cross_core(0, Event.ID0) +# Signal from the FIX/Cube-side endpoint +pto.set_cross_flag(pto.Pipe.FIX, 0) ``` -#### `pto.wait_flag_dev(core_id, event_id)` +#### `pto.wait_cross_flag(pipe, event_id)` -**Description**: Wait for an event from another core. Core-level (SU) blocking — the entire core stalls until the event is received. +**Description**: Wait for an event on a synchronization endpoint. On architectures that lower this surface to the backend `sync.wait` primitive, the wait is core-level (SU) blocking. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `core_id` | `pto.i64` | Source core identifier | -| `event_id` | `Event` | Event identifier to wait on | +| `pipe` | `Pipe` | Waiting endpoint for the synchronization event. The public DSL accepts `Pipe.FIX` here. | +| `event_id` | `int` | Event identifier to wait on (`0`–`7`) | **Returns**: None (side-effect operation). **Example**: + ```python -from pto import Event - -# Core 1 waits for core 0 to signal event ID0 -pto.wait_flag_dev(0, Event.ID0) +# Wait on the FIX/Cube-side endpoint +pto.wait_cross_flag(pto.Pipe.FIX, 0) ``` -### 10.5.2 Intra-block sync: `set_intra_block`, `wait_intra_core` +### 10.5.2 Intra-block sync: `set_intra_flag`, `wait_intra_flag` -The Cube unit (matrix pipeline) has a dedicated synchronization channel separate from the standard pipe-flag mechanism used by MTE and Vector pipelines. `set_intra_block` and `wait_intra_core` synchronize Cube and Vector within the same block, ensuring that shared UB tile data is not accessed before the producer finishes. +The Cube unit (matrix pipeline) has a dedicated synchronization channel separate from the standard pipe-flag mechanism used by MTE and Vector pipelines. `set_intra_flag` and `wait_intra_flag` synchronize Cube and Vector within the same block, ensuring that shared UB tile data is not accessed before the producer finishes. -Unlike `wait_flag_dev`, `wait_intra_core` only stalls the specified pipeline — the SU and other pipelines continue executing. +Unlike `wait_cross_flag`, `wait_intra_flag` only stalls the specified pipeline — the SU and other pipelines continue executing. -#### `pto.set_intra_block(block_id, event_id)` +#### `pto.set_intra_flag(pipe, event_id)` -**Description**: Signal a synchronization event within a block. Specifies which trigger pipe fires the event. +**Description**: Signal a synchronization event within a block. The current PTODSL surface authors the trigger endpoint explicitly as a `Pipe`. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `block_id` | `pto.i64` | Block or pipeline identifier for the trigger source | -| `event_id` | `Event` | Event identifier | +| `pipe` | `Pipe` | Trigger endpoint for the synchronization event. The public DSL accepts `Pipe.MTE3` here. | +| `event_id` | `int` | Event identifier (`0`–`7`) | **Returns**: None (side-effect operation). **Example**: + ```python -from pto import Event - -# Signal event ID0 on block/pipeline 0 -pto.set_intra_block(0, Event.ID0) +# Signal event ID0 from the MTE3-side endpoint +pto.set_intra_flag(pto.Pipe.MTE3, 0) ``` -#### `pto.wait_intra_core(block_id, event_id)` +#### `pto.wait_intra_flag(pipe, event_id)` **Description**: Wait for an intra-block event. Only the specified pipeline stalls — the SU and other pipelines continue executing independently. @@ -379,42 +405,19 @@ pto.set_intra_block(0, Event.ID0) | Parameter | Type | Description | |-----------|------|-------------| -| `block_id` | `pto.i64` | Block or pipeline identifier specifying which pipeline waits | -| `event_id` | `Event` | Event identifier to wait on | +| `pipe` | `Pipe` | Waiting endpoint for the synchronization event. The public DSL accepts `Pipe.V` here. | +| `event_id` | `int` | Event identifier to wait on (`0`–`7`) | **Returns**: None (side-effect operation). **Example**: + ```python -from pto import Event - -# Pipeline 1 waits for event ID0 from pipeline 0 within the same block -pto.wait_intra_core(1, Event.ID0) +# Vector-side endpoint waits for event ID0 +pto.wait_intra_flag(pto.Pipe.V, 0) ``` -### 10.5.3 Intra-core configuration: `set_intra_core` - -#### `pto.set_intra_core(config)` - -**Description**: Configures intra-core synchronization parameters. The meaning of `config` is hardware-specific. - -**Parameters**: - -| Parameter | Type | Description | -|-----------|------|-------------| -| `config` | `pto.i32` | Hardware-specific configuration value | - -**Returns**: None (side-effect operation). - -**Example**: - -```python -pto.set_intra_core(3) -``` - ---- - ## 10.6 Synchronization in the abstraction hierarchy Where do sync operations belong in PTODSL's layered model? @@ -429,17 +432,17 @@ Where do sync operations belong in PTODSL's layered model? ### Auto-sync at the tile level -When writing `@pto.jit` code with tile ops (`tload`, `tstore`, `tadd`, etc.), each op carries a pipe assignment (e.g., `tload` → `PIPE_MTE2`, `tadd` → `PIPE_V`). PTOAS's sync-insertion pass analyzes the op sequence, infers the necessary `set_flag`/`wait_flag` pairs from the pipe transitions, and injects them into the lowered code. The tile ops themselves still require synchronization — the difference is that the compiler, not the user, writes it. +When writing `@pto.jit` code with tile ops (`tile.load`, `tile.store`, `tile.add`, etc.), each op carries a pipe assignment (e.g., `tile.load` → `PIPE_MTE2`, `tile.add` → `PIPE_V`). PTOAS's sync-insertion pass analyzes the op sequence, infers the necessary `set_flag`/`wait_flag` pairs from the pipe transitions, and injects them into the lowered code. The tile ops themselves still require synchronization — the difference is that the compiler, not the user, writes it. ### Quick reference: which sync for which scenario | Scenario | Sync primitive | |----------|----------------| -| DMA load must finish before compute | `set_flag(MTE2, V, id)` + `wait_flag(MTE2, V, id)` | -| Compute must finish before DMA store | `set_flag(V, MTE3, id)` + `wait_flag(V, MTE3, id)` | +| DMA load must finish before compute | `set_flag(MTE2, V, event_id=id)` + `wait_flag(MTE2, V, event_id=id)` | +| Compute must finish before DMA store | `set_flag(V, MTE3, event_id=id)` + `wait_flag(V, MTE3, event_id=id)` | | Two compute phases must not overlap | `mem_bar(BarrierType.VV_ALL)` | | Store must be visible to later load (same UB) | `mem_bar(BarrierType.VST_VLD)` | | Full pipeline sync point | `pipe_barrier(Pipe.ALL)` | | Double-buffer handoff (compute → DMA) | `rls_buf(V, id)` + `get_buf(MTE2, id)` | | Double-buffer handoff (DMA → compute) | `rls_buf(MTE2, id)` + `get_buf(V, id)` | -| Core A notifies core B | `set_cross_core(B, id)` + `wait_flag_dev(A, id)` | +| Core A notifies core B | `set_cross_flag(B, id)` + `wait_cross_flag(A, id)` | diff --git a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md index 154a88eac..845017eba 100644 --- a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md +++ b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md @@ -1,6 +1,6 @@ # 11. Flash Attention Complete Walkthrough -This chapter walks through `demos/flash_attention_sketch.py` layer by layer, tracing a complete flash attention implementation from the user-facing Python wrapper down to hardware-bound sub-kernels. Every API discussed in Chapters 1–10 appears in context here. +This chapter walks through `examplesflash_attention_sketch.py` layer by layer, tracing a complete flash attention implementation from the user-facing Python wrapper down to hardware-bound sub-kernels. Every API discussed in Chapters 1–10 appears in context here. The sketch computes **online-softmax flash attention** for one `(batch, head)` slice per launch instance. It partitions Q into blocks along the sequence dimension, iterates over KV blocks for each Q block, and maintains rolling softmax state across KV iterations. @@ -9,7 +9,7 @@ The sketch computes **online-softmax flash attention** for one `(batch, head)` s ``` flash_attention(...) L0 user-facing wrapper └─ @pto.jit flash_attention_kernel - ├─ Tile Ops tload / tstore at the GM↔UB boundary + ├─ Tile Ops tile.load / tile.store at the GM↔UB boundary └─ @pto.ukernel kv_block_process ├─ @pto.simt materialize_tile_bounds ├─ @pto.cube qk_matmul @@ -60,21 +60,29 @@ L0 knows nothing about tiles, UB, or pipelines. It is the boundary between the u ## 11.3 L1 — `@pto.jit` kernel entry + ```python @pto.jit(target="a5") def flash_attention_kernel( - Q, K, V, O, *, + Q: pto.tensor_spec(rank=4, dtype=pto.f32), + K: pto.tensor_spec(rank=4, dtype=pto.f32), + V: pto.tensor_spec(rank=4, dtype=pto.f32), + O: pto.tensor_spec(rank=4, dtype=pto.f32), + *, BLOCK_Q: pto.constexpr = 128, BLOCK_KV: pto.constexpr = 128, CAUSAL: pto.constexpr = False, NUM_STAGES: pto.constexpr = 2, ): + # Walkthrough body omitted in this signature overview. + return ``` The `@pto.jit` decorator marks the compile + launch boundary. Inputs are Python-native tensors; outputs are written in-place to `O`. Keyword-only `constexpr` parameters (`BLOCK_Q`, `BLOCK_KV`, `CAUSAL`) are baked at compile time. ### 11.3.1 TensorView construction + ```python q_view = pto.make_tensor_view(Q, shape=[batch, seq_q, heads, dim], strides=Q.strides) @@ -90,6 +98,7 @@ o_view = pto.make_tensor_view(O, shape=[batch, seq_q, heads, dim], ### 11.3.2 SPMD launch contract + ```python block_idx = pto.get_block_idx() block_num = pto.get_block_num() @@ -104,6 +113,7 @@ The launch grid is `[batch * heads]`. Each block computes one `(batch, head)` sl ### 11.3.3 Per-head view partitioning + ```python q_head = pto.partition_view( q_view, @@ -131,14 +141,44 @@ There is no dedicated `select_head_view` public helper anymore. Each `(batch, he ### 11.3.4 Tile allocation -Two categories of tiles are allocated: +Three categories of tiles are allocated: -**UB-resident tiles** — data tiles that live in the Unified Buffer: +**MAT-backed bridge tiles** — the logical Q/K/V/P blocks that feed the cube path: + ```python -q_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) -k_tile = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, valid_shape=[full_bc, dim]) -v_tile = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, valid_shape=[full_bc, dim]) +q_mat = pto.alloc_tile( + shape=[Br, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, dim], + blayout="ColMajor", + slayout="RowMajor", +) +k_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", +) +v_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", +) +p_mat = pto.alloc_tile( + shape=[Br, Bc], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, full_bc], + blayout="ColMajor", + slayout="RowMajor", +) o_prev_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) o_next_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) @@ -147,23 +187,43 @@ m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") -s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) -p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) -pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") -beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +``` + +The walkthrough keeps Q/K/V/P on the MAT path so the cube sub-kernels consume the same tile objects that the L1 schedule owns. Runtime tails still live in `valid_shape`; the physical tile shapes stay static. + +**UB-resident state and scratch tiles** — the online-softmax state plus intermediate outputs: + +```python +o_prev_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +o_next_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + +s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") ``` The online-softmax algorithm requires **ping-pong state tiles**: `m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next`. After each KV block, `next` becomes `prev` for the following iteration. **Cube-local scratch tiles** — allocated in specific memory spaces: + ```python -q_l0a = pto.alloc_tile(shape=[Br, D], dtype=pto.f16, - memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, dim]) -p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f16, - memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, full_bc]) -rhs_l0b = pto.alloc_tile(shape=[Bc, D], dtype=pto.f16, +q_l0a = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, dim]) +p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, full_bc]) +rhs_l0b = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, memory_space=pto.MemorySpace.RIGHT, valid_shape=[full_bc, dim]) qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, full_bc]) @@ -175,6 +235,7 @@ Cube scratch tiles are NOT UB buffers. `LEFT`, `RIGHT`, and `ACC` are distinct h ### 11.3.5 SIMT metadata buffer + ```python meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) meta_ptr = meta_tile.as_ptr() @@ -189,14 +250,29 @@ row-major padded physical width just to satisfy row-byte alignment. ### 11.3.6 Outer Q loop + inner KV loop + ```python with pto.for_(0, q_blocks, step=1) as qi: + q_rows = _block_valid_extent(seq_q, qi, Br) q_part = pto.partition_view(q_head, offsets=[0, qi * Br, 0, 0], - sizes=[1, Br, 1, dim]) + sizes=[1, q_rows, 1, dim]) o_part = pto.partition_view(o_head, offsets=[0, qi * Br, 0, 0], - sizes=[1, Br, 1, dim]) - - pto.tload(q_part, q_tile) + sizes=[1, q_rows, 1, dim]) + + q_mat.valid_shape = [q_rows, dim] + o_prev_tile.valid_shape = [q_rows, dim] + o_next_tile.valid_shape = [q_rows, dim] + m_prev_tile.valid_shape = [q_rows, one] + m_next_tile.valid_shape = [q_rows, one] + l_prev_tile.valid_shape = [q_rows, one] + l_next_tile.valid_shape = [q_rows, one] + alpha_tile.valid_shape = [q_rows, one] + beta_tile.valid_shape = [q_rows, one] + p_mat.valid_shape = [q_rows, full_bc] + pv_tile.valid_shape = [q_rows, dim] + q_l0a.valid_shape = [q_rows, dim] + + pto.tile.load(q_part, q_mat) m_prev_tile.fill(float("-inf")) l_prev_tile.fill(0.0) @@ -210,16 +286,28 @@ with pto.for_(0, q_blocks, step=1) as qi: m_cur = kv_loop.m l_cur = kv_loop.l o_cur = kv_loop.o + kv_rows = _block_valid_extent(seq_k, kj, Bc) k_part = pto.partition_view(k_head, - offsets=[0, kj * Bc, 0, 0], sizes=[1, Bc, 1, dim]) + offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) v_part = pto.partition_view(v_head, - offsets=[0, kj * Bc, 0, 0], sizes=[1, Bc, 1, dim]) + offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) + + k_mat.valid_shape = [kv_rows, dim] + v_mat.valid_shape = [kv_rows, dim] + s_tile.valid_shape = [q_rows, kv_rows] + p_tile.valid_shape = [q_rows, kv_rows] + p_mat.valid_shape = [q_rows, kv_rows] + pv_tile.valid_shape = [q_rows, dim] + p_l0a.valid_shape = [q_rows, kv_rows] + rhs_l0b.valid_shape = [kv_rows, dim] + qk_acc_tile.valid_shape = [q_rows, kv_rows] + pv_acc_tile.valid_shape = [q_rows, dim] kv_block_process( - q_tile, k_part, v_part, k_tile, v_tile, + q_mat, k_part, v_part, k_mat, v_mat, o_cur, o_next_tile, m_cur, l_cur, m_next_tile, l_next_tile, - s_tile, p_tile, pv_tile, + s_tile, p_tile, p_mat, pv_tile, alpha_tile, beta_tile, q_l0a, p_l0a, rhs_l0b, qk_acc_tile, pv_acc_tile, @@ -229,25 +317,26 @@ with pto.for_(0, q_blocks, step=1) as qi: kv_loop.update(m=m_next_tile, l=l_next_tile, o=o_next_tile) o_final_tile = kv_loop.final("o") - pto.tstore(o_final_tile, o_part) + pto.tile.store(o_final_tile, o_part) ``` Key points: -- **`tload` at the L1 boundary**: Q is loaded once per Q block using a tile op. The compiler auto-inserts the necessary `set_flag`/`wait_flag` pairs. +- **Static physical shape, dynamic valid extent**: `alloc_tile(shape=...)` stays constexpr. Tail handling is expressed by updating `valid_shape` before each block load and sub-kernel call. +- **`tile.load` at the L1 boundary**: Q is loaded once per Q block using a tile op into the MAT-backed bridge tile `q_mat`. The compiler auto-inserts the necessary `set_flag`/`wait_flag` pairs. - **State initialization**: `fill(float("-inf"))` and `fill(0.0)` initialize the online-softmax accumulators before the first KV block. - **Carry state**: the inner `kv_loop` carries three ping-pong tiles (`m`, `l`, `o`) across iterations using `.carry(...)` / `.update(...)` / `.final(...)`. After each KV block, the loop updates the carried values to the `_next` tiles. After the loop, `.final("o")` extracts the final output accumulator. -- **`tstore` at the L1 boundary**: writes the final result for this Q block back to GM. +- **`tile.store` at the L1 boundary**: writes the final result for this Q block back to GM. ## 11.4 L2 — `@pto.ukernel` ```python @pto.ukernel def kv_block_process( - q_tile, k_part, v_part, k_tile, v_tile, + q_mat, k_part, v_part, k_mat, v_mat, o_prev_tile, o_next_tile, m_prev_tile, l_prev_tile, m_next_tile, l_next_tile, - s_tile, p_tile, pv_tile, + s_tile, p_tile, p_mat, pv_tile, alpha_tile, beta_tile, q_l0a, p_l0a, rhs_l0b, qk_acc_tile, pv_acc_tile, @@ -259,20 +348,29 @@ The ukernel processes one KV block against an already-loaded Q tile. It owns the ### Phase 0 — Stage K/V data + ```python -pto.mte_load(k_part, k_tile) -pto.mte_load(v_part, v_tile) +rows = k_mat.valid_shape[0] +cols = k_mat.valid_shape[1] +row_bytes = cols * pto.bytewidth(pto.f32) +gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f32) +mat_row_stride = k_mat.shape[1] * pto.bytewidth(pto.f32) +pto.mte_load(k_part.as_ptr(), k_mat.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride)) +pto.mte_load(v_part.as_ptr(), v_mat.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride)) pto.pipe_barrier(pto.Pipe.ALL) ``` -`mte_load` copies the current K and V block from GM to UB. `pipe_barrier(Pipe.ALL)` makes the phase boundary explicit before the cube unit reads `k_tile`/`v_tile`. +`mte_load` is the ptr-based GM→MAT DMA wrapper used by this walkthrough. The ukernel passes explicit GM/MAT pointers plus the DMA grouping parameters, and `pipe_barrier(Pipe.ALL)` makes the phase boundary explicit before the cube unit reads `k_mat`/`v_mat`. ### Phase 0b — Materialize loop bounds + ```python materialize_tile_bounds(meta_ptr, - q_tile.valid_shape[0], - k_tile.valid_shape[0]) + q_mat.valid_shape[0], + k_mat.valid_shape[0]) row_start = scalar.load(meta_ptr + 0) row_stop = scalar.load(meta_ptr + 1) valid_cols = scalar.load(meta_ptr + 2) @@ -282,8 +380,9 @@ The SIMT sub-kernel `materialize_tile_bounds` writes `{0, valid_rows, valid_cols ### Phase 1 — `S = Q @ K^T` + ```python -qk_matmul(q_tile, k_tile, q_l0a, rhs_l0b, qk_acc_tile, s_tile) +qk_matmul(q_mat, k_mat, q_l0a, rhs_l0b, qk_acc_tile, s_tile) pto.pipe_barrier(pto.Pipe.ALL) ``` @@ -291,6 +390,7 @@ Dispatches the cube sub-kernel. `pipe_barrier(Pipe.ALL)` separates the matrix mu ### Phase 2 — Online softmax + ```python online_softmax_rows( s_tile, p_tile, @@ -306,20 +406,25 @@ The simd sub-kernel computes per-row softmax on `S`, updates the running `m`/`l` ### Phase 3 — `PV = P @ V` + ```python -pv_matmul(p_tile, v_tile, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) +pto.tile.mov(p_tile, p_mat) +pto.pipe_barrier(pto.Pipe.ALL) + +pv_matmul(p_mat, v_mat, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) pto.pipe_barrier(pto.Pipe.ALL) ``` -Second cube dispatch. `rhs_l0b` is reused for `V` (it previously held `K`). `pv_acc_tile` is reused from the QK^T accumulator. +The probability tile is first staged onto the MAT path with `pto.tile.mov(p_tile, p_mat)`. Then the second cube dispatch reuses `rhs_l0b` for `V` and `pv_acc_tile` for the accumulator. ### Phase 4 — Blend output + ```python blend_output_rows( o_prev_tile, pv_tile, alpha_tile, beta_tile, o_next_tile, row_start, row_stop, - v_tile.valid_shape[1], + v_mat.valid_shape[1], ) pto.pipe_barrier(pto.Pipe.ALL) ``` @@ -334,15 +439,16 @@ Each `pipe_barrier(Pipe.ALL)` between phases is explicit in the ukernel body. Th ### `qk_matmul` — `S = Q @ K^T` + ```python @pto.cube -def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): - m = q_tile.valid_shape[0] - k = q_tile.valid_shape[1] - n = k_tile.valid_shape[0] +def qk_matmul(q_mat, k_mat, q_l0a, k_l0b, s_acc, s_tile): + m = q_mat.valid_shape[0] + k = q_mat.valid_shape[1] + n = k_mat.valid_shape[0] - pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) - pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mte_l1_l0a(q_mat.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_mat.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) ``` @@ -358,15 +464,16 @@ The cube kernel does not allocate scratch — the caller (L1) owns scratch lifet ### `pv_matmul` — `PV = P @ V` + ```python @pto.cube -def pv_matmul(p_tile, v_tile, p_l0a, v_l0b, pv_acc, pv_tile): - m = p_tile.valid_shape[0] - k = p_tile.valid_shape[1] - n = v_tile.valid_shape[1] +def pv_matmul(p_mat, v_mat, p_l0a, v_l0b, pv_acc, pv_tile): + m = p_mat.valid_shape[0] + k = p_mat.valid_shape[1] + n = v_mat.valid_shape[1] - pto.mte_l1_l0a(p_tile.as_ptr(), p_l0a.as_ptr(), m, k) - pto.mte_l1_l0b(v_tile.as_ptr(), v_l0b.as_ptr(), k, n) + pto.mte_l1_l0a(p_mat.as_ptr(), p_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(v_mat.as_ptr(), v_l0b.as_ptr(), k, n) pto.mad(p_l0a.as_ptr(), v_l0b.as_ptr(), pv_acc.as_ptr(), m, n, k) pto.mte_l0c_ub(pv_acc.as_ptr(), pv_tile.as_ptr(), m, n, n, n, 0) ``` @@ -388,6 +495,7 @@ def online_softmax_rows( The simd kernel iterates over rows with `pto.for_`, processing one row per iteration: + ```python with pto.for_(row_start, row_stop, step=1) as row: col_mask = pto.make_mask(pto.f32, valid_cols) @@ -403,6 +511,7 @@ with pto.for_(row_start, row_stop, step=1) as row: ### Softmax computation + ```python row_max = pto.vcgmax(s_row, col_mask) m_next = scalar.max(m_prev, row_max) @@ -430,6 +539,7 @@ This implements the online-softmax update from the Flash Attention paper: ### Store results + ```python pto.vsts(p_row, p_tile[row, 0:], col_mask) scalar.store(m_next, m_next_tile[row, 0]) @@ -447,6 +557,7 @@ This implements the online-softmax update from the Flash Attention paper: ### `materialize_tile_bounds` — scalar metadata + ```python @pto.simt def materialize_tile_bounds(meta_ptr, valid_rows, valid_cols): @@ -459,6 +570,7 @@ Three scalar stores write the loop bounds into the metadata buffer. `meta_ptr` i ### `blend_output_rows` — output accumulation + ```python @pto.simt def blend_output_rows(o_prev_tile, pv_tile, alpha_tile, beta_tile, @@ -486,28 +598,17 @@ The SIMT kernel walks the tile element by element with nested `pto.for_` loops. ### Context manager alternative -For trivial sub-kernels like `materialize_tile_bounds`, a named function is overkill — the context manager form keeps the logic inline where it's used. Here is how the ukernel body would look with `materialize_tile_bounds` inlined: +For trivial sub-kernels like `materialize_tile_bounds`, a named function is overkill — the context manager form keeps the logic inline where it's used. The inline SIMT scope itself looks like this: + ```python -@pto.ukernel -def kv_block_process(...): - pto.mte_load(k_part, k_tile) - pto.mte_load(v_part, v_tile) - pto.pipe_barrier(pto.Pipe.ALL) - - # Inline SIMT: materialize loop bounds (replaces the named @pto.simt function) - with pto.simt(): - scalar.store(0, meta_ptr + 0) - scalar.store(valid_rows, meta_ptr + 1) - scalar.store(valid_cols, meta_ptr + 2) - - pto.pipe_barrier(pto.Pipe.ALL) - - qk_matmul(q_tile, k_tile, ...) - ... +with pto.simt(): + scalar.store(0, meta_ptr + 0) + scalar.store(q_mat.valid_shape[0], meta_ptr + 1) + scalar.store(k_mat.valid_shape[0], meta_ptr + 2) ``` -The `with pto.simt():` block is semantically identical to calling a `@pto.simt` function — the compiler treats it as an anonymous sub-kernel. For 3-line helpers that have no reuse, the context manager avoids the indirection of a separate function. For complex, reusable logic like `online_softmax_rows` or `qk_matmul`, the named decorator form remains the better fit. +The `with pto.simt():` block acts as an anonymous inline sub-kernel scope. For 3-line helpers that have no reuse, the context manager avoids the indirection of a separate function. For complex, reusable logic like `online_softmax_rows` or `qk_matmul`, the named decorator form remains the better fit. ## 11.8 Putting it all together: one KV block execution @@ -515,21 +616,23 @@ For one KV block, the full execution sequence is: | Step | Layer | Operation | Hardware | |------|-------|-----------|----------| -| 1 | L1 | `tload(q_part, q_tile)` | MTE2 → UB | -| 2 | L2 | `mte_load(k_part, k_tile)` | MTE2 → UB | -| 3 | L2 | `mte_load(v_part, v_tile)` | MTE2 → UB | -| 4 | L2 | `mem_bar(SYNC)` | — | +| 1 | L1 | `tile.load(q_part, q_mat)` | GM → MAT | +| 2 | L2 | `mte_load(k_part.as_ptr(), k_mat.as_ptr(), ...)` | GM → MAT | +| 3 | L2 | `mte_load(v_part.as_ptr(), v_mat.as_ptr(), ...)` | GM → MAT | +| 4 | L2 | `pipe_barrier(Pipe.ALL)` | — | | 5 | L3c | `materialize_tile_bounds` | SIMT | | 6 | L3a | `qk_matmul` (mte_l1_l0a, mte_l1_l0b, mad, mte_l0c_ub) | Cube | -| 7 | L2 | `mem_bar(SYNC)` | — | +| 7 | L2 | `pipe_barrier(Pipe.ALL)` | — | | 8 | L3b | `online_softmax_rows` (vlds, vcgmax, vexp, vcgadd, vsts, ...) | SIMD | -| 9 | L2 | `mem_bar(SYNC)` | — | -| 10 | L3a | `pv_matmul` | Cube | -| 11 | L2 | `mem_bar(SYNC)` | — | -| 12 | L3c | `blend_output_rows` | SIMT | -| 13 | L2 | `mem_bar(SYNC)` | — | +| 9 | L2 | `pipe_barrier(Pipe.ALL)` | — | +| 10 | L2 | `tile.mov(p_tile, p_mat)` | Tile copy | +| 11 | L2 | `pipe_barrier(Pipe.ALL)` | — | +| 12 | L3a | `pv_matmul` | Cube | +| 13 | L2 | `pipe_barrier(Pipe.ALL)` | — | +| 14 | L3c | `blend_output_rows` | SIMT | +| 15 | L2 | `pipe_barrier(Pipe.ALL)` | — | -After all KV blocks: L1 issues `tstore(o_final_tile, o_part)` to write the result back to GM. +After all KV blocks: L1 issues `tile.store(o_final_tile, o_part)` to write the result back to GM. ## 11.9 Design patterns in this sketch @@ -537,7 +640,7 @@ After all KV blocks: L1 issues `tstore(o_final_tile, o_part)` to write the resul **Scratch reuse**: `rhs_l0b` serves both `K` (in `qk_matmul`) and `V` (in `pv_matmul`). `pv_acc_tile` reuses the accumulator from QK^T. The caller (L1) allocates once; the ukernel passes them to both cube sub-kernels. -**Tile-level boundary vs micro-instruction boundary**: `tload`/`tstore` appear only in `@pto.jit`. `mte_load`/`mte_store` appear only in `@pto.ukernel`. This is the key abstraction split: L1 operates on tiles, L2 operates on micro-instructions. +**Tile-level boundary vs micro-instruction boundary**: `tile.load`/`tile.store` appear only in `@pto.jit`. `mte_load` appears only in `@pto.ukernel`, and it is authored in the explicit ptr-based DMA form. This is the key abstraction split: L1 operates on tiles, L2 operates on micro-instructions. **No vreg across sub-kernel boundaries**: vector registers are local to each `@pto.simd` kernel. Data crosses sub-kernel boundaries through UB tiles — the boundary contract is enforced by the type system. diff --git a/ptodsl/docs/user_guide/12-additional-examples.md b/ptodsl/docs/user_guide/12-additional-examples.md index e7b730dc3..fbd518580 100644 --- a/ptodsl/docs/user_guide/12-additional-examples.md +++ b/ptodsl/docs/user_guide/12-additional-examples.md @@ -31,21 +31,22 @@ def mat_add(A, B, O, *, BLOCK_M: pto.constexpr = 64, BLOCK_N: pto.constexpr = 12 b_part = pto.partition_view(b_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) o_part = pto.partition_view(o_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) - pto.tload(a_part, a_tile) - pto.tload(b_part, b_tile) - pto.tadd(a_tile, b_tile, o_tile) - pto.tstore(o_tile, o_part) + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.tile.add(a_tile, b_tile, o_tile) + pto.tile.store(o_tile, o_part) ``` **Key points**: - Nested `pto.for_` loops produce a 2D block traversal. Both loops are recorded as device-side control flow — they adapt to the runtime shape `M`. -- Tile shape `[BLOCK_M, BLOCK_N]` is 2D; all three tiles use the same shape so `tadd` is elementwise. +- Tile shape `[BLOCK_M, BLOCK_N]` is 2D; all three tiles use the same shape so `tile.add` is elementwise. - `partition_view` takes 2D offsets and sizes. - `BLOCK_M` and `BLOCK_N` are `constexpr` — the compiler specializes the kernel per tile shape. The L0 wrapper follows the same pattern as Chapter 2: + ```python def mat_add_wrapper(A, B, O=None, stream=None): if O is None: @@ -66,6 +67,7 @@ When a data dimension is not evenly divisible by the tile size or the hardware v Below is a self-contained `@pto.simd` kernel that adds two tiles row by row, handling column tails with `make_mask`: + ```python @pto.simd def add_rows_with_tail(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, @@ -96,50 +98,57 @@ The pattern: ### 12.2.2 Tile-level tail handling -At the Tile Op level, tail handling is built into `tload` and `tstore`. When a partition size along a dimension is smaller than the tile size, the tile's `valid_shape` tracks the actual data extent: +At the Tile Op level, tail handling is built into `tile.load` and `tile.store`. When a partition size along a dimension is smaller than the tile size, the tile's `valid_shape` tracks the actual data extent: + ```python @pto.jit(target="a5") -def vec_add_with_tail(A, B, O, *, BLOCK: pto.constexpr): +def vec_add_with_tail( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + O: pto.tensor_spec(rank=1, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): N = A.shape[0] a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) - a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) + a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) num_blocks = (N + BLOCK - 1) // BLOCK with pto.for_(0, num_blocks, step=1) as i: offset = i * BLOCK - this_block = min(BLOCK, N - offset) + this_block = scalar.min(BLOCK, N - offset) a_part = pto.partition_view(a_view, offsets=[offset], sizes=[this_block]) b_part = pto.partition_view(b_view, offsets=[offset], sizes=[this_block]) o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) - pto.tload(a_part, a_tile) - pto.tload(b_part, b_tile) + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) a_tile.valid_shape = [this_block] b_tile.valid_shape = [this_block] o_tile.valid_shape = [this_block] - pto.tadd(a_tile, b_tile, o_tile) - pto.tstore(o_tile, o_part) + pto.tile.add(a_tile, b_tile, o_tile) + pto.tile.store(o_tile, o_part) ``` -- `this_block = min(BLOCK, N - offset)` computes the actual block size for the tail iteration. -- `sizes=[this_block]` on the partition and `valid_shape` on the tile tell `tload`/`tadd`/`tstore` how many elements are live. +- `this_block = scalar.min(BLOCK, N - offset)` computes the actual block size for the tail iteration on the device side. +- `sizes=[this_block]` on the partition and `tile.valid_shape = [...]` on the tile tell `tile.load`/`tile.add`/`tile.store` how many elements are live. ### 12.2.3 The general rule | Tail scenario | Mechanism | |---------------|-----------| -| Tile Op boundary (tload/tstore) | `valid_shape` on tile + smaller `sizes` on partition | +| Tile Op boundary (tile.load/tile.store) | `valid_shape` on tile + smaller `sizes` on partition | | SIMD vector boundary (vlds/vadd/vsts) | `make_mask` + mask parameter on op | | SIMT scalar loop boundary | `min(BLOCK, N - offset)` in loop bound | @@ -149,28 +158,37 @@ This example demonstrates a complete GEMM kernel: `C = A @ B` where A is `[M, K] ### 12.3.1 L3: Cube sub-kernel + ```python @pto.cube -def gemm_tile(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, +def gemm_tile(a_mat: pto.Tile, b_mat: pto.Tile, o_tile: pto.Tile, a_l0a: pto.Tile, b_l0b: pto.Tile, o_acc: pto.Tile): - m = a_tile.valid_shape[0] - k = a_tile.valid_shape[1] - n = b_tile.valid_shape[0] + m = a_mat.valid_shape[0] + k = a_mat.valid_shape[1] + n = b_mat.valid_shape[1] - pto.mte_l1_l0a(a_tile.as_ptr(), a_l0a.as_ptr(), m, k) - pto.mte_l1_l0b(b_tile.as_ptr(), b_l0b.as_ptr(), k, n, transpose=True) + pto.mte_l1_l0a(a_mat.as_ptr(), a_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(b_mat.as_ptr(), b_l0b.as_ptr(), k, n) pto.mad(a_l0a.as_ptr(), b_l0b.as_ptr(), o_acc.as_ptr(), m, n, k) pto.mte_l0c_ub(o_acc.as_ptr(), o_tile.as_ptr(), m, n, n, n, 0) ``` -The cube sub-kernel consumes UB tiles and cube-local scratch buffers. The four-step sequence — stage left operand, stage right operand, multiply, writeback — is the canonical cube compute pattern. +The cube sub-kernel consumes MAT staging tiles plus cube-local scratch buffers. The four-step sequence — stage left operand, stage right operand, multiply, writeback — is the canonical cube compute pattern. ### 12.3.2 L1: Tile orchestration + ```python @pto.jit(target="a5") -def gemm(A, B, O, *, BLOCK_M: pto.constexpr = 64, - BLOCK_K: pto.constexpr = 64, BLOCK_N: pto.constexpr = 64): +def gemm( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK_M: pto.constexpr = 64, + BLOCK_K: pto.constexpr = 64, + BLOCK_N: pto.constexpr = 64, +): M, K_ = A.shape _, N_ = B.shape @@ -178,8 +196,10 @@ def gemm(A, B, O, *, BLOCK_M: pto.constexpr = 64, b_view = pto.make_tensor_view(B, shape=[K_, N_], strides=B.strides) o_view = pto.make_tensor_view(O, shape=[M, N_], strides=O.strides) - a_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32) - b_tile = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32) + a_mat = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32, + memory_space=pto.MemorySpace.MAT) + b_mat = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32, + memory_space=pto.MemorySpace.MAT) o_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) a_l0a = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32, @@ -197,6 +217,8 @@ def gemm(A, B, O, *, BLOCK_M: pto.constexpr = 64, m_off = mi * BLOCK_M with pto.for_(0, num_n, step=1) as ni: n_off = ni * BLOCK_N + o_part = pto.partition_view(o_view, offsets=[m_off, n_off], + sizes=[BLOCK_M, BLOCK_N]) o_tile.fill(0.0) @@ -207,27 +229,26 @@ def gemm(A, B, O, *, BLOCK_M: pto.constexpr = 64, sizes=[BLOCK_M, BLOCK_K]) b_part = pto.partition_view(b_view, offsets=[k_off, n_off], sizes=[BLOCK_K, BLOCK_N]) - o_part = pto.partition_view(o_view, offsets=[m_off, n_off], - sizes=[BLOCK_M, BLOCK_N]) - pto.tload(a_part, a_tile) - pto.tload(b_part, b_tile) + pto.tile.load(a_part, a_mat) + pto.tile.load(b_part, b_mat) - gemm_tile(a_tile, b_tile, o_tile, a_l0a, b_l0b, o_acc) + gemm_tile(a_mat, b_mat, o_tile, a_l0a, b_l0b, o_acc) - pto.tstore(o_tile, o_part) + pto.tile.store(o_tile, o_part) ``` **Key points**: - **Triply nested loops**: M, N, and K dimensions are all blocked. The K loop accumulates partial results into `o_tile`. - **Accumulation**: `o_tile.fill(0.0)` resets the accumulator before the K loop. Each K-block calls `gemm_tile` which writes its partial product back to `o_tile`. The Cube unit accumulates implicitly via `mad` — each K-block's partial result is added to the running total in `o_acc`. -- **Cube-local scratch**: `a_l0a`, `b_l0b`, and `o_acc` are allocated with explicit `memory_space` parameters (`LEFT`, `RIGHT`, `ACC`). Cube-local state does not leak into UB. -- **Direct L3 call**: `gemm_tile` is called directly from `@pto.jit` — no ukernel needed. The compiler handles sync between `tload` and the Cube sub-kernel. +- **MAT staging + cube-local scratch**: `a_mat` and `b_mat` are explicit MAT tiles that satisfy the `mte_l1_l0a` / `mte_l1_l0b` source contract. `a_l0a`, `b_l0b`, and `o_acc` are cube-local scratch (`LEFT`, `RIGHT`, `ACC`). +- **Direct L3 call**: `gemm_tile` is called directly from `@pto.jit` — no ukernel needed. The compiler handles sync between `tile.load` and the Cube sub-kernel. - **Cube sub-kernel reuse**: the same `gemm_tile` function is called for every K-block — the named decorator form enables reuse. ### 12.3.3 L0 wrapper + ```python def gemm_wrapper(A, B, O=None, stream=None): if O is None: @@ -241,11 +262,11 @@ This pattern extends directly to batch-GEMM: pass a grid of `batch` and use `pto ### 12.3.4 Comparison with ukernel path -For reference, the same GEMM could be written using `@pto.ukernel` for explicit MTE control. The ukernel would replace the inner `tload`/`tstore` calls with `mte_load`/`mte_store` and add `mem_bar` synchronization between DMA and compute. The direct-call path used above is recommended for most users — the ukernel path is for cases that need hand-tuned DMA scheduling. +For reference, the same GEMM could be written using `@pto.ukernel` for explicit MTE control. The ukernel would replace the inner `tile.load`/`tile.store` calls with `mte_load`/`mte_store` and add `mem_bar` synchronization between DMA and compute. The direct-call path used above is recommended for most users — the ukernel path is for cases that need hand-tuned DMA scheduling. ## 12.4 Online normalization with loop-carried state -Chapter 11 demonstrated online softmax with ping-pong state tiles. A simpler but instructive case is **online layer normalization** — computing mean and variance incrementally across blocks without a second pass. +Chapter 11 demonstrated online softmax with ping-pong state tiles. A simpler but instructive case is **online layer normalization** — computing mean and variance incrementally across blocks while carrying only scalar state between iterations. Given a vector `X` of length `N`, the streaming Welford algorithm updates the running mean `mu` and variance `var` as each new element `x` arrives: @@ -256,130 +277,88 @@ mu_next = mu_prev + delta / n_next m2_next = m2_prev + delta * (x - mu_next) ``` -The example below applies this pattern block by block, using a ukernel for the per-block SIMD work and `pto.for_` carry state to shuttle the running statistics between blocks. +The example below keeps the whole pattern inside one `@pto.jit` kernel. The first pass carries `mu`, `n`, and `m2` across blocks; the second pass reloads each block and applies the normalization explicitly with scalar loads and stores. This version assumes `N > 0`. -### 12.4.1 L3: SIMD block statistics - -```python -@pto.simd -def block_mean_var(x_tile: pto.Tile, block_size: pto.i32, - mu_prev: pto.f32, n_prev: pto.f32, m2_prev: pto.f32, - mu_next_tile: pto.Tile, n_next_tile: pto.Tile, - m2_next_tile: pto.Tile): - VEC = pto.elements_per_vreg(pto.f32) - - # Per-row cross-lane reductions to compute the block sum and sum-of-squares - row_sum = pto.vdup(0.0, pto.f32) - row_sum2 = pto.vdup(0.0, pto.f32) - - col_loop = pto.for_(0, block_size, step=VEC).carry(row_sum=row_sum, row_sum2=row_sum2) - with col_loop: - c = col_loop.iv - remained = pto.i32(block_size) - c - mask, _ = pto.make_mask(pto.f32, remained) - - x_vec = pto.vlds(x_tile[0, c:]) - row_sum = pto.vcadd(x_vec, mask) - row_sum2 = pto.vcadd(pto.vmul(x_vec, x_vec, mask), mask) - col_loop.update(row_sum=row_sum, row_sum2=row_sum2) - - block_n = pto.cvt(block_size, pto.f32) - block_mean = pto.vdiv(col_loop.final("row_sum"), block_n) - block_mean_sq = pto.vdiv(col_loop.final("row_sum2"), block_n) - - # Welford update: merge block statistics into running state - n_next = n_prev + block_n - delta = block_mean - mu_prev - mu_next = mu_prev + delta * block_n / n_next - m2_next = m2_prev + pto.vdiv(row_sum2, block_n) * block_n # simplified - - scalar.store(n_next, n_next_tile[0, 0]) - scalar.store(mu_next, mu_next_tile[0, 0]) - scalar.store(m2_next, m2_next_tile[0, 0]) -``` - -### 12.4.2 L2: Ukernel with carry orchestration - -```python -@pto.ukernel -def norm_block(x_part: pto.PartitionTensorView, x_tile: pto.Tile, - block_size: pto.i32, - mu_prev: pto.f32, n_prev: pto.f32, m2_prev: pto.f32, - mu_next_tile: pto.Tile, n_next_tile: pto.Tile, - m2_next_tile: pto.Tile): - pto.mte_load(x_part, x_tile) - pto.pipe_barrier(pto.Pipe.ALL) - - block_mean_var(x_tile, block_size, - mu_prev, n_prev, m2_prev, - mu_next_tile, n_next_tile, m2_next_tile) - pto.pipe_barrier(pto.Pipe.ALL) -``` - -### 12.4.3 L1: JIT entry with carry state +### 12.4.1 JIT example with loop-carried Welford state + ```python @pto.jit(target="a5") -def online_layernorm(X, O, *, BLOCK: pto.constexpr): +def online_layernorm( + X: pto.tensor_spec(rank=1, dtype=pto.f32), + O: pto.tensor_spec(rank=1, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): N = X.shape[0] x_view = pto.make_tensor_view(X, shape=[N], strides=X.strides) o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) - x_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) - - mu_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) - n_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) - m2_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + x_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) num_blocks = (N + BLOCK - 1) // BLOCK - # Carry: running statistics across blocks - block_loop = pto.for_(0, num_blocks, step=1).carry( + # Pass 1: running Welford state across blocks. + stats_loop = pto.for_(0, num_blocks, step=1).carry( mu=pto.f32(0.0), n=pto.f32(0.0), m2=pto.f32(0.0) ) - with block_loop: - i = block_loop.iv + with stats_loop: + i = stats_loop.iv offset = i * BLOCK - this_block = min(BLOCK, N - offset) - + this_block = scalar.min(BLOCK, N - offset) x_part = pto.partition_view(x_view, offsets=[offset], sizes=[this_block]) - - mu_prev = block_loop.mu - n_prev = block_loop.n - m2_prev = block_loop.m2 - - norm_block(x_part, x_tile, pto.i32(this_block), - mu_prev, n_prev, m2_prev, - mu_tile, n_tile, m2_tile) - - n_next = scalar.load(n_tile[0, 0]) - mu_next = scalar.load(mu_tile[0, 0]) - m2_next = scalar.load(m2_tile[0, 0]) - - block_loop.update(mu=mu_next, n=n_next, m2=m2_next) - - # After all blocks: finalize normalization with the running stats - global_var = m2_next / n_next - - # Second pass: normalize each block (using same tiling) + pto.tile.load(x_part, x_tile) + x_tile.valid_shape = [this_block] + + elem_loop = pto.for_(0, this_block, step=1).carry( + mu=stats_loop.mu, n=stats_loop.n, m2=stats_loop.m2 + ) + with elem_loop: + j = elem_loop.iv + x = scalar.load(x_tile.as_ptr(), j) + n_next = elem_loop.n + 1.0 + delta = x - elem_loop.mu + mu_next = elem_loop.mu + delta / n_next + delta2 = x - mu_next + m2_next = elem_loop.m2 + delta * delta2 + elem_loop.update(mu=mu_next, n=n_next, m2=m2_next) + + stats_loop.update( + mu=elem_loop.final("mu"), + n=elem_loop.final("n"), + m2=elem_loop.final("m2"), + ) + + mean = stats_loop.final("mu") + count = stats_loop.final("n") + inv_std = 1.0 / scalar.sqrt(stats_loop.final("m2") / count + pto.f32(1.0e-5)) + + # Pass 2: apply (x - mean) / sqrt(var + eps) block by block. with pto.for_(0, num_blocks, step=1) as i: offset = i * BLOCK - this_block = min(BLOCK, N - offset) + this_block = scalar.min(BLOCK, N - offset) x_part = pto.partition_view(x_view, offsets=[offset], sizes=[this_block]) o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) - pto.tload(x_part, x_tile) - pto.tnormalize(x_tile, mu_next, global_var, o_tile) - pto.tstore(o_tile, o_part) + pto.tile.load(x_part, x_tile) + x_tile.valid_shape = [this_block] + o_tile.valid_shape = [this_block] + + with pto.for_(0, this_block, step=1) as j: + x = scalar.load(x_tile.as_ptr(), j) + y = (x - mean) * inv_std + scalar.store(y, o_tile.as_ptr(), j) + + pto.tile.store(o_tile, o_part) ``` **Key points**: -- **Carry state**: `.carry(mu=..., n=..., m2=...)` on the `pto.for_` declares three loop-carried values. Each iteration reads the previous values via `block_loop.mu` etc. and feeds the updated values via `block_loop.update(...)`. -- **Ping-pong implicit**: The carry mechanism produces a clean SSA-style handoff between iterations — no explicit swap of tile pairs needed. -- **Two-pass algorithm**: The first pass accumulates statistics; the second pass applies the normalization. For a single-pass online version, the normalized output would be written block-by-block inside the first loop, but that requires storing the running statistics per element — a tradeoff between memory and passes. -- **Compare to flash attention**: The flash attention carry in Chapter 11 carries six values (`m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next`) and uses ping-pong tiles. This example shows that for simpler scalar carries, direct values (no tile swap) suffice. +- **Carry state**: `.carry(mu=..., n=..., m2=...)` on both loops keeps the running Welford state in SSA form. The outer loop carries state across blocks; the inner loop carries state across elements inside one block. +- **Tail handling**: `scalar.min(BLOCK, N - offset)` computes the live width of the current block, and `tile.valid_shape = [this_block]` keeps the tile contract aligned with that tail. +- **No special tile op required**: the normalization pass is written explicitly with `scalar.load(...)`, scalar arithmetic, `scalar.sqrt(...)`, and `scalar.store(...)`. There is no dependency on a dedicated `tnormalize` op. +- **Compare to flash attention**: the flash attention carry in Chapter 11 moves several tiles through ping-pong buffers. Here the carried state is only three scalars, so the same `.carry(...)` surface reads more like a conventional streaming reduction. ## 12.5 Design guidelines @@ -390,7 +369,7 @@ def online_layernorm(X, O, *, BLOCK: pto.constexpr): | Goal | Use | |------|-----| | Whole-kernel orchestration, GM↔UB boundary | `@pto.jit` | -| Tile-level data movement | `tload` / `tstore` | +| Tile-level data movement | `tile.load` / `tile.store` | | Custom row-wise vector math | `@pto.simd` | | Custom per-element logic | `@pto.simt` | | Matrix multiply | `@pto.cube` | diff --git a/ptodsl/demos/flash_attention_sketch.py b/ptodsl/examples/flash_attention_sketch.py similarity index 96% rename from ptodsl/demos/flash_attention_sketch.py rename to ptodsl/examples/flash_attention_sketch.py index 039ca3c9a..0a3819a5b 100644 --- a/ptodsl/demos/flash_attention_sketch.py +++ b/ptodsl/examples/flash_attention_sketch.py @@ -14,7 +14,7 @@ emit_flash_attention_mlir(...) compile/inspect wrapper └─ @pto.jit flash_attention_kernel - ├─ Tile Ops tload / tstore at the GM↔UB boundary + ├─ Tile Ops tile.load / tile.store at the GM↔UB boundary └─ @pto.ukernel one KV-block worth of MTE/sync orchestration ├─ @pto.cube matrix products (QK^T and P@V) ├─ @pto.simd row-wise online softmax @@ -33,7 +33,7 @@ 4. ``ukernel`` owns the per-block execution sandwich: stage the current K/V block with explicit micro-instructions, synchronize, call hardware-bound sub-kernels, and manage scratch/state. -5. ``@pto.jit`` may use tile ops such as ``tload`` / ``tstore`` at the logical +5. ``@pto.jit`` may use tile ops such as ``tile.load`` / ``tile.store`` at the logical scheduling boundary, but ``ukernel`` stays below that abstraction level. Once execution enters ``ukernel``, GM<->UB movement is expressed with MTE micro-instructions such as ``mte_load`` instead of tile ops. @@ -76,14 +76,12 @@ "Unable to locate the PTODSL Python package root from flash_attention_sketch.py" ) -from ptodsl import pto - -scalar = pto.scalar +from ptodsl import pto, scalar def _min_index(lhs, rhs): - return pto.scalar.select( - pto.scalar.cmpi("slt", lhs, rhs), + return scalar.select( + lhs < rhs, lhs, rhs, ) @@ -300,7 +298,7 @@ def flash_attention_kernel( pv_tile.valid_shape = [q_rows, dim] q_l0a.valid_shape = [q_rows, dim] - pto.tload(q_part, q_mat) + pto.tile.load(q_part, q_mat) # Initial online-softmax state for this Q block. # ``CAUSAL`` is threaded at the API boundary even though the masking @@ -369,7 +367,7 @@ def flash_attention_kernel( ) o_final_tile = kv_loop.final("o") - pto.tstore(o_final_tile, o_part) + pto.tile.store(o_final_tile, o_part) # ═══════════════════════════════════════════════════════════════════════════════ @@ -581,9 +579,26 @@ def kv_block_process( - wiring together the explicit state transition (prev -> next for m/l/o). """ - # Current-block GM->MAT staging via MTE micro-instructions. - pto.mte_load(k_part, k_mat) - pto.mte_load(v_part, v_mat) + # Current-block GM->MAT staging via explicit ptr-based DMA parameters. + rows = k_mat.valid_shape[0] + cols = k_mat.valid_shape[1] + row_bytes = cols * pto.bytewidth(pto.f32) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f32) + mat_row_stride = k_mat.shape[1] * pto.bytewidth(pto.f32) + pto.mte_load( + k_part.as_ptr(), + k_mat.as_ptr(), + 0, + row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride), + ) + pto.mte_load( + v_part.as_ptr(), + v_mat.as_ptr(), + 0, + row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride), + ) pto.pipe_barrier(pto.Pipe.ALL) materialize_tile_bounds( @@ -616,7 +631,7 @@ def kv_block_process( pto.pipe_barrier(pto.Pipe.ALL) # Stage the probability tile onto the cube MAT path. - pto.tmov(p_tile, p_mat) + pto.tile.mov(p_tile, p_mat) pto.pipe_barrier(pto.Pipe.ALL) # 3. PV = P @ V @@ -651,7 +666,7 @@ def kv_block_process( # │ L1 @pto.jit compile + cache + top-level orchestration │ # │ │ # │ flash_attention_kernel.compile(...).mlir_text() │ -# │ TensorView metadata / alloc_tile / partition_view / tload / tstore │ +# │ TensorView metadata / alloc_tile / partition_view / tile.load / tile.store │ # │ outer Q loop + inner KV loop + ping-pong state ownership │ # │ │ # │ Key idea: one launchable entry owns both runtime binding and logical │ diff --git a/ptodsl/examples/softmax_dsl.py b/ptodsl/examples/softmax_dsl.py index decdaa4d6..f4f6ebbf4 100644 --- a/ptodsl/examples/softmax_dsl.py +++ b/ptodsl/examples/softmax_dsl.py @@ -19,7 +19,7 @@ %arg0: !pto.ptr, …, %arg7: i32, …) # arg0: pto.ptr(…), … scf.if %has_rows { # with pto.if_(has_rows): - pto.tload ins(…) outs(…) # pto.tload(part, tile) + pto.tload ins(…) outs(…) # pto.tile.load(part, tile) pto.vecscope { # with pto.vecscope(): scf.for %row = … { # with pto.for_(…) as row: %final_max, %final_sum = # @@ -37,9 +37,9 @@ pto.barrier # pto.pipe_barrier(pto.Pipe.ALL) """ -from ptodsl import pto +from ptodsl import pto, scalar -s = pto.scalar # arith shorthand alias +s = scalar # arith shorthand alias @pto.jit( @@ -94,8 +94,8 @@ def online_softmax_update_kernel_2d( _ = s.index_cast(pto.int32, c8) # block_rows_i32 row_base_i32 = s.index_cast(pto.int32, row_base) remaining_rows= s.subi(arg8, row_base_i32) - has_rows = s.cmpi_sgt(remaining_rows, c0_i32) - too_many_rows = s.cmpi_sgt(remaining_rows, c8_i32) + has_rows = remaining_rows > c0_i32 + too_many_rows = remaining_rows > c8_i32 row_count_i32 = s.select(too_many_rows, c8_i32, remaining_rows) row_count = s.index_cast(row_count_i32) # → index seq = s.index_cast(arg7) # → index @@ -143,9 +143,9 @@ def online_softmax_update_kernel_2d( expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) # ── Tile loads from GM ──────────────────────────────────────────────── - pto.tload(oldmax_part, oldmax_tile) - pto.tload(oldsum_part, oldsum_tile) - pto.tload(qk_part, qk_tile) + pto.tile.load(oldmax_part, oldmax_tile) + pto.tile.load(oldsum_part, oldsum_tile) + pto.tile.load(qk_part, qk_tile) pto.set_flag("MTE2", "V", event_id=0) pto.wait_flag("MTE2", "V", event_id=0) @@ -178,7 +178,7 @@ def online_softmax_update_kernel_2d( chunk_i32 = s.index_cast(pto.int32, chunk) remaining_cols = s.subi(arg7, chunk_i32) - has_chunk = s.cmpi_sgt(remaining_cols, c0_i32) + has_chunk = remaining_cols > c0_i32 # scf.if with results – produce (next_max, next_sum) with pto.if_(has_chunk, results=(vf32, vf32)) as br: @@ -216,7 +216,7 @@ def online_softmax_update_kernel_2d( # Output normalisation loop with pto.for_(c0, c128, step=c64) as chunk2: rem2 = s.subi(arg7, s.index_cast(pto.int32, chunk2)) - has_chunk2= s.cmpi_sgt(rem2, c0_i32) + has_chunk2= rem2 > c0_i32 with pto.if_(has_chunk2): cmask2, _ = pto.plt_b32(rem2) cbase2 = s.addi(row_qk, chunk2) @@ -229,10 +229,10 @@ def online_softmax_update_kernel_2d( pto.wait_flag("V", "MTE3", event_id=0) # Tile stores to GM - pto.tstore(newmax_tile, newmax_part) - pto.tstore(newsum_tile, newsum_part) - pto.tstore(expmax_tile, expmax_part) - pto.tstore(out_tile, out_part) + pto.tile.store(newmax_tile, newmax_part) + pto.tile.store(newsum_tile, newsum_part) + pto.tile.store(expmax_tile, expmax_part) + pto.tile.store(out_tile, out_part) pto.pipe_barrier(pto.Pipe.ALL) diff --git a/ptodsl/examples/tadd_dsl.py b/ptodsl/examples/tadd_dsl.py index b55a65693..e02076058 100644 --- a/ptodsl/examples/tadd_dsl.py +++ b/ptodsl/examples/tadd_dsl.py @@ -28,9 +28,9 @@ } """ -from ptodsl import pto +from ptodsl import pto, scalar -s = pto.scalar # arith shorthand alias +s = scalar # arith shorthand alias @pto.jit(name="TADD", kernel_kind="vector", target="a5") diff --git a/ptodsl/ptodsl/_control_flow.py b/ptodsl/ptodsl/_control_flow.py index 23f3a12f4..54cacdf97 100644 --- a/ptodsl/ptodsl/_control_flow.py +++ b/ptodsl/ptodsl/_control_flow.py @@ -15,7 +15,7 @@ ``vecscope()`` – ``pto.vecscope { … }`` ``for_(lo, hi, step, *, iter_args)`` – ``scf.for`` with optional iter_args or named carry state -``if_(cond, *, results)`` – ``scf.if`` with optional results + else +``if_(cond)`` – ``scf.if`` via explicit branch handle + automatic named merge ``yield_(*vals)`` – ``scf.yield`` """ @@ -23,7 +23,6 @@ from ._runtime_index_ops import coerce_runtime_index from ._tracing.active import current_session from ._surface_values import unwrap_surface_value, wrap_like_surface_value, wrap_surface_value -from ._types import _resolve from mlir.dialects import pto as _pto, scf from mlir.ir import InsertionPoint @@ -284,96 +283,299 @@ def _coerce_index(value): # ── if_ ─────────────────────────────────────────────────────────────────────── -class _BlockCM: - """Enters the InsertionPoint of a single block for ``with br.then_:`` style.""" - - def __init__(self, block): +def _find_parent_block(op_view): + """Return the block that directly contains *op_view*.""" + parent_op = op_view.operation.parent + if parent_op is None: + raise RuntimeError("unable to locate the parent block for pto.if_(...)") + for region in parent_op.regions: + for block in region.blocks: + for candidate in block.operations: + if candidate.operation is op_view.operation: + return block + raise RuntimeError("unable to locate the parent block for pto.if_(...)") + + +def _move_block_ops(src_block, dst_block, *, yield_values): + """Move all non-terminator ops from *src_block* into *dst_block* and yield.""" + with InsertionPoint(dst_block): + terminator = scf.YieldOp(list(yield_values)) + yield_anchor = terminator.operation.opview + for op in list(src_block.operations): + if op.operation.name == "scf.yield": + continue + op.move_before(yield_anchor) + + +class _IfBranchCM: + """Enters the insertion point of one branch block for ``with br.then_:`` style.""" + + def __init__(self, owner, branch_name, block): + self._owner = owner + self._branch_name = branch_name self._block = block self._ip = None def __enter__(self): + self._owner._enter_branch(self._branch_name) self._ip = InsertionPoint(self._block) self._ip.__enter__() def __exit__(self, *exc): - self._ip.__exit__(*exc) + try: + self._ip.__exit__(*exc) + finally: + self._owner._leave_branch(self._branch_name) class BranchHandle: """ - Handle for ``scf.if`` with results and an else branch. + Handle for one authored ``pto.if_(...)`` branch pair. Usage:: - with pto.if_(cond, results=(vf32, vf32)) as br: + with pto.if_(cond) as br: with br.then_: - ... - pto.yield_(a, b) + br.assign(val=x) with br.else_: - pto.yield_(c, d) - x, y = br.results + br.assign(val=y) + out = br.val """ - def __init__(self, if_op): - self._op = if_op - self.then_ = _BlockCM(if_op.then_block) - self.else_ = _BlockCM(if_op.else_block) + def __init__(self, owner): + self._owner = owner + self.then_ = _IfBranchCM(owner, "then", owner._tmp_if.then_block) + self.else_ = _IfBranchCM(owner, "else", owner._tmp_if.else_block) - @property - def results(self): - return tuple(wrap_surface_value(result) for result in self._op.results) + def assign(self, **kwargs): + self._owner._assign_branch_values(kwargs) + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + return self._owner._get_merged_value(name) class _IfCM: - def __init__(self, cond, result_types): + def __init__(self, cond): self._cond = cond - self._result_types = [_resolve(t) for t in result_types] if result_types else [] - self._if_op = None - self._ip = None + self._cond_value = None + self._tmp_if = None + self._parent_block = None + self._active_branch = None + self._branch_closed = {"then": False, "else": False} + self._branch_entered = {"then": False, "else": False} + self._branch_assignments = {"then": None, "else": None} + self._merged_values = None + self._finalized = False + self._handle = None def __enter__(self): - cond = unwrap_surface_value(self._cond) - if self._result_types: - # if/else with results: create IfOp but don't enter any block; - # the caller manages blocks via br.then_ / br.else_ - self._if_op = scf.IfOp(cond, self._result_types, hasElse=True) - return BranchHandle(self._if_op) + self._cond_value = unwrap_surface_value(self._cond) + self._tmp_if = scf.IfOp(self._cond_value, hasElse=True) + self._parent_block = _find_parent_block(self._tmp_if) + self._handle = BranchHandle(self) + return self._handle + + def __exit__(self, exc_type, exc, tb): + if exc_type is not None: + self._erase_tmp_if() + return None + try: + self._finalize() + except Exception: + self._erase_tmp_if() + raise + return None + + def _enter_branch(self, branch_name): + if self._finalized: + raise RuntimeError("pto.if_(...) branches are no longer available after the conditional closes") + if self._active_branch is not None: + raise RuntimeError( + "pto.if_(...) does not support nested branch entry; close the current " + f"br.{self._active_branch}_ block before entering br.{branch_name}_" + ) + if self._branch_closed[branch_name]: + raise RuntimeError(f"br.{branch_name}_ may only be entered once per pto.if_(...)") + self._active_branch = branch_name + self._branch_entered[branch_name] = True + + def _leave_branch(self, branch_name): + if self._active_branch == branch_name: + self._active_branch = None + self._branch_closed[branch_name] = True + + def _assign_branch_values(self, kwargs): + if self._active_branch is None: + raise RuntimeError("br.assign(...) may only be used inside br.then_ or br.else_") + if not kwargs: + raise ValueError("br.assign(...) requires at least one named value") + branch_name = self._active_branch + if self._branch_assignments[branch_name] is not None: + raise RuntimeError(f"br.{branch_name}_ may call br.assign(...) at most once") + raw_values = {} + templates = {} + order = tuple(kwargs.keys()) + for name, value in kwargs.items(): + raw_value = unwrap_surface_value(value) + if not hasattr(raw_value, "type"): + raise TypeError( + "br.assign(...) expects PTO runtime values or authored surface values; " + f"'{name}' received {type(value).__name__}" + ) + raw_values[name] = raw_value + templates[name] = value + self._branch_assignments[branch_name] = { + "order": order, + "raw_values": raw_values, + "templates": templates, + } + + def _get_merged_value(self, name): + if not self._finalized: + raise RuntimeError(f"br.{name} is only available after the pto.if_(...) block closes") + if self._merged_values is None or name not in self._merged_values: + expected = () + if self._merged_values: + expected = tuple(self._merged_values.keys()) + if expected: + raise AttributeError( + f"br.{name} was not assigned by this conditional; " + f"expected one of: {', '.join(expected)}" + ) + raise AttributeError(f"br.{name} was not assigned by this conditional") + return self._merged_values[name] + + def _finalize(self): + self._validate_no_stray_ops() + if not any(self._branch_entered.values()): + raise RuntimeError( + "pto.if_(...) requires at least one explicit branch block; " + "use 'with br.then_:' and optionally 'with br.else_:'" + ) + merge_spec = self._validate_merge_spec() + if merge_spec is None: + self._finalize_side_effect_if() else: - # simple if without results: enter then_block automatically - self._if_op = scf.IfOp(cond) - self._ip = InsertionPoint(self._if_op.then_block) - self._ip.__enter__() + self._finalize_merged_if(merge_spec) + self._finalized = True + + def _validate_no_stray_ops(self): + parent_ops = list(self._parent_block.operations) + if not parent_ops or parent_ops[-1].operation is not self._tmp_if.operation: + raise RuntimeError( + "pto.if_(...) body may only contain explicit 'with br.then_:' / " + "'with br.else_:' blocks; PTODSL found operations emitted directly " + "in the outer if body" + ) + + def _validate_merge_spec(self): + then_assignment = self._branch_assignments["then"] + else_assignment = self._branch_assignments["else"] + if then_assignment is None and else_assignment is None: return None + if then_assignment is None or else_assignment is None: + raise RuntimeError( + "automatic branch merge requires both br.then_ and br.else_ to call br.assign(...)" + ) - def __exit__(self, *exc): - if not self._result_types: - scf.YieldOp([]) - self._ip.__exit__(*exc) - # for if/else with results: blocks are managed by BranchHandle; nothing to do + then_names = set(then_assignment["raw_values"].keys()) + else_names = set(else_assignment["raw_values"].keys()) + if then_names != else_names: + missing_in_else = sorted(then_names - else_names) + missing_in_then = sorted(else_names - then_names) + pieces = [] + if missing_in_else: + pieces.append(f"missing in else: {', '.join(missing_in_else)}") + if missing_in_then: + pieces.append(f"missing in then: {', '.join(missing_in_then)}") + raise RuntimeError("br.assign(...) names must match across branches; " + "; ".join(pieces)) + + order = then_assignment["order"] + result_types = [] + for name in order: + then_value = then_assignment["raw_values"][name] + else_value = else_assignment["raw_values"][name] + if then_value.type != else_value.type: + raise RuntimeError( + f"br.assign(...) type mismatch for '{name}': " + f"then branch yields {then_value.type}, else branch yields {else_value.type}" + ) + result_types.append(then_value.type) + + return { + "order": order, + "result_types": result_types, + "then": then_assignment, + "else": else_assignment, + } + + def _finalize_side_effect_if(self): + has_else = self._branch_entered["else"] + final_if = scf.IfOp(self._cond_value, hasElse=has_else) + _move_block_ops(self._tmp_if.then_block, final_if.then_block, yield_values=[]) + if has_else: + _move_block_ops(self._tmp_if.else_block, final_if.else_block, yield_values=[]) + self._merged_values = {} + self._tmp_if.erase() + self._tmp_if = final_if + + def _finalize_merged_if(self, merge_spec): + final_if = scf.IfOp(self._cond_value, merge_spec["result_types"], hasElse=True) + then_yield_values = [ + merge_spec["then"]["raw_values"][name] + for name in merge_spec["order"] + ] + else_yield_values = [ + merge_spec["else"]["raw_values"][name] + for name in merge_spec["order"] + ] + _move_block_ops(self._tmp_if.then_block, final_if.then_block, yield_values=then_yield_values) + _move_block_ops(self._tmp_if.else_block, final_if.else_block, yield_values=else_yield_values) + + merged = {} + for name, template, result in zip( + merge_spec["order"], + (merge_spec["then"]["templates"][name] for name in merge_spec["order"]), + final_if.results, + ): + merged[name] = wrap_like_surface_value(template, result) + self._merged_values = merged + self._tmp_if.erase() + self._tmp_if = final_if + + def _erase_tmp_if(self): + if self._tmp_if is None: + return + try: + self._tmp_if.erase() + except Exception: + pass + finally: + self._tmp_if = None -def if_(cond, *, results=None) -> _IfCM: +def if_(cond) -> _IfCM: """ - ``scf.if`` context manager. + ``scf.if`` context manager with explicit branch handles. - Without ``results`` – simple if with no else; ``scf.yield`` is inserted - automatically:: + Side-effect-only form:: - with pto.if_(has_rows): - ... + with pto.if_(has_rows) as br: + with br.then_: + ... - With ``results`` – if/else pair that produces SSA values; the caller must - manage ``br.then_`` and ``br.else_`` and emit ``pto.yield_(…)`` in each:: + Automatic named merge form:: - with pto.if_(has_chunk, results=(vf32, vf32)) as br: + with pto.if_(has_chunk) as br: with br.then_: - ... - pto.yield_(merged_max, merged_sum) + br.assign(x=a) with br.else_: - pto.yield_(running_max, running_sum) - x, y = br.results + br.assign(x=b) + x = br.x """ - return _IfCM(cond, results) + return _IfCM(cond) # ── yield_ ──────────────────────────────────────────────────────────────────── diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index 48af69166..4b4693719 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -67,6 +67,14 @@ def illegal_subkernel_placement_error(role: str, outer_role: str | None) -> Runt ) +def illegal_inline_subkernel_placement_error(role: str, outer_role: str | None) -> RuntimeError: + """Return one diagnostic for an inline subkernel scope placed outside the supported layer graph.""" + return RuntimeError( + f"inline pto.{role}() may only be used from the top-level @pto.jit body or inside @pto.ukernel; " + f"nested use inside @pto.{outer_role} is not part of the PTODSL layer contract." + ) + + def simd_value_escape_error(type_text: str) -> RuntimeError: """Return one diagnostic for transient SIMD values escaping a simd subkernel boundary.""" return RuntimeError( @@ -90,6 +98,7 @@ def tile_row_alignment_error(*, shape, dtype, row_bytes: int, required_alignment __all__ = [ "PTODSLTracingMisuseError", "host_tensor_metadata_error", + "illegal_inline_subkernel_placement_error", "illegal_subkernel_placement_error", "native_python_control_flow_error", "simd_value_escape_error", diff --git a/ptodsl/ptodsl/_host_tensors.py b/ptodsl/ptodsl/_host_tensors.py index 9f270642d..0217ffb5f 100644 --- a/ptodsl/ptodsl/_host_tensors.py +++ b/ptodsl/ptodsl/_host_tensors.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from ._diagnostics import host_tensor_metadata_error -from ._types import _resolve, index, ptr +from ._types import _ensure_non_storage_only_authored_dtype, _resolve, index, ptr def _normalize_tensor_shape(shape): @@ -110,6 +110,10 @@ class TensorSpec: def __post_init__(self): if self.rank <= 0: raise ValueError("tensor_spec(rank=...) expects a positive rank") + _ensure_non_storage_only_authored_dtype( + self.dtype, + context="pto.tensor_spec(...)", + ) def entry_arg_types(self): data_type = _resolve(ptr(self.dtype, self.address_space)) diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index 8f1cbe7da..9bd845a58 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -13,8 +13,10 @@ Design rules: - Vector math ops infer the result type from the first operand's type. -- ``vlds`` / ``vbrc_load`` still require an explicit ``vreg_type`` argument - because the result type cannot be inferred from the pointer alone. +- ``vlds(tile[row, col:])`` and ``vlds(ptr, offset)`` infer the result + ``vreg`` type from the source element type. ``vbrc_load`` still requires an + explicit result ``vreg`` type because broadcast widths are authored + explicitly in the current surface. - ``make_tensor_view`` infers the TensorViewType from ``len(shape)`` and the pointer's element type. - ``partition_view`` infers the PartitionTensorViewType from the source type. @@ -31,6 +33,7 @@ TensorViewValue, TileSliceValue, TileValue, + _coerce_index_value, _unwrap_sequence, compose_partition_spec, emit_as_ptr, @@ -39,7 +42,17 @@ unwrap_surface_value, wrap_surface_value, ) -from ._types import _resolve, mask_type, part_tensor_view_type, tensor_view_type, vreg_type +from ._types import ( + _isinstance_pto_type, + _integer_signedness, + _materialize_integer_literal, + _resolve, + _strip_integer_signedness, + mask_type, + part_tensor_view_type, + tensor_view_type, + vreg_type, +) from mlir.dialects import arith, pto as _pto from mlir.ir import ( @@ -47,6 +60,8 @@ BF16Type, F16Type, F32Type, + Float8E4M3FNType, + Float8E5M2Type, FloatAttr, IndexType, IntegerType, @@ -80,6 +95,37 @@ def _event_attr(event_id: int): return getattr(_pto, f"EVENT_ID{event_id}") +def _canonical_pipe_token(pipe): + if isinstance(pipe, str): + canonical = _PIPE_ALIASES.get(pipe, pipe) + if not canonical.startswith("PIPE_"): + canonical = "PIPE_" + canonical + return canonical + + for canonical in ( + "PIPE_FIX", "PIPE_MTE1", "PIPE_MTE2", "PIPE_MTE3", "PIPE_MTE4", + "PIPE_V", "PIPE_M", "PIPE_S", "PIPE_V2", "PIPE_ALL", + ): + pipe_attr = getattr(_pto.PIPE, canonical, None) + if pipe_attr is not None and pipe == pipe_attr: + return canonical + return None + + +def _validate_static_event_id(event_id, *, context: str): + if isinstance(event_id, int) and not 0 <= event_id <= 7: + raise ValueError(f"{context} expects static event_id in [0, 7], got {event_id}") + + +def _validate_sync_pipe(pipe, *, context: str, allowed: tuple[str, ...]): + canonical = _canonical_pipe_token(pipe) + if canonical is None: + raise TypeError(f"{context} expects a concrete Pipe value, got {pipe!r}") + if canonical not in allowed: + expected = ", ".join(f"<{name}>" for name in allowed) + raise ValueError(f"{context} expects pipe to be one of {expected}, got <{canonical}>") + + # ── Constants ──────────────────────────────────────────────────────────────── def const(value: int, *, dtype=None): @@ -91,6 +137,10 @@ def const(value: int, *, dtype=None): """ from ._types import index as _idx_dtype mlir_type = _resolve(dtype) if dtype is not None else _resolve(_idx_dtype) + if any(cls.isinstance(mlir_type) for cls in (F16Type, BF16Type, F32Type)): + return wrap_surface_value(arith.ConstantOp(mlir_type, FloatAttr.get(mlir_type, value)).result) + if IntegerType.isinstance(mlir_type): + return wrap_surface_value(_materialize_integer_literal(mlir_type, value)) return wrap_surface_value(arith.ConstantOp(mlir_type, value).result) @@ -106,7 +156,10 @@ def castptr(int_addr, result_ptr_type): def addptr(base_ptr, index_offset): """``pto.addptr`` – advance a pointer by an index offset.""" return wrap_surface_value( - _pto.AddPtrOp(unwrap_surface_value(base_ptr), unwrap_surface_value(index_offset)).result + _pto.AddPtrOp( + unwrap_surface_value(base_ptr), + _coerce_index(index_offset, context="addptr(ptr, offset)"), + ).result ) @@ -123,8 +176,10 @@ def vlds(src_ptr, offset=None, result_vreg_type=None): _index_zero(), ).result) - if offset is None or result_vreg_type is None: - raise TypeError("vlds(ptr, offset, result_vreg_type) requires both offset and result_vreg_type") + if offset is None: + raise TypeError("vlds(ptr, offset, result_vreg_type=None) requires an explicit offset") + if result_vreg_type is None: + result_vreg_type = _infer_vreg_type_from_address_source(src_ptr) return wrap_surface_value(_pto.VldsOp( _resolve(result_vreg_type), unwrap_surface_value(src_ptr), @@ -132,6 +187,97 @@ def vlds(src_ptr, offset=None, result_vreg_type=None): ).result) +def vldas(source): + """``pto.vldas`` – prime alignment state for a following unaligned load stream.""" + if isinstance(source, TileSliceValue): + source = _tile_slice_ptr(source) + return wrap_surface_value( + _pto.VldasOp( + _pto.AlignType.get(), + unwrap_surface_value(source), + ).result + ) + + +def vldus(source, align): + """``pto.vldus`` – unaligned vector load threaded through alignment state.""" + result_type = ( + _infer_vreg_type_from_tile_slice(source) + if isinstance(source, TileSliceValue) + else _infer_vreg_type_from_address_source(source) + ) + if isinstance(source, TileSliceValue): + source = _tile_slice_ptr(source) + op = _pto.VldusOp( + result_type, + _pto.AlignType.get(), + unwrap_surface_value(source), + unwrap_surface_value(align), + ) + return wrap_surface_value(op.result), wrap_surface_value(op.updated_align) + + +_DEINTERLEAVE_DIST_TOKENS = {"DINTLV_B8", "DINTLV_B16", "DINTLV_B32", "BDINTLV"} +_INTERLEAVE_DIST_TOKENS = {"INTLV_B8", "INTLV_B16", "INTLV_B32"} +_VSTORE_DIST_TOKENS = { + "NORM_B8", "NORM_B16", "NORM_B32", + "1PT_B8", "1PT_B16", "1PT_B32", + "PK_B16", "PK_B32", "PK_B64", "PK4_B32", + "MRG4CHN_B8", "MRG2CHN_B8", "MRG2CHN_B16", +} + + +def _normalize_dist_token(dist, *, allowed: set[str], context: str): + token = dist + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized.startswith("_"): + normalized = normalized[1:] + if normalized not in allowed: + expected = ", ".join(sorted(allowed)) + raise ValueError(f"{context} does not support dist {dist!r}; expected one of {expected}") + return normalized + + +def vldsx2(source, offset_or_dist, dist=None): + """``pto.vldsx2`` – dual vector load with deinterleave.""" + if isinstance(source, TileSliceValue): + if dist is not None: + raise TypeError("vldsx2(tile[row, col:], dist) does not accept a separate offset argument") + result_type = _infer_vreg_type_from_tile_slice(source) + op = _pto.Vldsx2Op( + result_type, + result_type, + unwrap_surface_value(source), + _index_zero(), + _normalize_dist_token( + offset_or_dist, + allowed=_DEINTERLEAVE_DIST_TOKENS, + context="vldsx2(..., dist)", + ), + ) + return wrap_surface_value(op.low), wrap_surface_value(op.high) + + if dist is None: + raise TypeError("vldsx2(ptr, offset, dist) requires an explicit offset and dist") + result_type = _infer_vreg_type_from_address_source(source) + op = _pto.Vldsx2Op( + result_type, + result_type, + unwrap_surface_value(source), + _coerce_index(offset_or_dist, context="vldsx2(ptr, offset, dist)"), + _normalize_dist_token( + dist, + allowed=_DEINTERLEAVE_DIST_TOKENS, + context="vldsx2(..., dist)", + ), + ) + return wrap_surface_value(op.low), wrap_surface_value(op.high) + + def vbrc_load(src_ptr, offset, result_vreg_type): """``pto.vlds {dist="BRC_B32"}`` – broadcast a scalar into all lanes.""" return wrap_surface_value( @@ -144,6 +290,28 @@ def vbrc_load(src_ptr, offset, result_vreg_type): ) +def vbitcast(vector_value, to_dtype): + """``pto.vbitcast`` – reinterpret one vector register as a different element type.""" + target_elem = _resolve(to_dtype) + target_type = _resolve(vreg_type(_elements_per_vreg(target_elem), target_elem)) + return wrap_surface_value( + _pto.VbitcastOp( + target_type, + unwrap_surface_value(vector_value), + ).result + ) + + +def pbitcast(mask_value, to_type): + """``pto.pbitcast`` – reinterpret one mask register at a different granularity.""" + return wrap_surface_value( + _pto.PbitcastOp( + _resolve(to_type), + unwrap_surface_value(mask_value), + ).result + ) + + def vsts(val, dst_ptr, offset, mask=None): """``pto.vsts`` – vector store to a tile slice or to *dst_ptr* at *offset*.""" if isinstance(dst_ptr, TileSliceValue): @@ -178,424 +346,1969 @@ def vsts_1pt(val, dst_ptr, offset, mask): ) -# ── Mask / predicate ops ────────────────────────────────────────────────────── - -def plt_b32(scalar): - """ - ``pto.plt_b32`` – predicate-load from a 32-bit scalar. +def vstsx2(low, high, dst_ptr, offset_or_dist, dist_or_mask=None, mask=None): + """``pto.vstsx2`` – dual interleaving vector store.""" + if isinstance(dst_ptr, TileSliceValue): + if mask is not None: + raise TypeError("vstsx2(low, high, tile[row, col:], dist, mask) does not accept a separate offset argument") + _pto.Vstsx2Op( + unwrap_surface_value(low), + unwrap_surface_value(high), + unwrap_surface_value(dst_ptr), + _index_zero(), + _normalize_dist_token( + offset_or_dist, + allowed=_INTERLEAVE_DIST_TOKENS, + context="vstsx2(..., dist)", + ), + unwrap_surface_value(dist_or_mask), + ) + return - Returns ``(mask_value, scalar_out)``. ``scalar_out`` is often unused - and can be discarded with ``_``. - """ - plt_op = _pto.PltB32Op( - _resolve(mask_type("b32")), - IntegerType.get_signless(32), - unwrap_surface_value(scalar), + if mask is None: + raise TypeError("vstsx2(low, high, ptr, offset, dist, mask) requires an explicit offset, dist, and mask") + _pto.Vstsx2Op( + unwrap_surface_value(low), + unwrap_surface_value(high), + unwrap_surface_value(dst_ptr), + _coerce_index(offset_or_dist, context="vstsx2(ptr, offset, dist, mask)"), + _normalize_dist_token( + dist_or_mask, + allowed=_INTERLEAVE_DIST_TOKENS, + context="vstsx2(..., dist)", + ), + unwrap_surface_value(mask), ) - return wrap_surface_value(plt_op.mask), wrap_surface_value(plt_op.scalar_out) - -def pset_b32(pattern: str): - """``pto.pset_b32 "PATTERN"`` → ``!pto.mask``.""" - return wrap_surface_value(_pto.PsetB32Op(_resolve(mask_type("b32")), pattern).result) +def vgather2(buf, offsets, mask, result_vreg_type=None): + """``pto.vgather2`` – indexed gather from UB.""" + rt = result_vreg_type if result_vreg_type is not None else _infer_vreg_type_from_address_source(buf) + return wrap_surface_value( + _pto.Vgather2Op( + _resolve(rt), + unwrap_surface_value(buf), + unwrap_surface_value(offsets), + unwrap_surface_value(mask), + ).result + ) -# ── Vector math (result type inferred from first operand) ───────────────────── -def vadd(lhs, rhs, mask, result_type=None): - """``pto.vadd`` – element-wise add.""" - rt = result_type if result_type is not None else lhs.type +def vgather2_bc(buf, offsets, mask, result_vreg_type=None): + """``pto.vgather2_bc`` – indexed gather from UB with masked zero-fill.""" + rt = result_vreg_type if result_vreg_type is not None else _infer_vreg_type_from_address_source(buf) return wrap_surface_value( - _pto.VaddOp( + _pto.Vgather2BcOp( _resolve(rt), - unwrap_surface_value(lhs), - unwrap_surface_value(rhs), + unwrap_surface_value(buf), + unwrap_surface_value(offsets), unwrap_surface_value(mask), ).result ) -def vmul(lhs, rhs, mask): - """``pto.vmul`` – element-wise multiply.""" +def vgatherb(buf, offsets, mask, result_vreg_type=None): + """``pto.vgatherb`` – block gather from UB using byte offsets.""" + rt = result_vreg_type if result_vreg_type is not None else _infer_vreg_type_from_address_source(buf) return wrap_surface_value( - _pto.VmulOp( - unwrap_surface_value(lhs).type, - unwrap_surface_value(lhs), - unwrap_surface_value(rhs), + _pto.VgatherbOp( + _resolve(rt), + unwrap_surface_value(buf), + unwrap_surface_value(offsets), unwrap_surface_value(mask), ).result ) -def vmax(lhs, rhs, mask): - """``pto.vmax`` – element-wise maximum.""" +def vscatter(value, destination, offsets, mask): + """``pto.vscatter`` – indexed scatter to UB.""" + _pto.VscatterOp( + unwrap_surface_value(value), + unwrap_surface_value(destination), + unwrap_surface_value(offsets), + unwrap_surface_value(mask), + ) + + +def _coerce_i16(value, *, context: str): + raw_value = unwrap_surface_value(value) + i16_type = IntegerType.get_signless(16) + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return _materialize_integer_literal(i16_type, raw_value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") + if kind == "index": + return arith.IndexCastOp(i16_type, raw_value).result + signless_value = _strip_integer_signedness(raw_value) + if signless_value.type == i16_type: + return signless_value + width = IntegerType(raw_value.type).width + if width < 16: + if _integer_signedness(raw_value.type) == "unsigned": + return arith.ExtUIOp(i16_type, signless_value).result + return arith.ExtSIOp(i16_type, signless_value).result + if width > 16: + return arith.TruncIOp(i16_type, signless_value).result + return signless_value + + +def vsldb(source, block_stride, repeat_stride, mask): + """``pto.vsldb`` – block-strided load.""" + result_type = ( + _infer_vreg_type_from_tile_slice(source) + if isinstance(source, TileSliceValue) + else _infer_vreg_type_from_address_source(source) + ) return wrap_surface_value( - _pto.VmaxOp( - unwrap_surface_value(lhs).type, - unwrap_surface_value(lhs), - unwrap_surface_value(rhs), + _pto.VsldbOp( + result_type, + unwrap_surface_value(source), + _coerce_i16(block_stride, context="vsldb(..., block_stride, repeat_stride, mask)"), + _coerce_i16(repeat_stride, context="vsldb(..., block_stride, repeat_stride, mask)"), unwrap_surface_value(mask), ).result ) -def vdiv(lhs, rhs, mask): - """``pto.vdiv`` – element-wise divide.""" +def vsstb(value, destination, block_stride, repeat_stride, mask): + """``pto.vsstb`` – block-strided store.""" + _pto.VsstbOp( + unwrap_surface_value(value), + unwrap_surface_value(destination), + _coerce_i16(block_stride, context="vsstb(..., block_stride, repeat_stride, mask)"), + _coerce_i16(repeat_stride, context="vsstb(..., block_stride, repeat_stride, mask)"), + unwrap_surface_value(mask), + ) + + +# ── Mask / predicate ops ────────────────────────────────────────────────────── + +_MASK_PATTERN_TOKENS = { + "PAT_ALL", + "PAT_ALLF", + "PAT_H", + "PAT_Q", + "PAT_M3", + "PAT_M4", + *(f"PAT_VL{count}" for count in range(1, 129)), +} + +_CMP_MODE_TOKENS = {"eq", "ne", "lt", "le", "gt", "ge"} +_PREDICATE_PART_TOKENS = {"LOWER", "HIGHER"} +_PREDICATE_LOAD_DIST_TOKENS = {"NORM", "US", "DS"} +_PREDICATE_STORE_DIST_TOKENS = {"NORM", "PK"} +_POST_UPDATE_TOKENS = {"NO_POST_UPDATE", "POST_UPDATE"} + + +def _normalize_mask_pattern(pattern): + token = pattern + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + token = token.strip().upper() + normalized = token if token.startswith("PAT_") else f"PAT_{token}" + if normalized not in _MASK_PATTERN_TOKENS: + raise ValueError( + f"unsupported mask pattern {pattern!r}; expected one of PAT_ALL, PAT_ALLF, " + "PAT_H, PAT_Q, PAT_VL1..PAT_VL128, PAT_M3, PAT_M4" + ) + return normalized + + +def _normalize_cmp_mode(cmp_mode): + token = cmp_mode + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().lower() + if normalized not in _CMP_MODE_TOKENS: + raise ValueError( + f"unsupported cmp_mode {cmp_mode!r}; expected one of EQ, NE, LT, LE, GT, GE" + ) + return normalized + + +def _cmp_mode_attr(cmp_mode): + return Attribute.parse(f"#pto") + + +def _normalize_predicate_part(part): + token = part + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized not in _PREDICATE_PART_TOKENS: + raise ValueError(f"unsupported predicate part {part!r}; expected LOWER or HIGHER") + return normalized + + +def _normalize_predicate_dist(dist, *, allowed: set[str], context: str): + token = dist + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized not in allowed: + expected = ", ".join(sorted(allowed)) + raise ValueError(f"{context} does not support dist {dist!r}; expected one of {expected}") + return normalized + + +def _normalize_post_update_mode(mode, *, context: str): + token = mode + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized in {"OFF", "NO_POST_UPDATE"}: + return "NO_POST_UPDATE" + if normalized in {"ON", "POST_UPDATE"}: + return "POST_UPDATE" + expected = ", ".join(sorted(_POST_UPDATE_TOKENS)) + raise ValueError(f"{context} does not support mode {mode!r}; expected one of ON/OFF ({expected})") + + +def _mask_type_from_bits(mask_bits: int): + return _resolve(mask_type(f"b{mask_bits}")) + + +def _infer_mask_metadata(mask_value, *, context: str): + raw_type = unwrap_surface_value(mask_value).type + try: + mask_ty = _pto.MaskType(raw_type) + except Exception as exc: + raise TypeError(f"{context} expects a PTO mask value, got {raw_type}") from exc + granularity = mask_ty.granularity + return int(granularity[1:]), raw_type + + +def _require_same_mask_types(values, *, context: str): + raw_types = [unwrap_surface_value(value).type for value in values] + first = raw_types[0] + for other in raw_types[1:]: + if other != first: + raise TypeError(f"{context} expects masks of the same granularity, got {first} and {other}") + return first + + +def _pointer_element_type(ptr_value, *, context: str): + raw_type = unwrap_surface_value(ptr_value).type + try: + return _pto.PtrType(raw_type).element_type + except Exception: + try: + return MemRefType(raw_type).element_type + except Exception as exc: + raise TypeError(f"{context} expects a PTO pointer or memref-backed address, got {raw_type}") from exc + + +def _coerce_index(value, *, context: str): + raw_value = unwrap_surface_value(value) + index_type = IndexType.get() + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return arith.ConstantOp(index_type, raw_value).result + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an index-like scalar, got {raw_value.type}") + if IndexType.isinstance(raw_value.type): + return raw_value + if IntegerType.isinstance(raw_value.type): + return arith.IndexCastOp(index_type, _strip_integer_signedness(raw_value)).result + raise TypeError(f"{context} expects an index-like scalar, got {raw_value.type}") + + +def init_align(): + """``pto.init_align`` – materialize the initial alignment state.""" + return wrap_surface_value(_pto.InitAlignOp(_pto.AlignType.get()).result) + + +def _plt_impl(mask_bits: int, scalar): + plt_op = _plt_op_for_mask_bits(mask_bits)( + _mask_type_from_bits(mask_bits), + IntegerType.get_signless(32), + _coerce_i32(scalar, context=f"plt_b{mask_bits}(scalar)"), + ) + return wrap_surface_value(plt_op.mask), wrap_surface_value(plt_op.scalar_out) + + +def plt_b8(scalar): + """``pto.plt_b8`` – predicate-load from a 32-bit scalar into a b8 mask.""" + return _plt_impl(8, scalar) + + +def plt_b16(scalar): + """``pto.plt_b16`` – predicate-load from a 32-bit scalar into a b16 mask.""" + return _plt_impl(16, scalar) + + +def plt_b32(scalar): + """ + ``pto.plt_b32`` – predicate-load from a 32-bit scalar. + + Returns ``(mask_value, scalar_out)``. ``scalar_out`` is often unused + and can be discarded with ``_``. + """ + return _plt_impl(32, scalar) + + +def _pset_impl(mask_bits: int, pattern): return wrap_surface_value( - _pto.VdivOp( - unwrap_surface_value(lhs).type, - unwrap_surface_value(lhs), - unwrap_surface_value(rhs), - unwrap_surface_value(mask), + _pset_op_for_mask_bits(mask_bits)( + _mask_type_from_bits(mask_bits), + _normalize_mask_pattern(pattern), ).result ) -def vcmax(v, mask): - """``pto.vcmax`` – cross-lane maximum reduction.""" +def pset_b8(pattern): + """``pto.pset_b8(pattern)`` → ``!pto.mask``.""" + return _pset_impl(8, pattern) + + +def pset_b16(pattern): + """``pto.pset_b16(pattern)`` → ``!pto.mask``.""" + return _pset_impl(16, pattern) + + +def pset_b32(pattern): + """``pto.pset_b32(pattern)`` → ``!pto.mask``.""" + return _pset_impl(32, pattern) + + +def _pge_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PgeB8Op, + 16: _pto.PgeB16Op, + 32: _pto.PgeB32Op, + }[mask_bits] + + +def _pge_impl(mask_bits: int, pattern): return wrap_surface_value( - _pto.VcmaxOp( - unwrap_surface_value(v).type, - unwrap_surface_value(v), - unwrap_surface_value(mask), + _pge_op_for_mask_bits(mask_bits)( + _mask_type_from_bits(mask_bits), + _normalize_mask_pattern(pattern), ).result ) -def vcadd(v, mask): - """``pto.vcadd`` – cross-lane add (sum reduction).""" +def pge_b8(pattern): + """``pto.pge_b8(pattern)`` → ``!pto.mask``.""" + return _pge_impl(8, pattern) + + +def pge_b16(pattern): + """``pto.pge_b16(pattern)`` → ``!pto.mask``.""" + return _pge_impl(16, pattern) + + +def pge_b32(pattern): + """``pto.pge_b32(pattern)`` → ``!pto.mask``.""" + return _pge_impl(32, pattern) + + +def pand(src0, src1, mask): + """``pto.pand`` – gated mask AND.""" + result_type = _require_same_mask_types((src0, src1, mask), context="pand(src0, src1, mask)") return wrap_surface_value( - _pto.VcaddOp( - unwrap_surface_value(v).type, - unwrap_surface_value(v), + _pto.PandOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), unwrap_surface_value(mask), ).result ) -def vdup(v, mask, *, position=None): - """``pto.vdup`` – duplicate a lane value into all lanes. - - Pass ``position="LOWEST"`` to broadcast the lowest (lane-0) element. - """ +def por(src0, src1, mask): + """``pto.por`` – gated mask OR.""" + result_type = _require_same_mask_types((src0, src1, mask), context="por(src0, src1, mask)") return wrap_surface_value( - _pto.VdupOp( - unwrap_surface_value(v).type, - unwrap_surface_value(v), + _pto.PorOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), unwrap_surface_value(mask), - position=position, ).result ) -def vexpdif(inp, ref, mask, part: str = "ODD"): - """``pto.vexpdif`` – ``exp(inp - ref)`` selecting ODD or EVEN lanes.""" +def pxor(src0, src1, mask): + """``pto.pxor`` – gated mask XOR.""" + result_type = _require_same_mask_types((src0, src1, mask), context="pxor(src0, src1, mask)") return wrap_surface_value( - _pto.VexpdifOp( - unwrap_surface_value(inp).type, - unwrap_surface_value(inp), - unwrap_surface_value(ref), + _pto.PxorOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), unwrap_surface_value(mask), - part, ).result ) -def vexp(inp, mask): - """``pto.vexp`` – element-wise exponential.""" +def pnot(src, mask): + """``pto.pnot`` – gated mask NOT.""" + result_type = _require_same_mask_types((src, mask), context="pnot(src, mask)") return wrap_surface_value( - _pto.VexpOp( - unwrap_surface_value(inp).type, - unwrap_surface_value(inp), + _pto.PnotOp( + result_type, + unwrap_surface_value(src), unwrap_surface_value(mask), ).result ) -def vcgmax(v, mask): - """``pto.vcgmax`` – group maximum reduction, surfaced as the lowest-lane scalar.""" - reduced = _pto.VcgmaxOp( - unwrap_surface_value(v).type, - unwrap_surface_value(v), - unwrap_surface_value(mask), - ).result - return _extract_lowest_lane_scalar(reduced, mask) +def psel(src0, src1, sel): + """``pto.psel`` – per-lane mask select.""" + result_type = _require_same_mask_types((src0, src1, sel), context="psel(src0, src1, sel)") + return wrap_surface_value( + _pto.PselOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(sel), + ).result + ) -def vcgadd(v, mask): - """``pto.vcgadd`` – group sum reduction, surfaced as the lowest-lane scalar.""" - reduced = _pto.VcgaddOp( - unwrap_surface_value(v).type, - unwrap_surface_value(v), - unwrap_surface_value(mask), - ).result - return _extract_lowest_lane_scalar(reduced, mask) +def ppack(mask_value, part): + """``pto.ppack`` – pack predicate bits into the selected half.""" + _, result_type = _infer_mask_metadata(mask_value, context="ppack(mask, part)") + return wrap_surface_value( + _pto.PpackOp( + result_type, + unwrap_surface_value(mask_value), + _normalize_predicate_part(part), + ).result + ) -def vsubs(inp, scalar, mask): - """``pto.vsubs`` – vector minus scalar under mask.""" - raw_scalar = _coerce_scalar_like_vector_element(inp, scalar, context="vsubs") - neg_scalar = _negate_runtime_scalar(raw_scalar) +def punpack(mask_value, part): + """``pto.punpack`` – unpack predicate bits from the selected half.""" + _, result_type = _infer_mask_metadata(mask_value, context="punpack(mask, part)") return wrap_surface_value( - _pto.VaddsOp( - unwrap_surface_value(inp).type, - unwrap_surface_value(inp), - neg_scalar, - unwrap_surface_value(mask), + _pto.PunpackOp( + result_type, + unwrap_surface_value(mask_value), + _normalize_predicate_part(part), ).result ) -# ── Tile-domain operations ──────────────────────────────────────────────────── +def _pintlv_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PintlvB8Op, + 16: _pto.PintlvB16Op, + 32: _pto.PintlvB32Op, + }[mask_bits] -def make_tensor_view(ptr, *, shape=None, strides=None): - """ - ``pto.make_tensor_view`` – wrap a pointer as a tensor view. - Type is inferred: rank from ``len(shape)``, element type from ``ptr``. - """ - authored_ptr = ptr - if shape is None: - shape = getattr(authored_ptr, "shape", None) - if strides is None: - strides = getattr(authored_ptr, "strides", None) - if shape is None or strides is None: - raise TypeError("make_tensor_view() requires shape= and strides=, or a host tensor proxy carrying both") - ptr = resolve_tensor_data_entry(authored_ptr) - rank = len(shape) - raw_ptr = unwrap_surface_value(ptr) - elem = _pto.PtrType(raw_ptr.type).element_type - tv_type = tensor_view_type(rank, elem) - value = _pto.MakeTensorViewOp( - tv_type, - raw_ptr, - _unwrap_sequence(shape), - _unwrap_sequence(strides), - ).result - return TensorViewValue(value, shape=tuple(shape), strides=tuple(strides)) +def _pdintlv_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PdintlvB8Op, + 16: _pto.PdintlvB16Op, + 32: _pto.PdintlvB32Op, + }[mask_bits] -def _normalize_static_tile_shape(shape): - static_shape = [] - for dim in shape: - if isinstance(dim, bool) or not isinstance(dim, int): - raise TypeError( - "alloc_tile(shape=...) currently requires a static physical tile shape. " - "Use constexpr/static integers for shape and place runtime metadata in valid_shape." - ) - static_shape.append(dim) - return tuple(static_shape) +def _mask_pair_op(op_resolver, lhs, rhs, *, expected_mask_bits: int, context: str): + mask_bits, result_type = _infer_mask_metadata(lhs, context=context) + if mask_bits != expected_mask_bits: + raise TypeError(f"{context} expects mask_b{expected_mask_bits} operands, got mask_b{mask_bits}") + _require_same_mask_types((lhs, rhs), context=context) + op = op_resolver(mask_bits)( + result_type, + result_type, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + ) + return wrap_surface_value(op.low), wrap_surface_value(op.high) -def _split_valid_shape(shape, valid_shape): - rank = len(shape) - if valid_shape is None: - return tuple(shape), None, None, tuple(shape) +def pintlv_b8(lhs, rhs): + """``pto.pintlv_b8`` – interleave two b8 masks.""" + return _mask_pair_op( + _pintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=8, + context="pintlv_b8(lhs, rhs)", + ) + + +def pintlv_b16(lhs, rhs): + """``pto.pintlv_b16`` – interleave two b16 masks.""" + return _mask_pair_op( + _pintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=16, + context="pintlv_b16(lhs, rhs)", + ) + + +def pintlv_b32(lhs, rhs): + """``pto.pintlv_b32`` – interleave two b32 masks.""" + return _mask_pair_op( + _pintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=32, + context="pintlv_b32(lhs, rhs)", + ) + + +def pdintlv_b8(lhs, rhs): + """``pto.pdintlv_b8`` – deinterleave two b8 masks.""" + return _mask_pair_op( + _pdintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=8, + context="pdintlv_b8(lhs, rhs)", + ) + + +def pdintlv_b16(lhs, rhs): + """``pto.pdintlv_b16`` – deinterleave two b16 masks.""" + return _mask_pair_op( + _pdintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=16, + context="pdintlv_b16(lhs, rhs)", + ) + + +def pdintlv_b32(lhs, rhs): + """``pto.pdintlv_b32`` – deinterleave two b32 masks.""" + return _mask_pair_op( + _pdintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=32, + context="pdintlv_b32(lhs, rhs)", + ) - if len(valid_shape) != rank: + +def vcmp(src0, src1, seed_mask, cmp_mode): + """``pto.vcmp`` – vector/vector comparison producing a predicate mask.""" + _, elem_type = _infer_vreg_metadata(src0) + result_type = _mask_type_from_bits(_mask_bits_for_dtype(elem_type)) + seed_type = unwrap_surface_value(seed_mask).type + if seed_type != result_type: raise TypeError( - f"alloc_tile(valid_shape=...) rank mismatch: expected {rank} dims, got {len(valid_shape)}" + f"vcmp(src0, src1, seed_mask, cmp_mode) expects seed_mask {result_type}, got {seed_type}" ) + return wrap_surface_value( + _pto.VcmpOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(seed_mask), + _normalize_cmp_mode(cmp_mode), + ).result + ) - type_valid_shape = [] - surface_valid_shape = [] - valid_row = None - valid_col = None - for index, dim in enumerate(valid_shape): - surface_valid_shape.append(dim) - if isinstance(dim, bool): - raise TypeError("alloc_tile(valid_shape=...) does not accept bool dimensions") - if isinstance(dim, int): - type_valid_shape.append(dim) - continue - type_valid_shape.append(-1) - if index == 0: - valid_row = dim - continue - if index == 1: - valid_col = dim - continue + +def vcmps(src, scalar, seed_mask, cmp_mode): + """``pto.vcmps`` – vector/scalar comparison producing a predicate mask.""" + _, elem_type = _infer_vreg_metadata(src) + result_type = _mask_type_from_bits(_mask_bits_for_dtype(elem_type)) + seed_type = unwrap_surface_value(seed_mask).type + if seed_type != result_type: raise TypeError( - "alloc_tile(valid_shape=...) currently only supports dynamic runtime metadata " - "for the first two dimensions" + f"vcmps(src, scalar, seed_mask, cmp_mode) expects seed_mask {result_type}, got {seed_type}" ) - return tuple(type_valid_shape), valid_row, valid_col, tuple(surface_valid_shape) + scalar_value = _coerce_scalar_like_vector_element(src, scalar, context="vcmps") + return wrap_surface_value( + _pto.VcmpsOp( + result_type, + unwrap_surface_value(src), + unwrap_surface_value(scalar_value), + unwrap_surface_value(seed_mask), + _normalize_cmp_mode(cmp_mode), + ).result + ) -def _uses_row_major_none_box_layout(blayout, slayout) -> bool: - return str(blayout).lower() == "rowmajor" and str(slayout).lower() == "nonebox" +def plds(buf, offset, *, dist="NORM"): + """``pto.plds`` – load a predicate mask from UB memory.""" + elem_type = _pointer_element_type(buf, context="plds(buf, offset)") + result_type = _mask_type_from_bits(_mask_bits_for_dtype(elem_type)) + return wrap_surface_value( + _pto.PldsOp( + result_type, + unwrap_surface_value(buf), + _coerce_index(offset, context="plds(buf, offset)"), + _normalize_predicate_dist( + dist, + allowed=_PREDICATE_LOAD_DIST_TOKENS, + context="plds(..., dist)", + ), + ).result + ) -def _validate_authored_tile_row_alignment(shape, dtype, *, blayout, slayout): - if not _uses_row_major_none_box_layout(blayout, slayout): - return - if not shape: - return - elem_bytewidth = _element_bytewidth(_resolve(dtype)) - row_bytes = shape[-1] * elem_bytewidth - required_alignment = 32 - if row_bytes % required_alignment == 0: - return - raise tile_row_alignment_error( - shape=shape, - dtype=str(_resolve(dtype)), - row_bytes=row_bytes, - required_alignment=required_alignment, +def psts(mask_value, buf, offset, *, dist="NORM"): + """``pto.psts`` – store a predicate mask to UB memory.""" + _infer_mask_metadata(mask_value, context="psts(mask, buf, offset)") + _pto.PstsOp( + unwrap_surface_value(mask_value), + unwrap_surface_value(buf), + _coerce_index(offset, context="psts(mask, buf, offset)"), + _normalize_predicate_dist( + dist, + allowed=_PREDICATE_STORE_DIST_TOKENS, + context="psts(..., dist)", + ), ) -def partition_view(tv, *, offsets, sizes): - """ - ``pto.partition_view`` – slice a tensor view. +def pstu(align_in, mask_value, buf): + """``pto.pstu`` – unaligned predicate store with threaded alignment state.""" + mask_bits, _ = _infer_mask_metadata(mask_value, context="pstu(align_in, mask, buf)") + if mask_bits not in {16, 32}: + raise TypeError("pstu(align_in, mask, buf) currently supports only mask_b16 and mask_b32") + elem_type = _pointer_element_type(buf, context="pstu(align_in, mask, buf)") + expected_bytes = mask_bits // 8 + actual_bytes = _element_bytewidth(elem_type) + if actual_bytes != expected_bytes: + raise TypeError( + f"pstu(align_in, mask, buf) expects a {expected_bytes}-byte pointer element for mask_b{mask_bits}, " + f"got {elem_type}" + ) + align_type = _pto.AlignType.get() + base_type = unwrap_surface_value(buf).type + op = _pto.PstuOp( + align_type, + base_type, + unwrap_surface_value(align_in), + unwrap_surface_value(mask_value), + unwrap_surface_value(buf), + ) + return wrap_surface_value(op.align_out), wrap_surface_value(op.base_out) - Type is inferred from the source tensor-view type. + +def vstar(align, destination): + """``pto.vstar`` – flush alignment-buffered tail bytes to the destination base.""" + _pto.VstarOp( + unwrap_surface_value(align), + unwrap_surface_value(destination), + ) + + +def vstas(align, destination, offset): + """``pto.vstas`` – flush alignment-buffered tail bytes with an explicit offset.""" + _pto.VstasOp( + unwrap_surface_value(align), + unwrap_surface_value(destination), + _coerce_i32(offset, context="vstas(align, destination, offset)"), + ) + + +def vstur(align_in, value, base, mode="NO_POST_UPDATE"): + """``pto.vstur`` – unaligned vector store that updates only alignment state.""" + return wrap_surface_value( + _pto.VsturOp( + _pto.AlignType.get(), + unwrap_surface_value(align_in), + unwrap_surface_value(value), + unwrap_surface_value(base), + _normalize_post_update_mode(mode, context="vstur(..., mode)"), + ).align_out + ) + + +def vstus(align_in, offset, value, base): + """``pto.vstus`` – scalar-offset unaligned vector store that updates alignment state.""" + return wrap_surface_value( + _pto.VstusOp( + _pto.AlignType.get(), + unwrap_surface_value(align_in), + _coerce_i32(offset, context="vstus(align, offset, value, base)"), + unwrap_surface_value(value), + unwrap_surface_value(base), + ).align_out + ) + + +# ── Vector math (result type inferred from first operand) ───────────────────── + +def _emit_unary_vec_op(op_ctor, inp, mask): + return wrap_surface_value( + op_ctor( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(mask), + ).result + ) + + +def _emit_binary_vec_op(op_ctor, lhs, rhs, mask): + return wrap_surface_value( + op_ctor( + unwrap_surface_value(lhs).type, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) + + +def _emit_vec_scalar_masked_op(op_ctor, inp, scalar, mask, *, context: str): + scalar_value = _coerce_scalar_like_vector_element(inp, scalar, context=context) + return wrap_surface_value( + op_ctor( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(scalar_value), + unwrap_surface_value(mask), + ).result + ) + + +def vadd(lhs, rhs, mask, result_type=None): + """``pto.vadd`` – element-wise add.""" + rt = result_type if result_type is not None else lhs.type + return wrap_surface_value( + _pto.VaddOp( + _resolve(rt), + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) + + +def vsub(lhs, rhs, mask): + """``pto.vsub`` – element-wise subtract.""" + return _emit_binary_vec_op(_pto.VsubOp, lhs, rhs, mask) + + +def vmul(lhs, rhs, mask): + """``pto.vmul`` – element-wise multiply.""" + return _emit_binary_vec_op(_pto.VmulOp, lhs, rhs, mask) + + +def vmax(lhs, rhs, mask): + """``pto.vmax`` – element-wise maximum.""" + return _emit_binary_vec_op(_pto.VmaxOp, lhs, rhs, mask) + + +def vmin(lhs, rhs, mask): + """``pto.vmin`` – element-wise minimum.""" + return _emit_binary_vec_op(_pto.VminOp, lhs, rhs, mask) + + +def vand(lhs, rhs, mask): + """``pto.vand`` – element-wise bitwise and.""" + return _emit_binary_vec_op(_pto.VandOp, lhs, rhs, mask) + + +def vor(lhs, rhs, mask): + """``pto.vor`` – element-wise bitwise or.""" + return _emit_binary_vec_op(_pto.VorOp, lhs, rhs, mask) + + +def vxor(lhs, rhs, mask): + """``pto.vxor`` – element-wise bitwise xor.""" + return _emit_binary_vec_op(_pto.VxorOp, lhs, rhs, mask) + + +def vdiv(lhs, rhs, mask): + """``pto.vdiv`` – element-wise divide.""" + return _emit_binary_vec_op(_pto.VdivOp, lhs, rhs, mask) + + +def vshl(lhs, rhs, mask): + """``pto.vshl`` – element-wise shift left.""" + return _emit_binary_vec_op(_pto.VshlOp, lhs, rhs, mask) + + +def vshr(lhs, rhs, mask): + """``pto.vshr`` – element-wise shift right.""" + return _emit_binary_vec_op(_pto.VshrOp, lhs, rhs, mask) + + +def vcmax(v, mask): + """``pto.vcmax`` – cross-lane maximum reduction.""" + return _emit_unary_vec_op(_pto.VcmaxOp, v, mask) + + +def vcadd(v, mask): + """``pto.vcadd`` – cross-lane add (sum reduction).""" + return _emit_unary_vec_op(_pto.VcaddOp, v, mask) + + +def vcmin(v, mask): + """``pto.vcmin`` – cross-lane minimum reduction.""" + return _emit_unary_vec_op(_pto.VcminOp, v, mask) + + +def vdup(v, mask, *, position=None): + """``pto.vdup`` – duplicate a lane value into all lanes. + + Pass ``position="LOWEST"`` to broadcast the lowest (lane-0) element. """ - spec = compose_partition_spec(tv, offsets=offsets, sizes=sizes) - if spec is not None: - source = spec.root_tensor_view - offsets = spec.offsets - sizes = spec.sizes - else: - source = tv + return wrap_surface_value( + _pto.VdupOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + position=position, + ).result + ) + + +def vln(inp, mask): + """``pto.vln`` – element-wise natural logarithm.""" + return _emit_unary_vec_op(_pto.VlnOp, inp, mask) + + +def vsqrt(inp, mask): + """``pto.vsqrt`` – element-wise square root.""" + return _emit_unary_vec_op(_pto.VsqrtOp, inp, mask) + + +def vabs(inp, mask): + """``pto.vabs`` – element-wise absolute value.""" + return _emit_unary_vec_op(_pto.VabsOp, inp, mask) + + +def vneg(inp, mask): + """``pto.vneg`` – element-wise negation.""" + return _emit_unary_vec_op(_pto.VnegOp, inp, mask) + + +def vrelu(inp, mask): + """``pto.vrelu`` – element-wise ReLU.""" + return _emit_unary_vec_op(_pto.VreluOp, inp, mask) + + +def vnot(inp, mask): + """``pto.vnot`` – element-wise bitwise/logical not.""" + return _emit_unary_vec_op(_pto.VnotOp, inp, mask) + + +def vexpdif(inp, ref, mask, part: str = "ODD"): + """``pto.vexpdif`` – ``exp(inp - ref)`` selecting ODD or EVEN lanes.""" + return wrap_surface_value( + _pto.VexpdifOp( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(ref), + unwrap_surface_value(mask), + part, + ).result + ) + + +def vexp(inp, mask): + """``pto.vexp`` – element-wise exponential.""" + return _emit_unary_vec_op(_pto.VexpOp, inp, mask) + + +def vrec(inp, mask): + """``pto.vrec`` – reciprocal, surfaced as ``1 / inp``.""" + zero_vec = vmuls(inp, 0, mask) + one_vec = vadds(zero_vec, 1, mask) + return vdiv(one_vec, inp, mask) + + +def vrsqrt(inp, mask): + """``pto.vrsqrt`` – inverse square root, surfaced as ``1 / sqrt(inp)``.""" + sqrt_vec = vsqrt(inp, mask) + return vrec(sqrt_vec, mask) + + +def vcgmax(v, mask): + """``pto.vcgmax`` – group maximum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgmaxOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vcgadd(v, mask): + """``pto.vcgadd`` – group sum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgaddOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vcgmin(v, mask): + """``pto.vcgmin`` – group minimum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgminOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vcpadd(v, mask): + """``pto.vcpadd`` – inclusive prefix sum.""" + return _emit_unary_vec_op(_pto.VcpaddOp, v, mask) + + +def vadds(inp, scalar, mask): + """``pto.vadds`` – vector plus scalar under mask.""" + return _emit_vec_scalar_masked_op(_pto.VaddsOp, inp, scalar, mask, context="vadds") + + +def vsubs(inp, scalar, mask): + """``pto.vsubs`` – vector minus scalar under mask.""" + raw_scalar = _coerce_scalar_like_vector_element(inp, scalar, context="vsubs") + neg_scalar = _negate_runtime_scalar(raw_scalar) + return wrap_surface_value( + _pto.VaddsOp( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + neg_scalar, + unwrap_surface_value(mask), + ).result + ) + + +def vmuls(inp, scalar, mask): + """``pto.vmuls`` – vector times scalar under mask.""" + return _emit_vec_scalar_masked_op(_pto.VmulsOp, inp, scalar, mask, context="vmuls") + + +def vmaxs(inp, scalar, mask): + """``pto.vmaxs`` – vector/scalar maximum under mask.""" + return _emit_vec_scalar_masked_op(_pto.VmaxsOp, inp, scalar, mask, context="vmaxs") + + +def vmins(inp, scalar, mask): + """``pto.vmins`` – vector/scalar minimum under mask.""" + return _emit_vec_scalar_masked_op(_pto.VminsOp, inp, scalar, mask, context="vmins") + + +def vlrelu(inp, alpha, mask): + """``pto.vlrelu`` – vector leaky ReLU under mask.""" + return _emit_vec_scalar_masked_op(_pto.VlreluOp, inp, alpha, mask, context="vlrelu") + + +def vaddrelu(lhs, rhs, mask): + """``pto.vaddrelu`` – add, then apply ReLU.""" + return vrelu(vadd(lhs, rhs, mask), mask) + + +def vsubrelu(lhs, rhs, mask): + """``pto.vsubrelu`` – subtract, then apply ReLU.""" + return vrelu(vsub(lhs, rhs, mask), mask) + + +def vaxpy(alpha, x, y, mask): + """``pto.vaxpy`` – fused ``alpha * x + y``.""" + alpha_value = _coerce_scalar_like_vector_element(x, alpha, context="vaxpy") + return wrap_surface_value( + _pto.VaxpyOp( + unwrap_surface_value(x).type, + unwrap_surface_value(x), + unwrap_surface_value(y), + unwrap_surface_value(alpha_value), + unwrap_surface_value(mask), + ).result + ) + + +def vsel(true_v, false_v, mask): + """``pto.vsel`` – element-wise select under a predicate mask.""" + return wrap_surface_value( + _pto.VselOp( + unwrap_surface_value(true_v).type, + unwrap_surface_value(true_v), + unwrap_surface_value(false_v), + unwrap_surface_value(mask), + ).result + ) + + +# ── Tile-domain operations ──────────────────────────────────────────────────── + +def make_tensor_view(ptr, *, shape=None, strides=None): + """ + ``pto.make_tensor_view`` – wrap a pointer as a tensor view. + + Type is inferred: rank from ``len(shape)``, element type from ``ptr``. + """ + authored_ptr = ptr + if shape is None: + shape = getattr(authored_ptr, "shape", None) + if strides is None: + strides = getattr(authored_ptr, "strides", None) + if shape is None or strides is None: + raise TypeError("make_tensor_view() requires shape= and strides=, or a host tensor proxy carrying both") + ptr = resolve_tensor_data_entry(authored_ptr) + rank = len(shape) + raw_ptr = unwrap_surface_value(ptr) + elem = _pto.PtrType(raw_ptr.type).element_type + tv_type = tensor_view_type(rank, elem) + value = _pto.MakeTensorViewOp( + tv_type, + raw_ptr, + _unwrap_sequence(shape), + _unwrap_sequence(strides), + ).result + return TensorViewValue(value, shape=tuple(shape), strides=tuple(strides)) + + +def _normalize_static_tile_shape(shape): + static_shape = [] + for dim in shape: + if isinstance(dim, bool) or not isinstance(dim, int): + raise TypeError( + "alloc_tile(shape=...) currently requires a static physical tile shape. " + "Use constexpr/static integers for shape and place runtime metadata in valid_shape." + ) + static_shape.append(dim) + return tuple(static_shape) + + +def _authored_tile_physical_shape(shape): + if len(shape) == 1: + return (1, shape[0]) + return tuple(shape) + + +def _split_valid_shape(shape, valid_shape): + logical_rank = len(shape) + if valid_shape is None: + return _authored_tile_physical_shape(shape), None, None, tuple(shape) + + if len(valid_shape) != logical_rank: + raise TypeError( + f"alloc_tile(valid_shape=...) rank mismatch: expected {logical_rank} dims, got {len(valid_shape)}" + ) + + surface_valid_shape = [] + if logical_rank == 1: + dim = valid_shape[0] + surface_valid_shape.append(dim) + if isinstance(dim, bool): + raise TypeError("alloc_tile(valid_shape=...) does not accept bool dimensions") + if isinstance(dim, int): + return (1, dim), None, None, tuple(surface_valid_shape) + return (-1, -1), 1, dim, tuple(surface_valid_shape) + + type_valid_shape = [] + valid_row = None + valid_col = None + for index, dim in enumerate(valid_shape): + surface_valid_shape.append(dim) + if isinstance(dim, bool): + raise TypeError("alloc_tile(valid_shape=...) does not accept bool dimensions") + if isinstance(dim, int): + type_valid_shape.append(dim) + continue + type_valid_shape.append(-1) + if index == 0: + valid_row = dim + continue + if index == 1: + valid_col = dim + continue + raise TypeError( + "alloc_tile(valid_shape=...) currently only supports dynamic runtime metadata " + "for the first two dimensions" + ) + return tuple(type_valid_shape), valid_row, valid_col, tuple(surface_valid_shape) + + +def _uses_row_major_none_box_layout(blayout, slayout) -> bool: + return str(blayout).lower() == "rowmajor" and str(slayout).lower() == "nonebox" + + +def _validate_authored_tile_row_alignment(shape, dtype, *, blayout, slayout): + if not _uses_row_major_none_box_layout(blayout, slayout): + return + if not shape: + return + elem_bytewidth = _element_bytewidth(_resolve(dtype)) + row_bytes = shape[-1] * elem_bytewidth + required_alignment = 32 + if row_bytes % required_alignment == 0: + return + raise tile_row_alignment_error( + shape=shape, + dtype=str(_resolve(dtype)), + row_bytes=row_bytes, + required_alignment=required_alignment, + ) + + +def partition_view(tv, *, offsets, sizes): + """ + ``pto.partition_view`` – slice a tensor view. + + Type is inferred from the source tensor-view type. + """ + spec = compose_partition_spec(tv, offsets=offsets, sizes=sizes) + if spec is not None: + source = spec.root_tensor_view + offsets = spec.offsets + sizes = spec.sizes + else: + source = tv + + raw_source = unwrap_surface_value(source) + src_type = _pto.TensorViewType(raw_source.type) + rank = src_type.rank + elem = src_type.element_type + ptv_type = part_tensor_view_type(rank, elem) + value = _pto.PartitionViewOp( + ptv_type, + raw_source, + _unwrap_sequence(offsets), + _unwrap_sequence(sizes), + ).result + return wrap_surface_value( + value, + root_tensor_view=source if spec is None else spec.root_tensor_view, + offsets=tuple(offsets), + sizes=tuple(sizes), + ) + + +def alloc_tile( + tile_type=None, + *, + shape=None, + dtype=None, + memory_space="ub", + valid_shape=None, + blayout: str = "RowMajor", + slayout: str = "NoneBox", + fractal_size: int = 512, + pad: str = "Null", + addr=None, + valid_row=None, + valid_col=None, +): + """ + ``pto.alloc_tile``. + + Accepts either the authored surface form: + + ``alloc_tile(shape=[...], dtype=..., memory_space=...)`` + + or the low-level explicit-type form: + + ``alloc_tile(tile_type, addr=..., valid_row=..., valid_col=...)``. + """ + if tile_type is not None and shape is not None: + raise TypeError("alloc_tile() accepts either tile_type or shape=/dtype=, not both") + + if tile_type is None: + if shape is None or dtype is None: + raise TypeError("alloc_tile() requires either tile_type or both shape= and dtype=") + if addr is not None or valid_row is not None or valid_col is not None: + raise TypeError( + "alloc_tile(shape=..., dtype=...) uses the authored surface form; " + "addr=/valid_row=/valid_col= are only supported with an explicit tile_type" + ) + logical_shape = _normalize_static_tile_shape(shape) + physical_shape = _authored_tile_physical_shape(logical_shape) + _validate_authored_tile_row_alignment(physical_shape, dtype, blayout=blayout, slayout=slayout) + type_valid_shape, valid_row, valid_col, surface_valid_shape = _split_valid_shape(logical_shape, valid_shape) + from ._types import tile_buf_type + tile_type = tile_buf_type( + physical_shape, + dtype, + type_valid_shape, + blayout=blayout, + address_space=memory_space, + slayout=slayout, + fractal_size=fractal_size, + pad=pad, + ) + shape = logical_shape + else: + physical_shape = None + surface_valid_shape = None + + value = _pto.AllocTileOp( + _resolve(tile_type), + addr=unwrap_surface_value(addr) if addr is not None else None, + valid_row=_coerce_index(valid_row, context="alloc_tile(valid_row)") if valid_row is not None else None, + valid_col=_coerce_index(valid_col, context="alloc_tile(valid_col)") if valid_col is not None else None, + ).result + if tile_type is not None and (valid_row is not None or valid_col is not None): + parsed_tile_type = parse_tile_type_metadata(_resolve(tile_type)) + rank = len(shape) if shape is not None else len(parsed_tile_type["shape_dims"]) + surface_valid_shape = [None] * rank + if rank >= 1: + surface_valid_shape[0] = valid_row + if rank >= 2: + surface_valid_shape[1] = valid_col + surface_valid_shape = tuple(surface_valid_shape) + return wrap_surface_value( + value, + tile_metadata={ + "shape": shape, + "physical_shape": physical_shape, + "dtype": dtype, + "memory_space": memory_space, + "valid_shape": surface_valid_shape, + }, + ) + + +def set_tile_valid_shape(tile, valid_shape): + """Update the runtime valid-shape metadata of an authored dynamic tile.""" + parsed_tile_type = parse_tile_type_metadata(unwrap_surface_value(tile).type) + if parsed_tile_type is None: + raise TypeError("tile.valid_shape assignment expects a tile_buf-backed value") + if len(parsed_tile_type["shape_dims"]) != 2: + raise TypeError("tile.valid_shape assignment currently only supports rank-2 tiles") + logical_rank = len(tile.shape) if getattr(tile, "shape", None) is not None else 2 + if logical_rank == 1: + if len(valid_shape) != 1: + raise TypeError("rank-1 tile.valid_shape assignment expects exactly one dimension") + if parsed_tile_type["valid_dims"] != (None, None): + raise TypeError( + "rank-1 tile.valid_shape assignment requires a tile allocated with " + "valid_shape=[...] so the physical valid row/col metadata remain dynamic" + ) + valid_row = _coerce_index_value(1) + valid_col, = _unwrap_sequence(valid_shape) + else: + if len(valid_shape) != 2: + raise TypeError("tile.valid_shape assignment currently expects exactly two dimensions") + if parsed_tile_type["valid_dims"] != (None, None): + raise TypeError( + "tile.valid_shape assignment requires a tile allocated with fully dynamic " + "valid_shape=[..., ...]" + ) + valid_row, valid_col = _unwrap_sequence(valid_shape) + _pto.SetValidShapeOp( + unwrap_surface_value(tile), + valid_row, + valid_col, + ) + + +def tload(part, tile): + """``pto.tload ins(part) outs(tile)``.""" + _pto.TLoadOp(None, unwrap_surface_value(part), unwrap_surface_value(tile)) + + +def tstore(tile, part): + """``pto.tstore ins(tile) outs(part)``.""" + _pto.TStoreOp(None, unwrap_surface_value(tile), unwrap_surface_value(part)) + + +def tmov(src, dst): + """``pto.tmov ins(src) outs(dst)`` – move data between tile domains.""" + _pto.TMovOp(None, unwrap_surface_value(src), unwrap_surface_value(dst)) + + +def _coerce_tile_scalar_operand(tile, scalar, *, context: str): + return _constant_like(scalar, infer_tile_element_type(wrap_surface_value(tile))) + + +def tadd(src0, src1, dst): + """``pto.tadd ins(src0, src1) outs(dst)``.""" + _pto.tadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tsub(src0, src1, dst): + """``pto.tsub ins(src0, src1) outs(dst)``.""" + _pto.tsub( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tmul(src0, src1, dst): + """``pto.tmul ins(src0, src1) outs(dst)``.""" + _pto.tmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tdiv(src0, src1, dst, *, precision_mode=None): + """``pto.tdiv ins(src0, src1) outs(dst)``.""" + _pto.tdiv( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tmax(src0, src1, dst): + """``pto.tmax ins(src0, src1) outs(dst)``.""" + _pto.tmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tmin(src0, src1, dst): + """``pto.tmin ins(src0, src1) outs(dst)``.""" + _pto.tmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tadds(src, scalar, dst): + """``pto.tadds ins(src, scalar) outs(dst)``.""" + _pto.tadds( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tadds"), + unwrap_surface_value(dst), + ) + + +def tsubs(src, scalar, dst): + """``pto.tsubs ins(src, scalar) outs(dst)``.""" + _pto.tsubs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tsubs"), + unwrap_surface_value(dst), + ) + + +def tmuls(src, scalar, dst): + """``pto.tmuls ins(src, scalar) outs(dst)``.""" + _pto.tmuls( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tmuls"), + unwrap_surface_value(dst), + ) + + +def tdivs(src, scalar, dst, *, precision_mode=None): + """``pto.tdivs ins(src, scalar) outs(dst)``.""" + _pto.tdivs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tdivs"), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tmaxs(src, scalar, dst): + """``pto.tmaxs ins(src, scalar) outs(dst)``.""" + _pto.tmaxs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tmaxs"), + unwrap_surface_value(dst), + ) + + +def tmins(src, scalar, dst): + """``pto.tmins ins(src, scalar) outs(dst)``.""" + _pto.tmins( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tmins"), + unwrap_surface_value(dst), + ) + + +def texp(src, dst, *, precision_mode=None): + """``pto.texp ins(src) outs(dst)``.""" + _pto.texp( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tlog(src, dst, *, precision_mode=None): + """``pto.tlog ins(src) outs(dst)``.""" + _pto.tlog( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tsqrt(src, dst, *, precision_mode=None): + """``pto.tsqrt ins(src) outs(dst)``.""" + _pto.tsqrt( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def trsqrt(src, dst, *, tmp=None, precision_mode=None): + """``pto.trsqrt ins(src, tmp?) outs(dst)``.""" + _pto.trsqrt( + unwrap_surface_value(src), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + precision_mode=precision_mode, + ) + + +def trecip(src, dst, *, precision_mode=None): + """``pto.trecip ins(src) outs(dst)``.""" + _pto.trecip( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tabs(src, dst): + """``pto.tabs ins(src) outs(dst)``.""" + _pto.tabs( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tneg(src, dst): + """``pto.tneg ins(src) outs(dst)``.""" + _pto.tneg( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def trelu(src, dst): + """``pto.trelu ins(src) outs(dst)``.""" + _pto.trelu( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tlrelu(src, slope, dst): + """``pto.tlrelu ins(src, slope) outs(dst)``.""" + _pto.tlrelu( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, slope, context="tlrelu"), + unwrap_surface_value(dst), + ) + + +def trowsum(src, tmp, dst): + """``pto.trowsum ins(src, tmp) outs(dst)``.""" + _pto.trowsum( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowmax(src, tmp, dst): + """``pto.trowmax ins(src, tmp) outs(dst)``.""" + _pto.trowmax( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowmin(src, tmp, dst): + """``pto.trowmin ins(src, tmp) outs(dst)``.""" + _pto.trowmin( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowprod(src, tmp, dst): + """``pto.trowprod ins(src, tmp) outs(dst)``.""" + _pto.trowprod( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowargmax(src, tmp, dst): + """``pto.trowargmax ins(src, tmp) outs(dst)``.""" + _pto.trowargmax( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowargmin(src, tmp, dst): + """``pto.trowargmin ins(src, tmp) outs(dst)``.""" + _pto.trowargmin( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tcolsum(src, dst, *, tmp=None, is_binary=None): + """``pto.tcolsum ins(src, tmp?) outs(dst)``.""" + _pto.tcolsum( + unwrap_surface_value(src), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + is_binary=is_binary, + ) + + +def tcolmax(src, dst): + """``pto.tcolmax ins(src) outs(dst)``.""" + _pto.tcolmax( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolmin(src, dst): + """``pto.tcolmin ins(src) outs(dst)``.""" + _pto.tcolmin( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolprod(src, dst): + """``pto.tcolprod ins(src) outs(dst)``.""" + _pto.tcolprod( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolargmax(src, tmp, dst): + """``pto.tcolargmax ins(src, tmp) outs(dst)``.""" + _pto.tcolargmax( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tcolargmin(src, tmp, dst): + """``pto.tcolargmin ins(src, tmp) outs(dst)``.""" + _pto.tcolargmin( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tcmp(src0, src1, dst, *, cmp_mode=None): + """``pto.tcmp ins(src0, src1) outs(dst)``.""" + _pto.tcmp( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + cmp_mode=None if cmp_mode is None else _cmp_mode_attr(cmp_mode), + ) + + +def tcmps(src, scalar, dst, *, cmp_mode=None): + """``pto.tcmps ins(src, scalar) outs(dst)``.""" + _pto.tcmps( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tcmps"), + unwrap_surface_value(dst), + cmp_mode=None if cmp_mode is None else _cmp_mode_attr(cmp_mode), + ) + + +def texpands(scalar, dst): + """``pto.texpands ins(scalar) outs(dst)``.""" + _pto.texpands( + _coerce_tile_scalar_operand(dst, scalar, context="texpands"), + unwrap_surface_value(dst), + ) + + +def trowexpand(src, dst): + """``pto.trowexpand ins(src) outs(dst)``.""" + _pto.trowexpand( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolexpand(src, dst): + """``pto.tcolexpand ins(src) outs(dst)``.""" + _pto.tcolexpand( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def trowexpandadd(src0, src1, dst): + """``pto.trowexpandadd ins(src0, src1) outs(dst)``.""" + _pto.trowexpandadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def trowexpandsub(src0, src1, dst, *, tmp=None): + """``pto.trowexpandsub ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandsub( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpandmul(src0, src1, dst, *, tmp=None): + """``pto.trowexpandmul ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpanddiv(src0, src1, dst, *, tmp=None, precision_mode=None): + """``pto.trowexpanddiv ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpanddiv( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + precision_mode=precision_mode, + ) + + +def trowexpandmax(src0, src1, dst, *, tmp=None): + """``pto.trowexpandmax ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpandmin(src0, src1, dst, *, tmp=None): + """``pto.trowexpandmin ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpandexpdif(src0, src1, dst, *, tmp=None): + """``pto.trowexpandexpdif ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandexpdif( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def tcolexpandadd(src0, src1, dst): + """``pto.tcolexpandadd ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandsub(src0, src1, dst): + """``pto.tcolexpandsub ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandsub( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandmul(src0, src1, dst): + """``pto.tcolexpandmul ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpanddiv(src0, src1, dst, *, precision_mode=None): + """``pto.tcolexpanddiv ins(src0, src1) outs(dst)``.""" + _pto.tcolexpanddiv( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tcolexpandmax(src0, src1, dst): + """``pto.tcolexpandmax ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandmin(src0, src1, dst): + """``pto.tcolexpandmin ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandexpdif(src0, src1, dst): + """``pto.tcolexpandexpdif ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandexpdif( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def _resolve_selection_tmp(dst, tmp, *, context: str): + if tmp is not None: + return tmp + + session = None + try: + from ._tracing.active import current_session + session = current_session() + except Exception: + session = None + + if session is not None and getattr(session.module_spec, "target_arch", None) == "a5": + return dst + + return alloc_tile(tile_type=unwrap_surface_value(dst).type) + + +def tsel(mask, src0, src1, dst, *, tmp=None): + """``pto.tsel ins(mask, src0, src1, tmp) outs(dst)`` with synthesized scratch when omitted.""" + resolved_tmp = tmp if tmp is not None else _resolve_selection_tmp(dst, tmp, context="tsel") + _pto.tsel( + unwrap_surface_value(mask), + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(resolved_tmp), + unwrap_surface_value(dst), + ) + + +def tsels(mask, src, scalar, dst, *, tmp=None): + """``pto.tsels ins(mask, src, tmp, scalar) outs(dst)`` with synthesized scratch when omitted.""" + resolved_tmp = tmp if tmp is not None else _resolve_selection_tmp(dst, tmp, context="tsels") + _pto.tsels( + unwrap_surface_value(mask), + unwrap_surface_value(src), + unwrap_surface_value(resolved_tmp), + _coerce_tile_scalar_operand(src, scalar, context="tsels"), + unwrap_surface_value(dst), + ) + + +def tcvt(src, dst, *, tmp=None, rmode=None, sat_mode=None): + """``pto.tcvt ins(src, tmp?) outs(dst)``.""" + _pto.tcvt( + unwrap_surface_value(src), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + rmode=rmode, + sat_mode=sat_mode, + ) + + +def tnot(src, dst): + """``pto.tnot ins(src) outs(dst)``.""" + _pto.tnot( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tand(src0, src1, dst): + """``pto.tand ins(src0, src1) outs(dst)``.""" + _pto.tand( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tands(src, scalar, dst): + """``pto.tands ins(src, scalar) outs(dst)``.""" + _pto.tands( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tands"), + unwrap_surface_value(dst), + ) + + +def tor(src0, src1, dst): + """``pto.tor ins(src0, src1) outs(dst)``.""" + _pto.tor( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tors(src, scalar, dst): + """``pto.tors ins(src, scalar) outs(dst)``.""" + _pto.tors( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tors"), + unwrap_surface_value(dst), + ) + + +def txor(src0, src1, tmp, dst): + """``pto.txor ins(src0, src1, tmp) outs(dst)``.""" + _pto.txor( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def txors(src, scalar, tmp, dst): + """``pto.txors ins(src, scalar, tmp) outs(dst)``.""" + _pto.txors( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="txors"), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tshl(src0, src1, dst): + """``pto.tshl ins(src0, src1) outs(dst)``.""" + _pto.tshl( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) - raw_source = unwrap_surface_value(source) - src_type = _pto.TensorViewType(raw_source.type) - rank = src_type.rank - elem = src_type.element_type - ptv_type = part_tensor_view_type(rank, elem) - value = _pto.PartitionViewOp( - ptv_type, - raw_source, - _unwrap_sequence(offsets), - _unwrap_sequence(sizes), - ).result - return wrap_surface_value( - value, - root_tensor_view=source if spec is None else spec.root_tensor_view, - offsets=tuple(offsets), - sizes=tuple(sizes), + +def tshls(src, scalar, dst): + """``pto.tshls ins(src, scalar) outs(dst)``.""" + _pto.tshls( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tshls"), + unwrap_surface_value(dst), ) -def alloc_tile( - tile_type=None, - *, - shape=None, - dtype=None, - memory_space="ub", - valid_shape=None, - blayout: str = "RowMajor", - slayout: str = "NoneBox", - fractal_size: int = 512, - pad: str = "Null", - addr=None, - valid_row=None, - valid_col=None, -): - """ - ``pto.alloc_tile``. +def tshr(src0, src1, dst): + """``pto.tshr ins(src0, src1) outs(dst)``.""" + _pto.tshr( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) - Accepts either the authored surface form: - ``alloc_tile(shape=[...], dtype=..., memory_space=...)`` +def tshrs(src, scalar, dst): + """``pto.tshrs ins(src, scalar) outs(dst)``.""" + _pto.tshrs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tshrs"), + unwrap_surface_value(dst), + ) - or the low-level explicit-type form: - ``alloc_tile(tile_type, addr=..., valid_row=..., valid_col=...)``. - """ - if tile_type is not None and shape is not None: - raise TypeError("alloc_tile() accepts either tile_type or shape=/dtype=, not both") +def tpartadd(src0, src1, dst): + """``pto.tpartadd ins(src0, src1) outs(dst)``.""" + _pto.tpartadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) - if tile_type is None: - if shape is None or dtype is None: - raise TypeError("alloc_tile() requires either tile_type or both shape= and dtype=") - if addr is not None or valid_row is not None or valid_col is not None: - raise TypeError( - "alloc_tile(shape=..., dtype=...) uses the authored surface form; " - "addr=/valid_row=/valid_col= are only supported with an explicit tile_type" - ) - shape = _normalize_static_tile_shape(shape) - _validate_authored_tile_row_alignment(shape, dtype, blayout=blayout, slayout=slayout) - type_valid_shape, valid_row, valid_col, surface_valid_shape = _split_valid_shape(shape, valid_shape) - from ._types import tile_buf_type - tile_type = tile_buf_type( - shape, - dtype, - type_valid_shape, - blayout=blayout, - address_space=memory_space, - slayout=slayout, - fractal_size=fractal_size, - pad=pad, - ) - else: - surface_valid_shape = None - value = _pto.AllocTileOp( - _resolve(tile_type), - addr=unwrap_surface_value(addr) if addr is not None else None, - valid_row=unwrap_surface_value(valid_row) if valid_row is not None else None, - valid_col=unwrap_surface_value(valid_col) if valid_col is not None else None, - ).result - if tile_type is not None and (valid_row is not None or valid_col is not None): - parsed_tile_type = parse_tile_type_metadata(_resolve(tile_type)) - rank = len(shape) if shape is not None else len(parsed_tile_type["shape_dims"]) - surface_valid_shape = [None] * rank - if rank >= 1: - surface_valid_shape[0] = valid_row - if rank >= 2: - surface_valid_shape[1] = valid_col - surface_valid_shape = tuple(surface_valid_shape) - return wrap_surface_value( - value, - tile_metadata={ - "shape": shape, - "dtype": dtype, - "memory_space": memory_space, - "valid_shape": surface_valid_shape, - }, +def tpartmul(src0, src1, dst): + """``pto.tpartmul ins(src0, src1) outs(dst)``.""" + _pto.tpartmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), ) -def set_tile_valid_shape(tile, valid_shape): - """Update the runtime valid-shape metadata of a rank-2 dynamic tile.""" - if len(valid_shape) != 2: - raise TypeError( - "tile.valid_shape assignment currently expects exactly two dimensions" - ) +def tpartmax(src0, src1, dst): + """``pto.tpartmax ins(src0, src1) outs(dst)``.""" + _pto.tpartmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) - parsed_tile_type = parse_tile_type_metadata(unwrap_surface_value(tile).type) - if parsed_tile_type is None: - raise TypeError("tile.valid_shape assignment expects a tile_buf-backed value") - if len(parsed_tile_type["shape_dims"]) != 2: - raise TypeError("tile.valid_shape assignment currently only supports rank-2 tiles") - if parsed_tile_type["valid_dims"] != (None, None): - raise TypeError( - "tile.valid_shape assignment requires a tile allocated with fully dynamic " - "valid_shape=[..., ...]" - ) - valid_row, valid_col = _unwrap_sequence(valid_shape) - _pto.SetValidShapeOp( - unwrap_surface_value(tile), - valid_row, - valid_col, +def tpartmin(src0, src1, dst): + """``pto.tpartmin ins(src0, src1) outs(dst)``.""" + _pto.tpartmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), ) -def tload(part, tile): - """``pto.tload ins(part) outs(tile)``.""" - _pto.TLoadOp(None, unwrap_surface_value(part), unwrap_surface_value(tile)) +def tfillpad(src, dst): + """``pto.tfillpad ins(src) outs(dst)``.""" + _pto.tfillpad( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) -def tstore(tile, part): - """``pto.tstore ins(tile) outs(part)``.""" - _pto.TStoreOp(None, unwrap_surface_value(tile), unwrap_surface_value(part)) +def tfillpad_expand(src, dst): + """``pto.tfillpad_expand ins(src) outs(dst)``.""" + _pto.tfillpad_expand( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) -def tmov(src, dst): - """``pto.tmov ins(src) outs(dst)`` – move data between tile domains.""" - _pto.TMovOp(None, unwrap_surface_value(src), unwrap_surface_value(dst)) +def tfillpad_inplace(src, dst): + """``pto.tfillpad_inplace ins(src) outs(dst)``.""" + _pto.tfillpad_inplace( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) def as_ptr(value, result_ptr_type=None): @@ -610,6 +2323,8 @@ def _constant_like(value, mlir_type): return value if isinstance(value, float): return arith.ConstantOp(mlir_type, FloatAttr.get(mlir_type, value)).result + if IntegerType.isinstance(mlir_type): + return _materialize_integer_literal(mlir_type, value) return arith.ConstantOp(mlir_type, value).result @@ -617,6 +2332,36 @@ def _index_zero(): return arith.ConstantOp(IndexType.get(), 0).result +def _tile_slice_linear_offset(tile_slice: TileSliceValue): + offsets = tile_slice.offsets + if len(offsets) == 1: + return offsets[0] + if len(offsets) != 2: + raise RuntimeError("tile slice pointer lowering only supports rank-1 or rank-2 offsets") + + physical_shape = getattr(tile_slice.tile, "physical_shape", None) + if physical_shape is None or len(physical_shape) != 2 or physical_shape[1] is None: + raise RuntimeError("tile slice pointer lowering requires static physical column shape metadata") + + row, col = offsets + stride = physical_shape[1] + if isinstance(row, int) and isinstance(col, int): + return row * stride + col + + row_value = _coerce_index(row, context="tile slice pointer lowering") + row_stride = arith.MulIOp(row_value, arith.ConstantOp(IndexType.get(), stride).result).result + col_value = _coerce_index(col, context="tile slice pointer lowering") + return arith.AddIOp(row_stride, col_value).result + + +def _tile_slice_ptr(tile_slice: TileSliceValue): + base_ptr = emit_as_ptr(tile_slice.tile) + linear_offset = _tile_slice_linear_offset(tile_slice) + if isinstance(linear_offset, int) and linear_offset == 0: + return base_ptr + return addptr(base_ptr, _coerce_index(linear_offset, context="tile slice pointer lowering")) + + def _infer_vreg_type_from_tile_slice(tile_slice: TileSliceValue): memref_type = MemRefType(tile_slice.type) elem_type = memref_type.element_type @@ -624,17 +2369,27 @@ def _infer_vreg_type_from_tile_slice(tile_slice: TileSliceValue): return _resolve(vreg_type(lanes, elem_type)) +def _infer_vreg_type_from_address_source(src_ptr): + raw_source = unwrap_surface_value(src_ptr) + source_type = raw_source.type + try: + elem_type = _pto.PtrType(source_type).element_type + except Exception: + try: + elem_type = MemRefType(source_type).element_type + except Exception as exc: + raise TypeError( + f"vlds(ptr, offset) cannot infer a vector-register type from source {source_type}; " + "pass result_vreg_type= explicitly" + ) from exc + lanes = _elements_per_vreg(elem_type) + return _resolve(vreg_type(lanes, elem_type)) + + def _elements_per_vreg(elem_type): - if F32Type.isinstance(elem_type): - bytewidth = 4 - elif any(cls.isinstance(elem_type) for cls in (F16Type, BF16Type)): - bytewidth = 2 - elif IntegerType.isinstance(elem_type): - width = IntegerType(elem_type).width - if width % 8 != 0: - raise TypeError(f"vlds/vsts tile-slice sugar does not support sub-byte integer element type {elem_type}") - bytewidth = width // 8 - else: + try: + bytewidth = _element_bytewidth(elem_type) + except TypeError as exc: raise TypeError(f"vlds/vsts tile-slice sugar does not support element type {elem_type}") return 256 // bytewidth @@ -666,6 +2421,10 @@ def _element_bytewidth(elem_type): return 4 if any(cls.isinstance(elem_type) for cls in (F16Type, BF16Type)): return 2 + if Float8E4M3FNType.isinstance(elem_type) or Float8E5M2Type.isinstance(elem_type): + return 1 + if any(_isinstance_pto_type(elem_type, name) for name in ("HiF8Type", "F4E1M2x2Type", "F4E2M1x2Type")): + return 1 if IntegerType.isinstance(elem_type): width = IntegerType(elem_type).width if width % 8 != 0: @@ -674,6 +2433,16 @@ def _element_bytewidth(elem_type): raise TypeError(f"unsupported element type {elem_type}") +def bytewidth(dtype): + """Return the size in bytes of one element of *dtype*.""" + return _element_bytewidth(_resolve(dtype)) + + +def elements_per_vreg(dtype): + """Return how many elements of *dtype* fit in one 256-byte vector register.""" + return _elements_per_vreg(_resolve(dtype)) + + def _mask_bits_for_dtype(dtype): elem_type = _resolve(dtype) bytewidth = _element_bytewidth(elem_type) @@ -708,20 +2477,23 @@ def _coerce_i32(value, *, context: str): if isinstance(raw_value, bool): raise TypeError(f"{context} does not accept bool values") if isinstance(raw_value, int): - return arith.ConstantOp(i32_type, raw_value).result + return _materialize_integer_literal(i32_type, raw_value) kind = classify_runtime_scalar_type(raw_value.type) if kind == "float": raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") if kind == "index": return arith.IndexCastOp(i32_type, raw_value).result - if raw_value.type == i32_type: - return raw_value + signless_value = _strip_integer_signedness(raw_value) + if signless_value.type == i32_type: + return signless_value width = IntegerType(raw_value.type).width if width < 32: - return arith.ExtSIOp(i32_type, raw_value).result + if _integer_signedness(raw_value.type) == "unsigned": + return arith.ExtUIOp(i32_type, signless_value).result + return arith.ExtSIOp(i32_type, signless_value).result if width > 32: - return arith.TruncIOp(i32_type, raw_value).result - return raw_value + return arith.TruncIOp(i32_type, signless_value).result + return signless_value def _coerce_i64(value, *, context: str): @@ -730,20 +2502,23 @@ def _coerce_i64(value, *, context: str): if isinstance(raw_value, bool): raise TypeError(f"{context} does not accept bool values") if isinstance(raw_value, int): - return arith.ConstantOp(i64_type, raw_value).result + return _materialize_integer_literal(i64_type, raw_value) kind = classify_runtime_scalar_type(raw_value.type) if kind == "float": raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") if kind == "index": return arith.IndexCastOp(i64_type, raw_value).result - if raw_value.type == i64_type: - return raw_value + signless_value = _strip_integer_signedness(raw_value) + if signless_value.type == i64_type: + return signless_value width = IntegerType(raw_value.type).width if width < 64: - return arith.ExtSIOp(i64_type, raw_value).result + if _integer_signedness(raw_value.type) == "unsigned": + return arith.ExtUIOp(i64_type, signless_value).result + return arith.ExtSIOp(i64_type, signless_value).result if width > 64: - return arith.TruncIOp(i64_type, raw_value).result - return raw_value + return arith.TruncIOp(i64_type, signless_value).result + return signless_value def _i64_zero(): @@ -854,10 +2629,12 @@ def fill_tile(tile, value): def make_mask(dtype, value): """Create a predicate mask matching *dtype* granularity.""" mask_bits = _mask_bits_for_dtype(dtype) - result_type = _resolve(mask_type(f"b{mask_bits}")) + result_type = _mask_type_from_bits(mask_bits) if isinstance(value, str): - return wrap_surface_value(_pset_op_for_mask_bits(mask_bits)(result_type, value).result) + return wrap_surface_value( + _pset_op_for_mask_bits(mask_bits)(result_type, _normalize_mask_pattern(value)).result + ) raw_value = unwrap_surface_value(value) raw_value = _coerce_i32(raw_value, context="make_mask(..., value)") @@ -867,89 +2644,216 @@ def make_mask(dtype, value): # ── Hardware / sync ─────────────────────────────────────────────────────────── -def mte_load(source, destination): +def _require_pto_ptr_operand(value, *, context: str): + raw_value = unwrap_surface_value(value) + try: + _pto.PtrType(raw_value.type) + except Exception as exc: + raise TypeError(f"{context} expects PTO ptr operands, got {raw_value.type}") from exc + return raw_value + + +def mte_load(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=None, pad=None): """ - Convenience GM->on-chip load surface. + Ptr-based GM->UB DMA wrapper aligned with the underlying ``pto.dma_load`` surface. - Current scope is intentionally narrow: contiguous rank-1 or squeezed-rank-2 - partition views lowering into VEC or MAT tiles. + This wrapper intentionally accepts only explicit pointer operands. It does + not infer burst shape or strides from TensorView / PartitionTensorView / + Tile metadata. """ - source = wrap_surface_value(source) - destination = wrap_surface_value(destination) - if not isinstance(source, PartitionTensorViewValue) or not isinstance(destination, TileValue): - raise TypeError("mte_load(source, destination) expects (PartitionTensorView, Tile)") - - src_ptr = emit_as_ptr(source) - dst_ptr = emit_as_ptr(destination) - row_count, valid_cols, src_row_stride, dst_row_stride = _infer_dma_2d_copy_signature( - source, destination, direction="gm_to_ub" - ) - destination_type = parse_tile_type_metadata(unwrap_surface_value(destination).type) - if destination_type is None: - raise TypeError("mte_load(source, destination) expects a tile_buf-backed destination") - destination_space = destination_type["memory_space"] - len_burst = _coerce_i64(_mul_bytes(valid_cols, infer_tile_element_type(destination)), context="mte_load len_burst") - n_burst = _coerce_i64(row_count, context="mte_load n_burst") - src_stride = _coerce_i64(src_row_stride, context="mte_load src_stride") - dst_stride = _coerce_i64(dst_row_stride, context="mte_load dst_stride") - - if destination_space == "vec": - _pto.MteGmUbOp( - unwrap_surface_value(src_ptr), - unwrap_surface_value(dst_ptr), - _i64_zero(), - len_burst, - n_burst, - src_stride, - dst_stride, - [], - [], - [], - ) - return + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_load(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_load(...)", + ) + pad_value, left_padding_count, right_padding_count = _normalize_dma_pad( + pad, + context="mte_load(...)", + ) + _pto.MteGmUbOp( + _require_pto_ptr_operand(source, context="mte_load(...)"), + _require_pto_ptr_operand(destination, context="mte_load(...)"), + _coerce_i64(l2_cache_ctl, context="mte_load l2_cache_ctl"), + _coerce_i64(len_burst, context="mte_load len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + pad_value=pad_value, + left_padding_count=left_padding_count, + right_padding_count=right_padding_count, + ) + + +def mte_store(source, destination, len_burst, *, nburst, loops=None): + """Ptr-based UB->GM DMA wrapper aligned with the underlying ``pto.dma_store`` surface.""" + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_store(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_store(...)", + ) + _pto.MteUbGmOp( + _require_pto_ptr_operand(source, context="mte_store(...)"), + _require_pto_ptr_operand(destination, context="mte_store(...)"), + _coerce_i64(len_burst, context="mte_store len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + ) + + +def _normalize_dma_group(name, triple, *, context: str): + if not isinstance(triple, tuple) or len(triple) != 3: + raise TypeError(f"{context} expects {name}=(count, src_stride, dst_stride)") + count, src_stride, dst_stride = triple + return ( + _coerce_i64(count, context=f"{context} {name}[0]"), + _coerce_i64(src_stride, context=f"{context} {name}[1]"), + _coerce_i64(dst_stride, context=f"{context} {name}[2]"), + ) - if destination_space == "mat": - _pto.MteGmL1Op( - unwrap_surface_value(src_ptr), - unwrap_surface_value(dst_ptr), - len_burst, - n_burst, - src_stride, - dst_stride, - [], - [], - [], - ) - return - raise TypeError( - "mte_load(source, destination) currently supports VEC or MAT tile destinations, " - f"got memory_space={destination_space!r}" +def _normalize_dma_loops(loops, *, context: str): + if loops is None: + return [], [], [] + if not isinstance(loops, (list, tuple)): + raise TypeError(f"{context} expects loops to be a list[tuple[int, int, int]] or None") + counts = [] + src_strides = [] + dst_strides = [] + for i, loop in enumerate(loops): + count, src_stride, dst_stride = _normalize_dma_group( + f"loops[{i}]", + loop, + context=context, + ) + counts.append(count) + src_strides.append(src_stride) + dst_strides.append(dst_stride) + return counts, src_strides, dst_strides + + +def _normalize_dma_pad(pad, *, context: str): + if pad is None: + return None, None, None + if not isinstance(pad, tuple): + raise TypeError(f"{context} expects pad to be tuple[ScalarType] or tuple[ScalarType, int, int]") + if len(pad) == 1: + pad_value = pad[0] + left_count = 0 + right_count = 0 + elif len(pad) == 3: + pad_value, left_count, right_count = pad + else: + raise TypeError(f"{context} expects pad to have length 1 or 3") + return ( + materialize_scalar_literal(pad_value, F32Type.get(), context=f"{context} pad[0]") + if not hasattr(pad_value, "type") else unwrap_surface_value(pad_value), + _coerce_i64(left_count, context=f"{context} pad[1]"), + _coerce_i64(right_count, context=f"{context} pad[2]"), ) -def mte_store(source, destination): - """Convenience UB->GM store surface matching ``mte_load`` scope.""" - source = wrap_surface_value(source) - destination = wrap_surface_value(destination) - if not isinstance(source, TileValue) or not isinstance(destination, PartitionTensorViewValue): - raise TypeError("mte_store(source, destination) expects (Tile, PartitionTensorView)") +def mte_gm_ub(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=None, pad=None): + """``pto.mte_gm_ub`` – grouped GM-to-UB DMA surface.""" + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_gm_ub(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_gm_ub(...)", + ) + pad_value, left_padding_count, right_padding_count = _normalize_dma_pad( + pad, + context="mte_gm_ub(...)", + ) + _pto.MteGmUbOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(l2_cache_ctl, context="mte_gm_ub l2_cache_ctl"), + _coerce_i64(len_burst, context="mte_gm_ub len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + pad_value=pad_value, + left_padding_count=left_padding_count, + right_padding_count=right_padding_count, + ) + - src_ptr = emit_as_ptr(source) - dst_ptr = emit_as_ptr(destination) - row_count, valid_cols, src_row_stride, dst_row_stride = _infer_dma_2d_copy_signature( - destination, source, direction="ub_to_gm" +def mte_ub_gm(source, destination, len_burst, *, nburst, loops=None): + """``pto.mte_ub_gm`` – grouped UB-to-GM DMA surface.""" + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_ub_gm(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_ub_gm(...)", ) _pto.MteUbGmOp( - unwrap_surface_value(src_ptr), - unwrap_surface_value(dst_ptr), - _coerce_i64(_mul_bytes(valid_cols, infer_tile_element_type(source)), context="mte_store len_burst"), - _coerce_i64(row_count, context="mte_store n_burst"), - _coerce_i64(src_row_stride, context="mte_store src_stride"), - _coerce_i64(dst_row_stride, context="mte_store dst_stride"), - [], - [], - [], + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(len_burst, context="mte_ub_gm len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + ) + + +def mte_ub_ub(source, destination, len_burst, *, nburst): + """``pto.mte_ub_ub`` – grouped UB-to-UB DMA surface.""" + n_burst, src_stride, dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_ub_ub(...)", + ) + _pto.MteUbUbOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + n_burst, + _coerce_i64(len_burst, context="mte_ub_ub len_burst"), + src_stride, + dst_stride, + ) + + +def mte_ub_l1(source, destination, len_burst, *, nburst): + """``pto.mte_ub_l1`` – grouped UB-to-L1 DMA surface.""" + n_burst, src_stride, dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_ub_l1(...)", + ) + _pto.MteUbL1Op( + unwrap_surface_value(source), + unwrap_surface_value(destination), + n_burst, + _coerce_i64(len_burst, context="mte_ub_l1 len_burst"), + src_stride, + dst_stride, ) @@ -1006,6 +2910,68 @@ def mad(lhs, rhs, dst, m, n, k): _coerce_i64(k, context="mad k"), ) + +def mad_acc(lhs, rhs, dst, m, n, k): + """``pto.mad_acc`` – cube matmul accumulate into an existing accumulator.""" + _pto.MadAccOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad_acc m"), + _coerce_i64(n, context="mad_acc n"), + _coerce_i64(k, context="mad_acc k"), + ) + + +def mad_bias(lhs, rhs, dst, bias, m, n, k): + """``pto.mad_bias`` – cube matmul initialized from a bias buffer.""" + _pto.MadBiasOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + unwrap_surface_value(bias), + _coerce_i64(m, context="mad_bias m"), + _coerce_i64(n, context="mad_bias n"), + _coerce_i64(k, context="mad_bias k"), + ) + + +def mad_mx(lhs, rhs, dst, m, n, k): + """``pto.mad_mx`` – MX-format cube matmul.""" + _pto.MadMxOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad_mx m"), + _coerce_i64(n, context="mad_mx n"), + _coerce_i64(k, context="mad_mx k"), + ) + + +def mad_mx_acc(lhs, rhs, dst, m, n, k): + """``pto.mad_mx_acc`` – MX-format cube matmul accumulate.""" + _pto.MadMxAccOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad_mx_acc m"), + _coerce_i64(n, context="mad_mx_acc n"), + _coerce_i64(k, context="mad_mx_acc k"), + ) + + +def mad_mx_bias(lhs, rhs, dst, bias, m, n, k): + """``pto.mad_mx_bias`` – MX-format cube matmul initialized from a bias buffer.""" + _pto.MadMxBiasOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + unwrap_surface_value(bias), + _coerce_i64(m, context="mad_mx_bias m"), + _coerce_i64(n, context="mad_mx_bias n"), + _coerce_i64(k, context="mad_mx_bias k"), + ) + def get_block_idx(): """``pto.get_block_idx`` → i64 block index.""" return wrap_surface_value(_pto.GetBlockIdxOp().result) @@ -1055,33 +3021,146 @@ def pipe_barrier(pipe): _pto.BarrierOp(_pipe_attr(pipe)) +def get_buf(pipe, buf_id, mode=0): + """``pto.get_buf(pipe, buf_id, mode=0)`` – acquire a buffer token.""" + _pto.GetBufOp( + _pipe_attr(pipe), + buf_id, + mode=mode, + ) + + +def rls_buf(pipe, buf_id, mode=0): + """``pto.rls_buf(pipe, buf_id, mode=0)`` – release a buffer token.""" + _pto.RlsBufOp( + _pipe_attr(pipe), + buf_id, + mode=mode, + ) + + +def _sync_event_id_operand(event_id, *, context: str): + _validate_static_event_id(event_id, context=context) + return event_id if isinstance(event_id, int) else unwrap_surface_value(event_id) + + +def _flag_event_id_operand(event_id, *, context: str): + if isinstance(event_id, int): + _validate_static_event_id(event_id, context=context) + return event_id, True + return _coerce_index(event_id, context=context), False + + +def set_cross_flag(pipe, event_id): + """``pto.set_cross_flag(pipe, event_id)`` – cross-core sync facade for ``pto.sync.set``.""" + _validate_sync_pipe(pipe, context="set_cross_flag(pipe, event_id)", allowed=("PIPE_FIX",)) + event_operand = _sync_event_id_operand(event_id, context="set_cross_flag(..., event_id=...)") + _pto.sync_set(_pipe_attr(pipe), event_operand) + + +def wait_cross_flag(pipe, event_id): + """``pto.wait_cross_flag(pipe, event_id)`` – cross-core sync facade for ``pto.sync.wait``.""" + _validate_sync_pipe(pipe, context="wait_cross_flag(pipe, event_id)", allowed=("PIPE_FIX",)) + event_operand = _sync_event_id_operand(event_id, context="wait_cross_flag(..., event_id=...)") + _pto.sync_wait(_pipe_attr(pipe), event_operand) + + +def set_intra_flag(pipe, event_id): + """``pto.set_intra_flag(pipe, event_id)`` – intra-block sync facade for ``pto.sync.set``.""" + _validate_sync_pipe(pipe, context="set_intra_flag(pipe, event_id)", allowed=("PIPE_MTE3",)) + event_operand = _sync_event_id_operand(event_id, context="set_intra_flag(..., event_id=...)") + _pto.sync_set(_pipe_attr(pipe), event_operand) + + +def wait_intra_flag(pipe, event_id): + """``pto.wait_intra_flag(pipe, event_id)`` – intra-block sync facade for ``pto.sync.wait``.""" + _validate_sync_pipe(pipe, context="wait_intra_flag(pipe, event_id)", allowed=("PIPE_V",)) + event_operand = _sync_event_id_operand(event_id, context="wait_intra_flag(..., event_id=...)") + _pto.sync_wait(_pipe_attr(pipe), event_operand) + + def set_flag(src: str, dst: str, *, event_id: int = 0): """``pto.set_flag[src, dst, event_id]``. Accepts short pipe names (``"MTE2"``, ``"V"``, …) or full ``"PIPE_MTE2"`` - names. ``event_id`` is an integer in ``[0, 7]``. + names. Static ``event_id`` values in ``[0, 7]`` lower to ``pto.set_flag``; + runtime index-like values lower to ``pto.set_flag_dyn``. """ - _pto.set_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_id)) + event_operand, is_static = _flag_event_id_operand( + event_id, + context="set_flag(..., event_id=...)", + ) + if is_static: + _pto.set_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_operand)) + return + _pto.set_flag_dyn(_pipe_attr(src), _pipe_attr(dst), event_operand) def wait_flag(src: str, dst: str, *, event_id: int = 0): - """``pto.wait_flag[src, dst, event_id]``.""" - _pto.wait_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_id)) + """``pto.wait_flag[src, dst, event_id]``. + + Static ``event_id`` values in ``[0, 7]`` lower to ``pto.wait_flag``; + runtime index-like values lower to ``pto.wait_flag_dyn``. + """ + event_operand, is_static = _flag_event_id_operand( + event_id, + context="wait_flag(..., event_id=...)", + ) + if is_static: + _pto.wait_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_operand)) + return + _pto.wait_flag_dyn(_pipe_attr(src), _pipe_attr(dst), event_operand) __all__ = [ "const", "castptr", "addptr", - "vlds", "vbrc_load", "vsts", "vsts_1pt", - "plt_b32", "pset_b32", "make_mask", - "vadd", "vmul", "vmax", "vdiv", - "vcmax", "vcadd", "vdup", "vexpdif", - "vexp", "vcgmax", "vcgadd", "vsubs", + "vlds", "vldas", "vldus", "vldsx2", "vbrc_load", "vsts", "vsts_1pt", "vstsx2", + "init_align", + "plt_b8", "plt_b16", "plt_b32", + "pset_b8", "pset_b16", "pset_b32", + "pge_b8", "pge_b16", "pge_b32", + "make_mask", + "pand", "por", "pxor", "pnot", "psel", + "pbitcast", "ppack", "punpack", + "pintlv_b8", "pintlv_b16", "pintlv_b32", + "pdintlv_b8", "pdintlv_b16", "pdintlv_b32", + "vgather2", "vgather2_bc", "vgatherb", "vscatter", "vsldb", "vsstb", + "vcmp", "vcmps", + "plds", "psts", "pstu", "vstar", "vstas", "vstur", "vstus", + "vbitcast", + "vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", + "vand", "vor", "vxor", "vshl", "vshr", + "vcmax", "vcadd", "vcmin", "vdup", "vexpdif", + "vexp", "vln", "vsqrt", "vabs", "vneg", "vrec", "vrsqrt", "vrelu", "vnot", + "vcgmax", "vcgadd", "vcgmin", "vcpadd", + "vadds", "vsubs", "vmuls", "vmaxs", "vmins", "vlrelu", + "vaxpy", "vaddrelu", "vsubrelu", + "vsel", "make_tensor_view", "partition_view", - "alloc_tile", "tload", "tstore", "tmov", "as_ptr", - "mte_load", "mte_store", "mem_bar", - "mte_l1_l0a", "mte_l1_l0b", "mte_l0c_ub", "mad", + "alloc_tile", + "tload", "tstore", "tmov", + "tadd", "tsub", "tmul", "tdiv", "tmax", "tmin", + "tadds", "tsubs", "tmuls", "tdivs", "tmaxs", "tmins", + "texp", "tlog", "tsqrt", "trsqrt", "trecip", "tabs", "tneg", + "trelu", "tlrelu", + "trowsum", "trowmax", "trowmin", "trowprod", "trowargmax", "trowargmin", + "tcolsum", "tcolmax", "tcolmin", "tcolprod", "tcolargmax", "tcolargmin", + "tcmp", "tcmps", + "texpands", "trowexpand", "tcolexpand", + "trowexpandadd", "trowexpandsub", "trowexpandmul", "trowexpanddiv", "trowexpandmax", "trowexpandmin", "trowexpandexpdif", + "tcolexpandadd", "tcolexpandsub", "tcolexpandmul", "tcolexpanddiv", "tcolexpandmax", "tcolexpandmin", "tcolexpandexpdif", + "tsel", "tsels", "tcvt", + "tnot", "tand", "tands", "tor", "tors", "txor", "txors", "tshl", "tshls", "tshr", "tshrs", + "tpartadd", "tpartmul", "tpartmax", "tpartmin", + "tfillpad", "tfillpad_expand", "tfillpad_inplace", + "as_ptr", + "mte_load", "mte_store", "mte_gm_ub", "mte_ub_gm", "mte_ub_ub", "mte_ub_l1", "mem_bar", + "mte_l1_l0a", "mte_l1_l0b", "mte_l0c_ub", + "mad", "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias", "get_block_idx", "get_block_num", "get_subblock_idx", "get_subblock_num", "store_vfsimt_info", "get_tid_x", "get_tid_y", "get_tid_z", - "pipe_barrier", "set_flag", "wait_flag", + "pipe_barrier", "get_buf", "rls_buf", + "set_cross_flag", "wait_cross_flag", "set_intra_flag", "wait_intra_flag", + "set_flag", "wait_flag", ] diff --git a/ptodsl/ptodsl/_runtime_scalar_ops.py b/ptodsl/ptodsl/_runtime_scalar_ops.py index f2e5cdfd7..989dca8bf 100644 --- a/ptodsl/ptodsl/_runtime_scalar_ops.py +++ b/ptodsl/ptodsl/_runtime_scalar_ops.py @@ -9,17 +9,16 @@ from __future__ import annotations -from mlir.dialects import arith -from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType +from ._types import ( + _integer_signedness, + _materialize_integer_literal, + _restore_integer_signedness, + _strip_integer_signedness, +) +from mlir.dialects import arith, math +from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType -_INTEGER_BINARY_OPS = { - "add": arith.AddIOp, - "sub": arith.SubIOp, - "mul": arith.MulIOp, - "floordiv": arith.FloorDivSIOp, - "mod": arith.RemSIOp, -} _FLOAT_BINARY_OPS = { "add": arith.AddFOp, @@ -33,10 +32,17 @@ def emit_runtime_binary_op(op_name: str, lhs, rhs): """Lower one authored runtime scalar binary operator.""" lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) if kind in {"index", "integer"}: - op_cls = _INTEGER_BINARY_OPS.get(op_name) + op_cls = _integer_binary_op(op_name, lhs.type) if op_cls is None: raise TypeError(f"runtime scalar operator '{op_name}' is not supported for integer/index values") - return op_cls(lhs, rhs).result + authored_type = lhs.type + if kind == "integer": + lhs = _strip_integer_signedness(lhs) + rhs = _strip_integer_signedness(rhs) + result = op_cls(lhs, rhs).result + if kind == "index": + return result + return _restore_runtime_integer_result(result, authored_type) if kind == "float": op_cls = _FLOAT_BINARY_OPS.get(op_name) if op_cls is None: @@ -51,13 +57,40 @@ def emit_runtime_max(lhs, rhs): if kind == "float": return arith.MaximumFOp(lhs, rhs).result if kind == "integer": - return arith.MaxSIOp(lhs, rhs).result + signedness = _integer_signedness(lhs.type) + signless_lhs = _strip_integer_signedness(lhs) + signless_rhs = _strip_integer_signedness(rhs) + if signedness == "unsigned": + result = arith.MaxUIOp(signless_lhs, signless_rhs).result + else: + result = arith.MaxSIOp(signless_lhs, signless_rhs).result + return _restore_integer_signedness(result, lhs.type) if kind == "index": cond = arith.CmpIOp(arith.CmpIPredicate.sge, lhs, rhs).result return arith.SelectOp(cond, lhs, rhs).result raise TypeError(f"unsupported runtime scalar operand category '{kind}'") +def emit_runtime_min(lhs, rhs): + """Lower one authored runtime scalar min operation.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind == "float": + return arith.MinimumFOp(lhs, rhs).result + if kind == "integer": + signedness = _integer_signedness(lhs.type) + signless_lhs = _strip_integer_signedness(lhs) + signless_rhs = _strip_integer_signedness(rhs) + if signedness == "unsigned": + result = arith.MinUIOp(signless_lhs, signless_rhs).result + else: + result = arith.MinSIOp(signless_lhs, signless_rhs).result + return _restore_integer_signedness(result, lhs.type) + if kind == "index": + cond = arith.CmpIOp(arith.CmpIPredicate.sle, lhs, rhs).result + return arith.SelectOp(cond, lhs, rhs).result + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + def normalize_runtime_binary_operands(lhs, rhs): lhs_is_value = _is_mlir_value(lhs) rhs_is_value = _is_mlir_value(rhs) @@ -82,11 +115,11 @@ def _reconcile_typed_operands(lhs, rhs): return lhs, rhs, classify_runtime_scalar_type(lhs_type) if IndexType.isinstance(lhs_type) and IntegerType.isinstance(rhs_type): - rhs = arith.IndexCastOp(IndexType.get(), rhs).result + rhs = arith.IndexCastOp(IndexType.get(), _strip_integer_signedness(rhs)).result return lhs, rhs, "index" if IntegerType.isinstance(lhs_type) and IndexType.isinstance(rhs_type): - lhs = arith.IndexCastOp(IndexType.get(), lhs).result + lhs = arith.IndexCastOp(IndexType.get(), _strip_integer_signedness(lhs)).result return lhs, rhs, "index" raise TypeError( @@ -102,6 +135,8 @@ def _materialize_literal(value, anchor_type): kind = classify_runtime_scalar_type(anchor_type) if kind == "float": return arith.ConstantOp(anchor_type, FloatAttr.get(anchor_type, float(value))).result + if kind == "index": + return arith.ConstantOp(anchor_type, int(value)).result if isinstance(value, float): raise TypeError( @@ -109,7 +144,7 @@ def _materialize_literal(value, anchor_type): f"against non-floating operand type {anchor_type}" ) - return arith.ConstantOp(anchor_type, int(value)).result + return _materialize_integer_literal(anchor_type, value) def classify_runtime_scalar_type(type_obj): @@ -126,9 +161,142 @@ def _is_mlir_value(value) -> bool: return not isinstance(value, (bool, int, float)) and hasattr(value, "type") +def _restore_runtime_integer_result(result, authored_type): + if IndexType.isinstance(authored_type): + return result + if not IntegerType.isinstance(authored_type): + return result + return _restore_integer_signedness(result, authored_type) + + +def emit_runtime_compare(op_name: str, lhs, rhs): + """Lower one authored runtime scalar comparison operator.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + + if kind == "float": + predicate = { + "lt": arith.CmpFPredicate.OLT, + "le": arith.CmpFPredicate.OLE, + "gt": arith.CmpFPredicate.OGT, + "ge": arith.CmpFPredicate.OGE, + "eq": arith.CmpFPredicate.OEQ, + "ne": arith.CmpFPredicate.ONE, + }.get(op_name) + if predicate is None: + raise TypeError(f"runtime scalar comparison '{op_name}' is not supported for floating-point values") + return arith.CmpFOp(predicate, lhs, rhs).result + + if kind == "index": + predicate = { + "lt": arith.CmpIPredicate.slt, + "le": arith.CmpIPredicate.sle, + "gt": arith.CmpIPredicate.sgt, + "ge": arith.CmpIPredicate.sge, + "eq": arith.CmpIPredicate.eq, + "ne": arith.CmpIPredicate.ne, + }.get(op_name) + if predicate is None: + raise TypeError(f"runtime scalar comparison '{op_name}' is not supported for index values") + return arith.CmpIOp(predicate, lhs, rhs).result + + if kind == "integer": + signedness = _integer_signedness(lhs.type) + signed_predicates = { + "lt": arith.CmpIPredicate.slt, + "le": arith.CmpIPredicate.sle, + "gt": arith.CmpIPredicate.sgt, + "ge": arith.CmpIPredicate.sge, + "eq": arith.CmpIPredicate.eq, + "ne": arith.CmpIPredicate.ne, + } + unsigned_predicates = { + "lt": arith.CmpIPredicate.ult, + "le": arith.CmpIPredicate.ule, + "gt": arith.CmpIPredicate.ugt, + "ge": arith.CmpIPredicate.uge, + "eq": arith.CmpIPredicate.eq, + "ne": arith.CmpIPredicate.ne, + } + predicate = (unsigned_predicates if signedness == "unsigned" else signed_predicates).get(op_name) + if predicate is None: + raise TypeError(f"runtime scalar comparison '{op_name}' is not supported for integer values") + return arith.CmpIOp(predicate, _strip_integer_signedness(lhs), _strip_integer_signedness(rhs)).result + + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def emit_runtime_bitwise_op(op_name: str, lhs, rhs): + """Lower one authored runtime scalar bitwise operator.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind != "integer": + raise TypeError( + f"runtime scalar bitwise operator '{op_name}' expects integer-like operands, got {lhs.type} and {rhs.type}" + ) + + op_cls = { + "and": arith.AndIOp, + "or": arith.OrIOp, + "xor": arith.XOrIOp, + }.get(op_name) + if op_cls is None: + raise TypeError(f"unsupported runtime scalar bitwise operator '{op_name}'") + + authored_type = lhs.type + result = op_cls(_strip_integer_signedness(lhs), _strip_integer_signedness(rhs)).result + return _restore_integer_signedness(result, authored_type) + + +def emit_runtime_abs(value): + """Lower one authored runtime scalar absolute-value operation.""" + kind = classify_runtime_scalar_type(value.type) + if kind == "float": + return math.AbsFOp(value).result + if kind == "index": + return value + if kind == "integer": + signedness = _integer_signedness(value.type) + if signedness == "unsigned": + return value + result = math.AbsIOp(_strip_integer_signedness(value)).result + return _restore_integer_signedness(result, value.type) + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def _integer_binary_op(op_name: str, authored_type): + if IndexType.isinstance(authored_type): + return { + "add": arith.AddIOp, + "sub": arith.SubIOp, + "mul": arith.MulIOp, + "floordiv": arith.FloorDivSIOp, + "mod": arith.RemSIOp, + }.get(op_name) + + signedness = _integer_signedness(authored_type) + if op_name in {"add", "sub", "mul"}: + return { + "add": arith.AddIOp, + "sub": arith.SubIOp, + "mul": arith.MulIOp, + }[op_name] + if op_name == "floordiv": + if signedness == "unsigned": + return arith.DivUIOp + return arith.FloorDivSIOp + if op_name == "mod": + if signedness == "unsigned": + return arith.RemUIOp + return arith.RemSIOp + return None + + __all__ = [ "classify_runtime_scalar_type", + "emit_runtime_abs", "emit_runtime_binary_op", + "emit_runtime_compare", + "emit_runtime_bitwise_op", "emit_runtime_max", + "emit_runtime_min", "normalize_runtime_binary_operands", ] diff --git a/ptodsl/ptodsl/_scalar_coercion.py b/ptodsl/ptodsl/_scalar_coercion.py index fc150a5f3..ad839977c 100644 --- a/ptodsl/ptodsl/_scalar_coercion.py +++ b/ptodsl/ptodsl/_scalar_coercion.py @@ -11,6 +11,13 @@ from ._runtime_scalar_ops import classify_runtime_scalar_type from ._surface_values import unwrap_surface_value +from ._types import ( + _integer_signedness, + _materialize_integer_literal, + _restore_integer_signedness, + _signless_integer_type, + _strip_integer_signedness, +) from mlir.dialects import arith from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType @@ -31,7 +38,7 @@ def coerce_scalar_to_type(value, target_type, *, context: str): if source_kind == "index" and target_kind == "integer": return _coerce_integer_like(raw_value, target_type) if source_kind == "integer" and target_kind == "index": - return arith.IndexCastOp(target_type, raw_value).result + return arith.IndexCastOp(target_type, _strip_integer_signedness(raw_value)).result if source_kind == "integer" and target_kind == "integer": return _coerce_integer_like(raw_value, target_type) if source_kind == "float" and target_kind == "float": @@ -51,6 +58,8 @@ def materialize_scalar_literal(value, target_type, *, context: str): target_kind = classify_runtime_scalar_type(target_type) if target_kind == "float": return arith.ConstantOp(target_type, FloatAttr.get(target_type, float(value))).result + if target_kind == "index": + return arith.ConstantOp(target_type, int(value)).result if isinstance(value, float): raise TypeError( @@ -58,19 +67,32 @@ def materialize_scalar_literal(value, target_type, *, context: str): f"target type {target_type}" ) - return arith.ConstantOp(target_type, int(value)).result + return _materialize_integer_literal(target_type, value) def _coerce_integer_like(raw_value, target_type): if IndexType.isinstance(raw_value.type): - return arith.IndexCastOp(target_type, raw_value).result - source_width = IntegerType(raw_value.type).width + signless_target = _signless_integer_type(target_type) + adapted = arith.IndexCastOp(signless_target, raw_value).result + return _restore_integer_signedness(adapted, target_type) + + source_type = raw_value.type + source_width = IntegerType(source_type).width target_width = IntegerType(target_type).width + signless_source = _strip_integer_signedness(raw_value) + signless_target = _signless_integer_type(target_type) + if source_width < target_width: - return arith.ExtSIOp(target_type, raw_value).result + source_signedness = _integer_signedness(source_type) + if source_signedness == "unsigned": + widened = arith.ExtUIOp(signless_target, signless_source).result + else: + widened = arith.ExtSIOp(signless_target, signless_source).result + return _restore_integer_signedness(widened, target_type) if source_width > target_width: - return arith.TruncIOp(target_type, raw_value).result - return raw_value + truncated = arith.TruncIOp(signless_target, signless_source).result + return _restore_integer_signedness(truncated, target_type) + return _restore_integer_signedness(signless_source, target_type) def _coerce_float_like(raw_value, target_type): diff --git a/ptodsl/ptodsl/_subkernels.py b/ptodsl/ptodsl/_subkernels.py index 64e4c481f..380c1ff59 100644 --- a/ptodsl/ptodsl/_subkernels.py +++ b/ptodsl/ptodsl/_subkernels.py @@ -15,6 +15,7 @@ import inspect from ._diagnostics import ( + illegal_inline_subkernel_placement_error, illegal_subkernel_placement_error, simd_value_escape_error, subkernel_host_tensor_boundary_error, @@ -79,9 +80,7 @@ def _validate_definition(self) -> None: def _validate_invocation(self, *args, **kwargs) -> None: session = current_session() outer = session.current_subkernel if session is not None else None - if outer is not None: - if self.spec.role == KernelRole.UKERNEL or outer.role != KernelRole.UKERNEL.value: - raise illegal_subkernel_placement_error(self.spec.role.value, outer.role) + _validate_subkernel_placement(self.spec.role, outer) bound = self.signature.bind_partial(*args, **kwargs) for name, value in bound.arguments.items(): @@ -121,18 +120,62 @@ def _find_transient_simd_escape(value): return None -def _subkernel_decorator(role: KernelRole, *, name: str | None = None, target: str = "a5"): - def decorator(fn): +def _validate_subkernel_placement(role: KernelRole, outer_frame, *, inline: bool = False) -> None: + if outer_frame is None: + return + if role == KernelRole.UKERNEL or outer_frame.role != KernelRole.UKERNEL.value: + if inline: + raise illegal_inline_subkernel_placement_error(role.value, outer_frame.role) + raise illegal_subkernel_placement_error(role.value, outer_frame.role) + + +class _SubkernelSurface: + """Dual-use surface that supports both decorators and inline context-manager scopes.""" + + def __init__(self, role: KernelRole, *, name: str | None = None, target: str = "a5"): + self._role = role + self._name = name + self._target = target + self._session_cm = None + + def __call__(self, fn): return SubkernelTemplate( SubkernelSpec( - role=role, - symbol_name=name or fn.__name__, - target=target, + role=self._role, + symbol_name=self._name or fn.__name__, + target=self._target, ), fn, ) - return decorator + def __enter__(self): + runtime = current_runtime() + if runtime is None: + raise RuntimeError( + f"inline pto.{self._role.value}() may only be used while tracing " + "a compatible PTODSL kernel" + ) + session = current_session() + outer = session.current_subkernel if session is not None else None + _validate_subkernel_placement(self._role, outer, inline=True) + symbol_name = self._name or f"inline_{self._role.value}" + self._session_cm = session.enter_inline_subkernel( + self._role.value, + symbol_name, + self._target, + ) + self._session_cm.__enter__() + return None + + def __exit__(self, *exc): + try: + return self._session_cm.__exit__(*exc) + finally: + self._session_cm = None + + +def _subkernel_decorator(role: KernelRole, *, name: str | None = None, target: str = "a5"): + return _SubkernelSurface(role, name=name, target=target) def _decorate_subkernel(role: KernelRole, fn=None, *, name: str | None = None, target: str = "a5"): diff --git a/ptodsl/ptodsl/_surface_types.py b/ptodsl/ptodsl/_surface_types.py index e48cedb68..8ea729eaa 100644 --- a/ptodsl/ptodsl/_surface_types.py +++ b/ptodsl/ptodsl/_surface_types.py @@ -74,6 +74,98 @@ class Pipe: ALL = _pto.PIPE.PIPE_ALL +class MaskPattern: + """Public PTODSL mask-pattern tokens.""" + + ALL = "PAT_ALL" + ALLF = "PAT_ALLF" + H = "PAT_H" + Q = "PAT_Q" + M3 = "PAT_M3" + M4 = "PAT_M4" + + +for _vl in range(1, 129): + setattr(MaskPattern, f"VL{_vl}", f"PAT_VL{_vl}") + + +class CmpMode: + """Public PTODSL compare-mode tokens.""" + + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + +class PredicatePart: + """Public PTODSL predicate pack/unpack part tokens.""" + + LOWER = "LOWER" + HIGHER = "HIGHER" + + +class PredicateDist: + """Public PTODSL predicate load/store distribution tokens.""" + + NORM = "NORM" + US = "US" + DS = "DS" + PK = "PK" + + +class VStoreDist: + """Public PTODSL vector-store distribution tokens.""" + + NORM_B8 = "NORM_B8" + NORM_B16 = "NORM_B16" + NORM_B32 = "NORM_B32" + _1PT_B8 = "1PT_B8" + _1PT_B16 = "1PT_B16" + _1PT_B32 = "1PT_B32" + PK_B16 = "PK_B16" + PK_B32 = "PK_B32" + PK_B64 = "PK_B64" + PK4_B32 = "PK4_B32" + MRG4CHN_B8 = "MRG4CHN_B8" + MRG2CHN_B8 = "MRG2CHN_B8" + MRG2CHN_B16 = "MRG2CHN_B16" + + +setattr(VStoreDist, "1PT_B8", VStoreDist._1PT_B8) +setattr(VStoreDist, "1PT_B16", VStoreDist._1PT_B16) +setattr(VStoreDist, "1PT_B32", VStoreDist._1PT_B32) + + +class DeinterleaveDist: + """Public PTODSL dual-load distribution tokens.""" + + DINTLV_B8 = "DINTLV_B8" + DINTLV_B16 = "DINTLV_B16" + DINTLV_B32 = "DINTLV_B32" + BDINTLV = "BDINTLV" + + +class InterleaveDist: + """Public PTODSL dual-store distribution tokens.""" + + INTLV_B8 = "INTLV_B8" + INTLV_B16 = "INTLV_B16" + INTLV_B32 = "INTLV_B32" + + +class PostUpdate: + """Public PTODSL post-update mode tokens for stateful stores.""" + + OFF = "NO_POST_UPDATE" + ON = "POST_UPDATE" + + +AlignType = _pto.AlignType + + class TensorView: """Authoring-time marker for a tensor-view descriptor value.""" @@ -92,6 +184,15 @@ class Tile: "MemorySpace", "BarrierType", "Pipe", + "MaskPattern", + "CmpMode", + "PredicatePart", + "PredicateDist", + "VStoreDist", + "DeinterleaveDist", + "InterleaveDist", + "PostUpdate", + "AlignType", "TensorView", "PartitionTensorView", "Tile", diff --git a/ptodsl/ptodsl/_surface_values.py b/ptodsl/ptodsl/_surface_values.py index ca45f98a3..f86fe640a 100644 --- a/ptodsl/ptodsl/_surface_values.py +++ b/ptodsl/ptodsl/_surface_values.py @@ -13,14 +13,14 @@ from dataclasses import dataclass from ._diagnostics import native_python_control_flow_error -from ._runtime_scalar_ops import emit_runtime_binary_op +from ._runtime_scalar_ops import emit_runtime_binary_op, emit_runtime_bitwise_op, emit_runtime_compare from ._surface_types import PartitionTensorView, TensorView, Tile from ._types import _normalize_address_space, _resolve, ptr from mlir.dialects import arith from mlir.dialects import memref from mlir.dialects import pto as _pto -from mlir.ir import IndexType, MemRefType, ShapedType, StridedLayoutAttr, Type +from mlir.ir import IndexType, IntegerType, MemRefType, ShapedType, StridedLayoutAttr, Type def unwrap_surface_value(value): @@ -184,6 +184,42 @@ def __mod__(self, other): def __rmod__(self, other): return wrap_surface_value(emit_runtime_binary_op("mod", unwrap_surface_value(other), self.value)) + def __lt__(self, other): + return wrap_surface_value(emit_runtime_compare("lt", self.value, unwrap_surface_value(other))) + + def __le__(self, other): + return wrap_surface_value(emit_runtime_compare("le", self.value, unwrap_surface_value(other))) + + def __gt__(self, other): + return wrap_surface_value(emit_runtime_compare("gt", self.value, unwrap_surface_value(other))) + + def __ge__(self, other): + return wrap_surface_value(emit_runtime_compare("ge", self.value, unwrap_surface_value(other))) + + def __eq__(self, other): + return wrap_surface_value(emit_runtime_compare("eq", self.value, unwrap_surface_value(other))) + + def __ne__(self, other): + return wrap_surface_value(emit_runtime_compare("ne", self.value, unwrap_surface_value(other))) + + def __and__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("and", self.value, unwrap_surface_value(other))) + + def __rand__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("and", unwrap_surface_value(other), self.value)) + + def __or__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("or", self.value, unwrap_surface_value(other))) + + def __ror__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("or", unwrap_surface_value(other), self.value)) + + def __xor__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("xor", self.value, unwrap_surface_value(other))) + + def __rxor__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("xor", unwrap_surface_value(other), self.value)) + class MaskResultValue(_SurfaceValue): """Mask value that also supports `(mask, remained)` unpacking.""" @@ -303,7 +339,11 @@ def __init__(self, tile: "TileValue"): self._cache: dict[int, object] = {} def __getitem__(self, index: int): - if index not in {0, 1}: + logical_rank = len(self._tile.shape) if self._tile.shape is not None else 2 + allowed = {0} if logical_rank == 1 else {0, 1} + if index not in allowed: + if logical_rank == 1: + raise IndexError("PTODSL rank-1 tile.valid_shape currently supports only index 0") raise IndexError("PTODSL tile.valid_shape currently supports indices 0 and 1") cached = self._cache.get(index) if cached is not None: @@ -316,7 +356,9 @@ def __getitem__(self, index: int): self._cache[index] = value return value try: - if index == 0: + if logical_rank == 1: + value = wrap_surface_value(_pto.TileValidColsOp(self._tile.value).result) + elif index == 0: value = wrap_surface_value(_pto.TileValidRowsOp(self._tile.value).result) else: value = wrap_surface_value(_pto.TileValidColsOp(self._tile.value).result) @@ -341,6 +383,7 @@ def __init__( value, *, shape=None, + physical_shape=None, dtype=None, memory_space=None, valid_shape=None, @@ -350,6 +393,11 @@ def __init__( self.shape = tuple(shape) if shape is not None else ( parsed["shape_dims"] if parsed is not None else None ) + self.physical_shape = tuple(physical_shape) if physical_shape is not None else ( + tuple(shape) if shape is not None else ( + parsed["shape_dims"] if parsed is not None else None + ) + ) self.dtype = dtype if dtype is not None else ( parsed["element_type"] if parsed is not None else None ) @@ -377,6 +425,7 @@ def valid_shape(self, dims): def surface_metadata(self): return { "shape": self.shape, + "physical_shape": self.physical_shape, "dtype": self.dtype, "memory_space": self.memory_space, "valid_shape": self.static_valid_shape, @@ -622,12 +671,13 @@ def infer_memref_type_from_surface_value(surface_value): return surface_value.type if isinstance(surface_value, TileValue): - if surface_value.shape is not None and surface_value.dtype is not None and surface_value.memory_space is not None: + physical_shape = getattr(surface_value, "physical_shape", None) + if physical_shape is not None and surface_value.dtype is not None and surface_value.memory_space is not None: space_enum = _normalize_address_space(surface_value.memory_space) if space_enum is None: raise RuntimeError("unsupported tile memory space for memref address view") return MemRefType.get( - list(surface_value.shape), + list(physical_shape), _resolve(surface_value.dtype), memory_space=_pto.AddressSpaceAttr.get(space_enum), ) @@ -703,7 +753,11 @@ def _materialize_tile_slice(tile: TileValue, key): if start_slice.stop is not None or start_slice.step is not None: raise TypeError("tile[start:] only supports an open-ended slice") start = 0 if start_slice.start is None else start_slice.start - return _build_tile_slice_view(tile, raw_offsets=[start], shape=[_dynamic_extent(tile.shape[0], start)]) + return _build_tile_slice_view( + tile, + raw_offsets=[0, start], + shape=[_dynamic_extent(tile.shape[0], start)], + ) row, col_slice = key if col_slice.stop is not None or col_slice.step is not None: @@ -741,14 +795,14 @@ def _build_tile_slice_view(tile: TileValue, *, raw_offsets, shape): ).result return TileSliceValue(slice_value, tile=tile, offsets=tuple(raw_offsets), shape=shape) - row_type = _make_strided_memref_type( - [1, _static_extent_if_known(shape[0])], + slice_type = _make_strided_memref_type( + [_static_extent_if_known(shape[0])], base_type.element_type, - [base_type.shape[1], 1], + [1], base_type.memory_space, ) - row_view = memref.SubViewOp( - row_type, + slice_value = memref.SubViewOp( + slice_type, base_memref, offset_operands, shape_operands, @@ -757,13 +811,6 @@ def _build_tile_slice_view(tile: TileValue, *, raw_offsets, shape): [1, static_shape[0]], [1, 1], ).result - flat_type = _make_strided_memref_type( - [_static_extent_if_known(shape[0])], - base_type.element_type, - [1], - base_type.memory_space, - ) - slice_value = memref.CollapseShapeOp(flat_type, row_view, [[0, 1]]).result return TileSliceValue(slice_value, tile=tile, offsets=tuple(raw_offsets), shape=shape) @@ -775,7 +822,7 @@ def _emit_tile_memref(tile: TileValue): def _dynamic_extent(static_dim, start): if isinstance(start, int): return static_dim - start - return arith.SubIOp(_index_const(static_dim), start).result + return arith.SubIOp(_index_const(static_dim), _coerce_index_value(start)).result def _static_extent_if_known(extent): @@ -821,7 +868,13 @@ def _mul_index(lhs, rhs): def _coerce_index_value(value): value = _normalize_index(value) - return _index_const(value) if isinstance(value, int) else value + if isinstance(value, int): + return _index_const(value) + if IndexType.isinstance(value.type): + return value + if IntegerType.isinstance(value.type): + return arith.IndexCastOp(IndexType.get(), value).result + raise TypeError(f"expected an index-like value, got {value.type}") __all__ = [ diff --git a/ptodsl/ptodsl/_tile_namespace.py b/ptodsl/ptodsl/_tile_namespace.py new file mode 100644 index 000000000..66d01224e --- /dev/null +++ b/ptodsl/ptodsl/_tile_namespace.py @@ -0,0 +1,128 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from . import _ops + + +def _resolve_row_reduction_tmp(src, tmp): + if tmp is not None: + return tmp + return _ops.alloc_tile(tile_type=_ops.unwrap_surface_value(src).type) + + +class _TileNamespace: + load = staticmethod(_ops.tload) + store = staticmethod(_ops.tstore) + mov = staticmethod(_ops.tmov) + + add = staticmethod(_ops.tadd) + sub = staticmethod(_ops.tsub) + mul = staticmethod(_ops.tmul) + div = staticmethod(_ops.tdiv) + max = staticmethod(_ops.tmax) + min = staticmethod(_ops.tmin) + + adds = staticmethod(_ops.tadds) + subs = staticmethod(_ops.tsubs) + muls = staticmethod(_ops.tmuls) + divs = staticmethod(_ops.tdivs) + maxs = staticmethod(_ops.tmaxs) + mins = staticmethod(_ops.tmins) + + exp = staticmethod(_ops.texp) + log = staticmethod(_ops.tlog) + sqrt = staticmethod(_ops.tsqrt) + rsqrt = staticmethod(_ops.trsqrt) + recip = staticmethod(_ops.trecip) + abs = staticmethod(_ops.tabs) + neg = staticmethod(_ops.tneg) + + relu = staticmethod(_ops.trelu) + lrelu = staticmethod(_ops.tlrelu) + + @staticmethod + def rowsum(src, dst, *, tmp=None): + return _ops.trowsum(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowmax(src, dst, *, tmp=None): + return _ops.trowmax(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowmin(src, dst, *, tmp=None): + return _ops.trowmin(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowprod(src, dst, *, tmp=None): + return _ops.trowprod(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowargmax(src, dst, *, tmp=None): + return _ops.trowargmax(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowargmin(src, dst, *, tmp=None): + return _ops.trowargmin(src, _resolve_row_reduction_tmp(src, tmp), dst) + + colsum = staticmethod(_ops.tcolsum) + colmax = staticmethod(_ops.tcolmax) + colmin = staticmethod(_ops.tcolmin) + colprod = staticmethod(_ops.tcolprod) + colargmax = staticmethod(_ops.tcolargmax) + colargmin = staticmethod(_ops.tcolargmin) + + cmp = staticmethod(_ops.tcmp) + cmps = staticmethod(_ops.tcmps) + + expands = staticmethod(_ops.texpands) + rowexpand = staticmethod(_ops.trowexpand) + colexpand = staticmethod(_ops.tcolexpand) + + rowexpandadd = staticmethod(_ops.trowexpandadd) + rowexpandsub = staticmethod(_ops.trowexpandsub) + rowexpandmul = staticmethod(_ops.trowexpandmul) + rowexpanddiv = staticmethod(_ops.trowexpanddiv) + rowexpandmax = staticmethod(_ops.trowexpandmax) + rowexpandmin = staticmethod(_ops.trowexpandmin) + rowexpandexpdif = staticmethod(_ops.trowexpandexpdif) + + colexpandadd = staticmethod(_ops.tcolexpandadd) + colexpandsub = staticmethod(_ops.tcolexpandsub) + colexpandmul = staticmethod(_ops.tcolexpandmul) + colexpanddiv = staticmethod(_ops.tcolexpanddiv) + colexpandmax = staticmethod(_ops.tcolexpandmax) + colexpandmin = staticmethod(_ops.tcolexpandmin) + colexpandexpdif = staticmethod(_ops.tcolexpandexpdif) + + sel = staticmethod(_ops.tsel) + sels = staticmethod(_ops.tsels) + cvt = staticmethod(_ops.tcvt) + + bit_not = staticmethod(_ops.tnot) + bit_and = staticmethod(_ops.tand) + bit_ands = staticmethod(_ops.tands) + bit_or = staticmethod(_ops.tor) + bit_ors = staticmethod(_ops.tors) + bit_xor = staticmethod(_ops.txor) + bit_xors = staticmethod(_ops.txors) + bit_shl = staticmethod(_ops.tshl) + bit_shls = staticmethod(_ops.tshls) + bit_shr = staticmethod(_ops.tshr) + bit_shrs = staticmethod(_ops.tshrs) + + partadd = staticmethod(_ops.tpartadd) + partmul = staticmethod(_ops.tpartmul) + partmax = staticmethod(_ops.tpartmax) + partmin = staticmethod(_ops.tpartmin) + + fillpad = staticmethod(_ops.tfillpad) + fillpad_expand = staticmethod(_ops.tfillpad_expand) + fillpad_inplace = staticmethod(_ops.tfillpad_inplace) + + +tile = _TileNamespace() diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index e12fd2a36..ac2da5f9d 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -93,12 +93,12 @@ def enter_function(self, ir_fn): raise RuntimeError("PTODSL trace-session function stack corruption detected") @contextmanager - def enter_subkernel(self, subkernel): - """Push *subkernel* as the current active inline-lowering frame.""" + def enter_inline_subkernel(self, role: str, symbol_name: str, target: str): + """Push one inline subkernel frame onto the active tracing stack.""" frame = SubkernelTraceFrame( - role=subkernel.spec.role.value, - symbol_name=subkernel.spec.symbol_name, - target=subkernel.spec.target, + role=role, + symbol_name=symbol_name, + target=target, ) self._subkernel_stack.append(frame) try: @@ -108,6 +108,16 @@ def enter_subkernel(self, subkernel): if popped is not frame: raise RuntimeError("PTODSL trace-session subkernel stack corruption detected") + @contextmanager + def enter_subkernel(self, subkernel): + """Push *subkernel* as the current active inline-lowering frame.""" + with self.enter_inline_subkernel( + subkernel.spec.role.value, + subkernel.spec.symbol_name, + subkernel.spec.target, + ) as frame: + yield frame + def lower_inline_subkernel(self, subkernel, *args, **kwargs): """Lower one inline PTODSL subkernel call through the shared session.""" with self.enter_subkernel(subkernel): diff --git a/ptodsl/ptodsl/_types.py b/ptodsl/ptodsl/_types.py index 5d822f5a7..9edf9d575 100644 --- a/ptodsl/ptodsl/_types.py +++ b/ptodsl/ptodsl/_types.py @@ -22,10 +22,15 @@ def softmax(arg0: pto.ptr(pto.float32, "GM"), ...): from ._bootstrap import make_context # ensure MLIR is on sys.path from mlir.dialects import pto as _pto +from mlir.dialects import arith +from mlir.dialects.builtin import UnrealizedConversionCastOp from mlir.ir import ( BF16Type, F16Type, F32Type, + Float8E4M3FNType, + Float8E5M2Type, + FloatAttr, IndexType, IntegerType, ShapedType, @@ -66,6 +71,15 @@ def __init__(self, factory): def resolve(self) -> Type: return self._factory() + def __call__(self, value): + target_type = self.resolve() + kind = _classify_scalar_type(target_type) + if kind == "float": + return arith.ConstantOp(target_type, _parse_float_attr(target_type, value)).result + if kind == "integer": + return _materialize_integer_literal(target_type, value) + raise TypeError(f"unsupported eager constructor target type {target_type}") + def __repr__(self): return f"" @@ -76,7 +90,7 @@ def __init__(self, elem, space: str): self._space = space def resolve(self) -> Type: - elem = _resolve(self._elem) + elem = _ensure_non_storage_only_dtype(self._elem, context="pto.ptr(...)") space_enum = _normalize_address_space(self._space) if space_enum is None: raise ValueError( @@ -108,7 +122,7 @@ def __init__(self, lanes: int, elem): self._elem = elem def resolve(self) -> Type: - elem = _resolve(self._elem) + elem = _ensure_non_storage_only_dtype(self._elem, context="pto.vreg_type(...)") vreg_type_cls = getattr(_pto, "VRegType", None) if vreg_type_cls is None: raise TypeError( @@ -145,6 +159,180 @@ def _resolve(dtype) -> Type: return dtype # already an mlir.ir.Type +def _classify_scalar_type(type_obj): + if F32Type.isinstance(type_obj) or F16Type.isinstance(type_obj) or BF16Type.isinstance(type_obj): + return "float" + if IndexType.isinstance(type_obj) or IntegerType.isinstance(type_obj): + return "integer" + return None + + +def _isinstance_pto_type(type_obj, type_name: str) -> bool: + cls = getattr(_pto, type_name, None) + if cls is None: + return False + try: + return cls.isinstance(type_obj) + except Exception: + return False + + +def _classify_storage_dtype(type_obj): + if _classify_scalar_type(type_obj) is not None: + return "compute" + if Float8E4M3FNType.isinstance(type_obj) or Float8E5M2Type.isinstance(type_obj): + return "storage_only" + if any(_isinstance_pto_type(type_obj, name) for name in ("HiF8Type", "F4E1M2x2Type", "F4E2M1x2Type")): + return "storage_only" + return "other" + + +def _is_storage_only_dtype(type_obj): + return _classify_storage_dtype(type_obj) == "storage_only" + + +def _is_storage_only_authored_dtype(dtype) -> bool: + if isinstance(dtype, _DType): + return dtype in _STORAGE_ONLY_DTYPE_DESCRIPTORS + return _is_storage_only_dtype(_resolve(dtype)) + + +def _ensure_tensor_storage_dtype(dtype, *, context: str): + type_obj = _resolve(dtype) + category = _classify_storage_dtype(type_obj) + if category not in {"compute", "storage_only"}: + raise TypeError(f"{context} does not support element type {type_obj}") + return type_obj + + +def _ensure_non_storage_only_dtype(dtype, *, context: str): + type_obj = _resolve(dtype) + if _is_storage_only_dtype(type_obj): + raise TypeError( + f"{context} does not accept storage-only low-precision type {type_obj}; " + "these dtypes are only supported in Tile / TensorView / PartitionTensorView construction" + ) + return type_obj + + +def _ensure_non_storage_only_authored_dtype(dtype, *, context: str): + if _is_storage_only_authored_dtype(dtype): + raise TypeError( + f"{context} does not accept storage-only low-precision types; " + "these dtypes are only supported in Tile / TensorView / PartitionTensorView construction" + ) + return dtype + + +def _integer_signedness(type_obj): + if not IntegerType.isinstance(type_obj): + raise TypeError(f"expected integer type, got {type_obj}") + text = str(type_obj) + if text.startswith("si"): + return "signed" + if text.startswith("ui"): + return "unsigned" + return "signless" + + +def _signless_integer_type(type_obj): + if not IntegerType.isinstance(type_obj): + raise TypeError(f"expected integer type, got {type_obj}") + return IntegerType.get_signless(IntegerType(type_obj).width) + + +def _strip_integer_signedness(value): + value_type = getattr(value, "type", None) + if value_type is None or not IntegerType.isinstance(value_type): + return value + signless_type = _signless_integer_type(value_type) + if value_type == signless_type: + return value + return UnrealizedConversionCastOp([signless_type], [value]).results[0] + + +def _restore_integer_signedness(value, target_type): + if not IntegerType.isinstance(target_type): + raise TypeError(f"expected integer target type, got {target_type}") + signless_type = _signless_integer_type(target_type) + if target_type == signless_type: + return value + return UnrealizedConversionCastOp([target_type], [value]).results[0] + + +def _materialize_integer_literal(target_type, value): + if not IntegerType.isinstance(target_type): + raise TypeError(f"unsupported eager integer constructor target type {target_type}") + signless_type = _signless_integer_type(target_type) + raw_value = _parse_integer_value(value, target_type=target_type) + constant = arith.ConstantOp(signless_type, raw_value).result + if target_type == signless_type: + return constant + return _restore_integer_signedness(constant, target_type) + + +def _parse_integer_value(value, *, target_type=None): + if isinstance(value, bool): + raise TypeError("eager scalar constructors do not accept bool values") + if isinstance(value, int): + return value + if isinstance(value, str): + text = value.strip() + return _parse_integer_text(text) + raise TypeError(f"cannot materialize {value!r} as an integer constant of type {target_type}") + + +def _parse_integer_text(text: str): + if text.startswith(("0x", "0X", "-0x", "-0X")): + return int(text, 16) + return int(text, 0) + + +def _parse_float_attr(target_type, value): + if isinstance(value, bool): + raise TypeError("eager scalar constructors do not accept bool values") + if isinstance(value, str): + text = value.strip() + lower = text.lower() + if lower in {"inf", "+inf", "-inf", "nan"}: + numeric = float(lower) + elif text.startswith(("0x", "0X")): + return _float_attr_from_bit_pattern(target_type, text) + else: + numeric = float(text) + else: + numeric = float(value) + return FloatAttr.get(target_type, numeric) + + +def _float_attr_from_bit_pattern(target_type, text): + import math + import struct + + if F16Type.isinstance(target_type): + bits = int(text, 16) & 0xFFFF + as_bytes = bits.to_bytes(2, byteorder="little", signed=False) + numeric = struct.unpack(" Type: """``!pto.tensor_view`` with *rank* all-dynamic dims.""" - return _pto.TensorViewType.get(rank, _resolve(elem)) + return _pto.TensorViewType.get(rank, _ensure_tensor_storage_dtype(elem, context="pto.tensor_view_type(...)")) def part_tensor_view_type(rank: int, elem) -> Type: """``!pto.partition_tensor_view`` with *rank* all-dynamic dims.""" kDynamic = ShapedType.get_dynamic_size() - return _pto.PartitionTensorViewType.get([kDynamic] * rank, _resolve(elem)) + return _pto.PartitionTensorViewType.get( + [kDynamic] * rank, + _ensure_tensor_storage_dtype(elem, context="pto.part_tensor_view_type(...)"), + ) __all__ = [ "_DType", "_resolve", - "float32", "float16", "bf16", "int1", "int8", "int16", "int32", "int64", "index", + "float32", "float16", "bf16", + "f8e4m3", "f8e5m2", "hif8", "f4e1m2x2", "f4e2m1x2", + "int1", "int8", "int16", "int32", "int64", + "si8", "si16", "si32", "si64", + "ui8", "ui16", "ui32", "ui64", + "index", "ptr", "vreg_type", "mask_type", "tile_buf_type", "tensor_view_type", "part_tensor_view_type", ] diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index 36f473fe1..9de6159cd 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -23,10 +23,13 @@ # ── Types ───────────────────────────────────────────────────────────────────── from ._types import ( # noqa: F401 float32, float16, bf16, + f8e4m3, f8e5m2, hif8, f4e1m2x2, f4e2m1x2, int1, int8, int16, int32, int64, + si8, si16, si32, si64, + ui8, ui16, ui32, ui64, index, ptr, vreg_type, mask_type, - tile_buf_type, tensor_view_type, part_tensor_view_type, + tile_buf_type, _resolve, ) from ._surface_types import ( # noqa: F401 @@ -36,29 +39,59 @@ BarrierType, Pipe, MemorySpace, + MaskPattern, + CmpMode, + PredicatePart, + PredicateDist, + VStoreDist, + DeinterleaveDist, + InterleaveDist, + PostUpdate, + AlignType, TensorView, PartitionTensorView, Tile, ) from ._tensor_factories import empty_like # noqa: F401 +from ._tile_namespace import tile # noqa: F401 # ── Operations ──────────────────────────────────────────────────────────────── from ._ops import ( # noqa: F401 const, castptr, addptr, - vlds, vbrc_load, vsts, vsts_1pt, - plt_b32, pset_b32, - make_mask, - vadd, vmul, vmax, vdiv, - vcmax, vcadd, vdup, vexpdif, - vexp, vcgmax, vcgadd, vsubs, + vlds, vldas, vldus, vldsx2, vbrc_load, vsts, vsts_1pt, vstsx2, + init_align, + plt_b8, plt_b16, plt_b32, + pset_b8, pset_b16, pset_b32, + pge_b8, pge_b16, pge_b32, + make_mask, bytewidth, elements_per_vreg, + pand, por, pxor, pnot, psel, + pbitcast, + ppack, punpack, + pintlv_b8, pintlv_b16, pintlv_b32, + pdintlv_b8, pdintlv_b16, pdintlv_b32, + vgather2, vgather2_bc, vgatherb, vscatter, vsldb, vsstb, + vcmp, vcmps, + plds, psts, pstu, vstar, vstas, vstur, vstus, + vbitcast, + vadd, vsub, vmul, vdiv, vmax, vmin, + vand, vor, vxor, vshl, vshr, + vcmax, vcadd, vcmin, vdup, vexpdif, + vexp, vln, vsqrt, vabs, vneg, vrec, vrsqrt, vrelu, vnot, + vcgmax, vcgadd, vcgmin, vcpadd, + vadds, vsubs, vmuls, vmaxs, vmins, vlrelu, + vaxpy, vaddrelu, vsubrelu, + vsel, make_tensor_view, partition_view, - alloc_tile, tload, tstore, tmov, as_ptr, - mte_load, mte_store, mem_bar, - mte_l1_l0a, mte_l1_l0b, mte_l0c_ub, mad, + alloc_tile, as_ptr, + mte_load, mte_store, mte_gm_ub, mte_ub_gm, mte_ub_ub, mte_ub_l1, mem_bar, + mte_l1_l0a, mte_l1_l0b, mte_l0c_ub, + mad, mad_acc, mad_bias, mad_mx, mad_mx_acc, mad_mx_bias, get_block_idx, get_block_num, get_subblock_idx, get_subblock_num, store_vfsimt_info, get_tid_x, get_tid_y, get_tid_z, pipe_barrier, + get_buf, rls_buf, + set_cross_flag, wait_cross_flag, set_intra_flag, wait_intra_flag, set_flag, wait_flag, ) @@ -73,9 +106,6 @@ from ._jit import jit, KernelHandle # noqa: F401 from ._subkernels import ukernel, cube, simd, simt # noqa: F401 -# ── Scalar sub-namespace ────────────────────────────────────────────────────── -from . import scalar # noqa: F401 - # ── Shorthand dtype aliases ─────────────────────────────────────────────────── f32 = float32 f16 = float16 @@ -84,3 +114,6 @@ i16 = int16 i32 = int32 i64 = int64 +mask_b8 = mask_type("b8") +mask_b16 = mask_type("b16") +mask_b32 = mask_type("b32") diff --git a/ptodsl/ptodsl/scalar.py b/ptodsl/ptodsl/scalar.py index 40c973517..fef02ff96 100644 --- a/ptodsl/ptodsl/scalar.py +++ b/ptodsl/ptodsl/scalar.py @@ -6,7 +6,8 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. """ -Scalar arithmetic helpers – exposed as ``pto.scalar.*`` (or ``s = pto.scalar``). +Scalar arithmetic helpers – exposed as top-level ``scalar.*`` from the +``ptodsl`` package (for example ``from ptodsl import scalar``). Arithmetic helpers operate on raw ``mlir.ir.Value`` objects and emit the corresponding arith dialect operations at the active insertion point. @@ -18,43 +19,33 @@ from ._scalar_coercion import coerce_scalar_to_type from ._runtime_scalar_ops import ( classify_runtime_scalar_type, + emit_runtime_abs, + emit_runtime_binary_op, emit_runtime_max, + emit_runtime_min, ) from ._surface_values import resolve_address_access, unwrap_surface_value, wrap_surface_value from ._types import _resolve from mlir.dialects import arith from mlir.dialects import math -from mlir.dialects import pto as _pto from mlir.ir import IndexType, MemRefType, Operation - -_CMPI_PREDICATES = { - "eq": arith.CmpIPredicate.eq, - "ne": arith.CmpIPredicate.ne, - "slt": arith.CmpIPredicate.slt, - "sle": arith.CmpIPredicate.sle, - "sgt": arith.CmpIPredicate.sgt, - "sge": arith.CmpIPredicate.sge, - "ult": arith.CmpIPredicate.ult, - "ule": arith.CmpIPredicate.ule, - "ugt": arith.CmpIPredicate.ugt, - "uge": arith.CmpIPredicate.uge, -} +from mlir.dialects import pto as _pto def muli(lhs, rhs): """arith.muli""" - return wrap_surface_value(arith.MulIOp(unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result) + return wrap_surface_value(emit_runtime_binary_op("mul", unwrap_surface_value(lhs), unwrap_surface_value(rhs))) def addi(lhs, rhs): """arith.addi""" - return wrap_surface_value(arith.AddIOp(unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result) + return wrap_surface_value(emit_runtime_binary_op("add", unwrap_surface_value(lhs), unwrap_surface_value(rhs))) def subi(lhs, rhs): """arith.subi""" - return wrap_surface_value(arith.SubIOp(unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result) + return wrap_surface_value(emit_runtime_binary_op("sub", unwrap_surface_value(lhs), unwrap_surface_value(rhs))) def index_cast(type_or_val, val=None): @@ -72,32 +63,6 @@ def index_cast(type_or_val, val=None): return wrap_surface_value(arith.IndexCastOp(_resolve(type_or_val), unwrap_surface_value(val)).result) -def cmpi(pred: str, lhs, rhs): - """ - arith.cmpi with a named predicate string. - - ``pred`` is one of: ``"eq"``, ``"ne"``, ``"slt"``, ``"sle"``, - ``"sgt"``, ``"sge"``, ``"ult"``, ``"ule"``, ``"ugt"``, ``"uge"``. - """ - predicate = _CMPI_PREDICATES.get(pred) - if predicate is None: - raise ValueError( - f"Unknown cmpi predicate '{pred}'; known: {list(_CMPI_PREDICATES)}" - ) - return wrap_surface_value( - arith.CmpIOp(predicate, unwrap_surface_value(lhs), unwrap_surface_value(rhs)).result - ) - - -def cmpi_sgt(lhs, rhs): - """arith.cmpi sgt (signed greater-than).""" - return wrap_surface_value(arith.CmpIOp( - arith.CmpIPredicate.sgt, - unwrap_surface_value(lhs), - unwrap_surface_value(rhs), - ).result) - - def select(cond, true_val, false_val): """arith.select""" return wrap_surface_value(arith.SelectOp( @@ -115,6 +80,14 @@ def max(lhs, rhs): )) +def min(lhs, rhs): + """Runtime scalar minimum across float / integer / index values.""" + return wrap_surface_value(emit_runtime_min( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + )) + + def exp(value): """Runtime scalar exponential for floating-point values.""" raw_value = unwrap_surface_value(value) @@ -124,6 +97,29 @@ def exp(value): return wrap_surface_value(math.ExpOp(raw_value).result) +def log(value): + """Runtime scalar natural logarithm for floating-point values.""" + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind != "float": + raise TypeError(f"scalar.log(...) expects a floating-point runtime scalar, got {raw_value.type}") + return wrap_surface_value(math.LogOp(raw_value).result) + + +def sqrt(value): + """Runtime scalar square root for floating-point values.""" + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind != "float": + raise TypeError(f"scalar.sqrt(...) expects a floating-point runtime scalar, got {raw_value.type}") + return wrap_surface_value(math.SqrtOp(raw_value).result) + + +def abs(value): + """Runtime scalar absolute value across float / integer / index values.""" + return wrap_surface_value(emit_runtime_abs(unwrap_surface_value(value))) + + def load(ptr_or_ref, offset=None): """Load one scalar element from a PTODSL address view or tile element.""" buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) @@ -155,8 +151,7 @@ def _infer_buffer_element_type(buffer_type): __all__ = [ "muli", "addi", "subi", "index_cast", - "cmpi", "cmpi_sgt", "select", - "max", "exp", + "max", "min", "exp", "log", "sqrt", "abs", "load", "store", ] diff --git a/ptodsl/tests/test_vector_cube_ops.py b/ptodsl/tests/test_vector_cube_ops.py new file mode 100644 index 000000000..49782bc04 --- /dev/null +++ b/ptodsl/tests/test_vector_cube_ops.py @@ -0,0 +1,381 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import unittest +import inspect +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from ptodsl.ptodsl import _ops, pto + + +def _identity(value): + return value + + +class VectorCubeSurfaceTest(unittest.TestCase): + def test_public_namespace_exports_new_vector_and_cube_apis(self): + names = [ + "vsub", "vmin", "vand", "vor", "vxor", "vshl", "vshr", + "vln", "vsqrt", "vabs", "vneg", "vrec", "vrsqrt", "vrelu", "vnot", + "vcmin", "vcgmin", "vcpadd", + "vadds", "vmuls", "vmaxs", "vmins", "vlrelu", + "vaxpy", "vaddrelu", "vsubrelu", "vsel", + "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias", + ] + + for name in names: + self.assertTrue(hasattr(pto, name), name) + + def test_tile_bitwise_aliases_are_exposed_without_legacy_names(self): + preferred_names = [ + "bit_not", "bit_and", "bit_ands", "bit_or", "bit_ors", + "bit_xor", "bit_xors", "bit_shl", "bit_shls", "bit_shr", "bit_shrs", + ] + legacy_names = [ + "not_", "and_", "ands", "or_", "ors", "xor", "xors", "shl", "shls", "shr", "shrs", + ] + + for name in preferred_names: + with self.subTest(name=name): + self.assertTrue(hasattr(pto.tile, name), name) + + for name in legacy_names: + with self.subTest(name=name): + self.assertFalse(hasattr(pto.tile, name), name) + + def test_tile_partial_and_fillpad_names_are_exposed_without_legacy_names(self): + preferred_names = [ + "partadd", "partmul", "partmax", "partmin", + "fillpad", "fillpad_expand", "fillpad_inplace", + ] + legacy_names = [ + "part_add", "part_mul", "part_max", "part_min", + "fill_pad", "fill_pad_expand", "fill_pad_inplace", + ] + + for name in preferred_names: + with self.subTest(name=name): + self.assertTrue(hasattr(pto.tile, name), name) + + for name in legacy_names: + with self.subTest(name=name): + self.assertFalse(hasattr(pto.tile, name), name) + + def test_sync_flag_names_are_exposed_without_legacy_aliases(self): + preferred_names = [ + "set_cross_flag", "wait_cross_flag", + "set_intra_flag", "wait_intra_flag", + ] + legacy_names = [ + "set_cross_core", "wait_flag_dev", + "set_intra_block", "wait_intra_core", + ] + + for name in preferred_names: + with self.subTest(name=name): + self.assertTrue(hasattr(pto, name), name) + + for name in legacy_names: + with self.subTest(name=name): + self.assertFalse(hasattr(pto, name), name) + + def test_direct_vector_wrappers_dispatch_to_generated_ops(self): + lhs = SimpleNamespace(type="vec_ty") + rhs = SimpleNamespace(type="vec_ty") + mask = SimpleNamespace(type="mask_ty") + result = object() + + binary_cases = [ + ("vsub", "VsubOp", (lhs, rhs, mask)), + ("vmin", "VminOp", (lhs, rhs, mask)), + ("vand", "VandOp", (lhs, rhs, mask)), + ("vor", "VorOp", (lhs, rhs, mask)), + ("vxor", "VxorOp", (lhs, rhs, mask)), + ("vshl", "VshlOp", (lhs, rhs, mask)), + ("vshr", "VshrOp", (lhs, rhs, mask)), + ] + unary_cases = [ + ("vln", "VlnOp", (lhs, mask)), + ("vsqrt", "VsqrtOp", (lhs, mask)), + ("vabs", "VabsOp", (lhs, mask)), + ("vneg", "VnegOp", (lhs, mask)), + ("vrelu", "VreluOp", (lhs, mask)), + ("vnot", "VnotOp", (lhs, mask)), + ("vcmin", "VcminOp", (lhs, mask)), + ("vcpadd", "VcpaddOp", (lhs, mask)), + ] + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity): + for func_name, op_name, args in binary_cases + unary_cases: + with self.subTest(func=func_name): + fake_op = SimpleNamespace(result=result) + with patch.object(_ops._pto, op_name, return_value=fake_op) as op_ctor: + output = getattr(_ops, func_name)(*args) + self.assertIs(output, result) + self.assertEqual(op_ctor.call_args.args[0], "vec_ty") + + def test_vec_scalar_wrappers_and_vaxpy_coerce_scalar_operands(self): + vec = SimpleNamespace(type="vec_ty") + other = SimpleNamespace(type="vec_ty") + mask = SimpleNamespace(type="mask_ty") + scalar = object() + coerced_scalar = object() + result = object() + + vec_scalar_cases = [ + ("vadds", "VaddsOp"), + ("vmuls", "VmulsOp"), + ("vmaxs", "VmaxsOp"), + ("vmins", "VminsOp"), + ("vlrelu", "VlreluOp"), + ] + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_coerce_scalar_like_vector_element", return_value=coerced_scalar) as coerce_scalar: + for func_name, op_name in vec_scalar_cases: + with self.subTest(func=func_name): + fake_op = SimpleNamespace(result=result) + with patch.object(_ops._pto, op_name, return_value=fake_op) as op_ctor: + output = getattr(_ops, func_name)(vec, scalar, mask) + self.assertIs(output, result) + self.assertEqual(op_ctor.call_args.args, ("vec_ty", vec, coerced_scalar, mask)) + + fake_op = SimpleNamespace(result=result) + with patch.object(_ops._pto, "VaxpyOp", return_value=fake_op) as op_ctor: + output = _ops.vaxpy(scalar, vec, other, mask) + self.assertIs(output, result) + self.assertEqual(op_ctor.call_args.args, ("vec_ty", vec, other, coerced_scalar, mask)) + self.assertGreaterEqual(coerce_scalar.call_count, len(vec_scalar_cases) + 1) + + def test_composed_vector_wrappers_chain_existing_primitives(self): + vec = object() + rhs = object() + mask = object() + zero_vec = object() + one_vec = object() + sqrt_vec = object() + add_vec = object() + sub_vec = object() + relu_vec = object() + reciprocal_vec = object() + + with patch.object(_ops, "vmuls", return_value=zero_vec) as vmuls, \ + patch.object(_ops, "vadds", return_value=one_vec) as vadds, \ + patch.object(_ops, "vdiv", return_value=reciprocal_vec) as vdiv: + self.assertIs(_ops.vrec(vec, mask), reciprocal_vec) + vmuls.assert_called_once_with(vec, 0, mask) + vadds.assert_called_once_with(zero_vec, 1, mask) + vdiv.assert_called_once_with(one_vec, vec, mask) + + with patch.object(_ops, "vsqrt", return_value=sqrt_vec) as vsqrt, \ + patch.object(_ops, "vrec", return_value=reciprocal_vec) as vrec: + self.assertIs(_ops.vrsqrt(vec, mask), reciprocal_vec) + vsqrt.assert_called_once_with(vec, mask) + vrec.assert_called_once_with(sqrt_vec, mask) + + with patch.object(_ops, "vadd", return_value=add_vec) as vadd, \ + patch.object(_ops, "vrelu", return_value=relu_vec) as vrelu: + self.assertIs(_ops.vaddrelu(vec, rhs, mask), relu_vec) + vadd.assert_called_once_with(vec, rhs, mask) + vrelu.assert_called_once_with(add_vec, mask) + + with patch.object(_ops, "vsub", return_value=sub_vec) as vsub, \ + patch.object(_ops, "vrelu", return_value=relu_vec) as vrelu: + self.assertIs(_ops.vsubrelu(vec, rhs, mask), relu_vec) + vsub.assert_called_once_with(vec, rhs, mask) + vrelu.assert_called_once_with(sub_vec, mask) + + def test_vcgmin_and_vsel_dispatch_correctly(self): + vec = SimpleNamespace(type="vec_ty") + other = SimpleNamespace(type="vec_ty") + mask = SimpleNamespace(type="mask_ty") + reduced = object() + scalar = object() + selected = object() + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_extract_lowest_lane_scalar", return_value=scalar) as extract_scalar, \ + patch.object(_ops._pto, "VcgminOp", return_value=SimpleNamespace(result=reduced)) as vcgmin_op: + output = _ops.vcgmin(vec, mask) + self.assertIs(output, scalar) + self.assertEqual(vcgmin_op.call_args.args, ("vec_ty", vec, mask)) + extract_scalar.assert_called_once_with(reduced, mask) + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity), \ + patch.object(_ops._pto, "VselOp", return_value=SimpleNamespace(result=selected)) as vsel_op: + output = _ops.vsel(vec, other, mask) + self.assertIs(output, selected) + self.assertEqual(vsel_op.call_args.args, ("vec_ty", vec, other, mask)) + + def test_cube_variant_wrappers_dispatch_to_generated_ops(self): + lhs = object() + rhs = object() + dst = object() + bias = object() + + cube_cases = [ + ("mad_acc", "MadAccOp", (lhs, rhs, dst, 1, 2, 3), (lhs, rhs, dst, "i64:1", "i64:2", "i64:3")), + ("mad_bias", "MadBiasOp", (lhs, rhs, dst, bias, 1, 2, 3), (lhs, rhs, dst, bias, "i64:1", "i64:2", "i64:3")), + ("mad_mx", "MadMxOp", (lhs, rhs, dst, 1, 2, 3), (lhs, rhs, dst, "i64:1", "i64:2", "i64:3")), + ("mad_mx_acc", "MadMxAccOp", (lhs, rhs, dst, 1, 2, 3), (lhs, rhs, dst, "i64:1", "i64:2", "i64:3")), + ("mad_mx_bias", "MadMxBiasOp", (lhs, rhs, dst, bias, 1, 2, 3), (lhs, rhs, dst, bias, "i64:1", "i64:2", "i64:3")), + ] + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_coerce_i64", side_effect=lambda value, *, context: f"i64:{value}"): + for func_name, op_name, args, expected_call in cube_cases: + with self.subTest(func=func_name): + op_ctor = MagicMock() + with patch.object(_ops._pto, op_name, op_ctor): + getattr(_ops, func_name)(*args) + self.assertEqual(op_ctor.call_args.args, expected_call) + + def test_tile_selection_surface_exposes_optional_tmp(self): + for func, expected in [ + (_ops.tsel, ["mask", "src0", "src1", "dst", "tmp"]), + (_ops.tsels, ["mask", "src", "scalar", "dst", "tmp"]), + (pto.tile.sel, ["mask", "src0", "src1", "dst", "tmp"]), + (pto.tile.sels, ["mask", "src", "scalar", "dst", "tmp"]), + ]: + with self.subTest(func=func): + signature = inspect.signature(func) + self.assertEqual(list(signature.parameters.keys()), expected) + self.assertEqual(signature.parameters["tmp"].kind, inspect.Parameter.KEYWORD_ONLY) + self.assertIsNone(signature.parameters["tmp"].default) + + def test_tile_selection_wrappers_use_explicit_tmp_or_synthesize_one(self): + mask = object() + src0 = object() + src1 = object() + src = object() + dst = object() + tmp = object() + scalar = object() + coerced_scalar = object() + synthesized_tmp = object() + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_coerce_tile_scalar_operand", return_value=coerced_scalar): + with patch.object(_ops, "_resolve_selection_tmp", return_value=synthesized_tmp) as resolve_tmp, \ + patch.object(_ops._pto, "tsel") as tsel_op: + _ops.tsel(mask, src0, src1, dst) + resolve_tmp.assert_called_once_with(dst, None, context="tsel") + self.assertEqual(tsel_op.call_args.args, (mask, src0, src1, synthesized_tmp, dst)) + + with patch.object(_ops, "_resolve_selection_tmp", side_effect=AssertionError("should not synthesize")), \ + patch.object(_ops._pto, "tsel") as tsel_op: + _ops.tsel(mask, src0, src1, dst, tmp=tmp) + self.assertEqual(tsel_op.call_args.args, (mask, src0, src1, tmp, dst)) + + with patch.object(_ops, "_resolve_selection_tmp", return_value=synthesized_tmp) as resolve_tmp, \ + patch.object(_ops._pto, "tsels") as tsels_op: + _ops.tsels(mask, src, scalar, dst) + resolve_tmp.assert_called_once_with(dst, None, context="tsels") + self.assertEqual(tsels_op.call_args.args, (mask, src, synthesized_tmp, coerced_scalar, dst)) + + with patch.object(_ops, "_resolve_selection_tmp", side_effect=AssertionError("should not synthesize")), \ + patch.object(_ops._pto, "tsels") as tsels_op: + _ops.tsels(mask, src, scalar, dst, tmp=tmp) + self.assertEqual(tsels_op.call_args.args, (mask, src, tmp, coerced_scalar, dst)) + + def test_tile_row_reductions_expose_optional_tmp_and_synthesize_one(self): + src = SimpleNamespace(type="src_ty") + dst = object() + tmp = object() + synthesized_tmp = object() + + row_cases = [ + ("rowsum", "trowsum"), + ("rowmax", "trowmax"), + ("rowmin", "trowmin"), + ("rowprod", "trowprod"), + ("rowargmax", "trowargmax"), + ("rowargmin", "trowargmin"), + ] + + for name, low_level_name in row_cases: + with self.subTest(func=name): + signature = inspect.signature(getattr(pto.tile, name)) + self.assertEqual(list(signature.parameters.keys()), ["src", "dst", "tmp"]) + self.assertEqual(signature.parameters["tmp"].kind, inspect.Parameter.KEYWORD_ONLY) + self.assertIsNone(signature.parameters["tmp"].default) + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "alloc_tile", return_value=synthesized_tmp) as alloc_tile, \ + patch.object(_ops, low_level_name) as low_level_op: + getattr(pto.tile, name)(src, dst) + alloc_tile.assert_called_once_with(tile_type="src_ty") + low_level_op.assert_called_once_with(src, synthesized_tmp, dst) + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "alloc_tile", side_effect=AssertionError("should not synthesize")), \ + patch.object(_ops, low_level_name) as low_level_op: + getattr(pto.tile, name)(src, dst, tmp=tmp) + low_level_op.assert_called_once_with(src, tmp, dst) + + def test_sync_event_id_rejects_out_of_range_static_values(self): + cases = [ + (_ops.set_flag, ("MTE2", "V"), {"event_id": 8}, "set_flag(..., event_id=...)"), + (_ops.wait_flag, ("MTE2", "V"), {"event_id": -1}, "wait_flag(..., event_id=...)"), + (_ops.set_cross_flag, (pto.Pipe.FIX, 8), {}, "set_cross_flag(..., event_id=...)"), + (_ops.wait_cross_flag, (pto.Pipe.FIX, -1), {}, "wait_cross_flag(..., event_id=...)"), + (_ops.set_intra_flag, (pto.Pipe.MTE3, 9), {}, "set_intra_flag(..., event_id=...)"), + (_ops.wait_intra_flag, (pto.Pipe.V, -2), {}, "wait_intra_flag(..., event_id=...)"), + ] + + with patch.object(_ops._pto, "set_flag") as set_flag_op, \ + patch.object(_ops._pto, "set_flag_dyn") as set_flag_dyn_op, \ + patch.object(_ops._pto, "wait_flag") as wait_flag_op, \ + patch.object(_ops._pto, "wait_flag_dyn") as wait_flag_dyn_op, \ + patch.object(_ops._pto, "sync_set") as sync_set_op, \ + patch.object(_ops._pto, "sync_wait") as sync_wait_op: + for func, args, kwargs, context in cases: + with self.subTest(func=func.__name__, event_id=kwargs.get("event_id", args[-1])): + with self.assertRaises(ValueError) as exc: + func(*args, **kwargs) + message = str(exc.exception) + self.assertIn(context, message) + self.assertIn("[0, 7]", message) + + set_flag_op.assert_not_called() + set_flag_dyn_op.assert_not_called() + wait_flag_op.assert_not_called() + wait_flag_dyn_op.assert_not_called() + sync_set_op.assert_not_called() + sync_wait_op.assert_not_called() + + def test_sync_facades_reject_illegal_pipe_endpoints(self): + cases = [ + (_ops.set_cross_flag, (pto.Pipe.V, 0), "set_cross_flag(pipe, event_id)", "", ""), + (_ops.wait_cross_flag, (pto.Pipe.MTE3, 0), "wait_cross_flag(pipe, event_id)", "", ""), + (_ops.set_intra_flag, (pto.Pipe.FIX, 0), "set_intra_flag(pipe, event_id)", "", ""), + (_ops.wait_intra_flag, (pto.Pipe.MTE3, 0), "wait_intra_flag(pipe, event_id)", "", ""), + ] + + with patch.object(_ops._pto, "sync_set") as sync_set_op, \ + patch.object(_ops._pto, "sync_wait") as sync_wait_op: + for func, args, context, expected, actual in cases: + with self.subTest(func=func.__name__, pipe=args[0]): + with self.assertRaises(ValueError) as exc: + func(*args) + message = str(exc.exception) + self.assertIn(context, message) + self.assertIn(expected, message) + self.assertIn(actual, message) + + sync_set_op.assert_not_called() + sync_wait_op.assert_not_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/python/ptodsl_docs_as_test.py b/test/python/ptodsl_docs_as_test.py new file mode 100644 index 000000000..699167855 --- /dev/null +++ b/test/python/ptodsl_docs_as_test.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable +import json +import re +import shutil +import subprocess +import sys +import tempfile + + +REPO_ROOT = Path(__file__).resolve().parents[2] +USER_GUIDE_ROOT = REPO_ROOT / "ptodsl" / "docs" / "user_guide" +sys.path.insert(0, str(REPO_ROOT / "ptodsl")) + +from ptodsl import pto, scalar +from ptodsl._bootstrap import make_context +from mlir.ir import Module +from ptodsl_docs_fragment_fixtures import FRAGMENT_FIXTURES, render_fragment_fixture + +FENCE_RE = re.compile(r"^```(?P[A-Za-z0-9_+-]*)\s*$") +META_RE = re.compile(r"^\s*\s*$") + + +@dataclass(frozen=True) +class MarkdownCodeBlock: + path: Path + start_line: int + end_line: int + language: str + lines: tuple[str, ...] + metadata: "DocBlockMetadata | None" + + @property + def text(self) -> str: + return "".join(self.lines) + + +@dataclass(frozen=True) +class MarkdownScanResult: + path: Path + blocks: tuple[MarkdownCodeBlock, ...] + + +@dataclass(frozen=True) +class DocBlockMetadata: + kind: str + body: str + line: int + raw: str + + +@dataclass(frozen=True) +class DocTestDirective: + mode: str + symbol: str | None = None + compile_kwargs: dict[str, object] | None = None + fixture: str | None = None + + +def expect(condition: bool, message: str) -> None: + if not condition: + raise AssertionError(message) + + +def format_doc_context(path: Path, start_line: int, symbol: str | None = None) -> str: + symbol_text = symbol if symbol is not None else "" + return f"{path}:{start_line} [symbol={symbol_text}]" + + +def fail_doc(path: Path, start_line: int, message: str, symbol: str | None = None) -> None: + raise AssertionError(f"{format_doc_context(path, start_line, symbol)}: {message}") + + +def iter_markdown_files(root: Path) -> Iterable[Path]: + yield from sorted(root.glob("*.md")) + + +def parse_metadata_line(path: Path, line: str, line_number: int) -> DocBlockMetadata | None: + match = META_RE.match(line) + if match is None: + return None + + kind = match.group("kind") + body = match.group("body").strip() + expect(body, f"{format_doc_context(path, line_number)}: ptodsl-doc-{kind} metadata must not be empty") + if kind == "test": + try: + json.loads(body) + except json.JSONDecodeError as exc: + raise AssertionError( + f"{format_doc_context(path, line_number)}: ptodsl-doc-test metadata must be valid JSON: {exc.msg}" + ) from exc + return DocBlockMetadata(kind=kind, body=body, line=line_number, raw=line.rstrip("\n")) + + +def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> DocBlockMetadata | None: + candidate = fence_line - 2 + while candidate >= 0 and not lines[candidate].strip(): + candidate -= 1 + if candidate < 0: + return None + line = lines[candidate] + if line.lstrip().startswith(" ```python @@ -80,7 +81,7 @@ def flash_attention(Q, K, V, *, O=None, causal=False): return O ``` -### L1 — `@pto.jit` +### `@pto.jit` — the kernel entry Decorating a function with `@pto.jit` marks it as a launchable PTO kernel. This decoration means: @@ -110,22 +111,28 @@ def flash_attention_kernel( return ``` -L1 is the primary layer for expressing **tile-level semantics**. Inside `@pto.jit`, you allocate tile buffers (`alloc_tile`), move data between GM and UB at block granularity (`tile.load`, `tile.store`), and perform tile-level compute (`tile.add`, `tile.exp`, `tile.rowsum`, etc.). When the built-in Tile Ops are not sufficient, you can drop down to `@pto.ukernel` to write custom tile-level semantics with micro-instructions. +`@pto.jit` is the only host-visible kernel entry. Its `mode` selects the +programming model: + +- `mode="auto"` (the default) is **tile-centric**. You allocate tiles, partition + GM views, use Tile Ops (`tile.load`, `tile.store`, `tile.add`, ...), and call + compute sub-kernels. The compiler manages staging and scheduling around the + tile abstraction. +- `mode="explicit"` is **tile + micro-instruction**. You keep the same tile + surface from `auto`, but also gain access to the full micro-instruction + set — MTE ops (`mte_load`, `mte_store`, ...), explicit synchronization, + and direct pointer manipulation — so you can reach below the tile abstraction + and control individual instructions when needed. + +In both modes, `@pto.jit` is where you allocate tiles (`alloc_tile`) and use +Tile Ops. The difference is that `explicit` additionally opens up the +micro-instruction surface — MTE ops, explicit sync, and pointer-level +control — so you can mix tile operations with hand-authored instructions in +the same kernel. The SPMD launch contract is also owned here: the runtime grid (e.g., `batch * heads` blocks) is declared at the call site, and block/subblock indices are queried via `pto.get_block_idx()` and friends. -### L2 — `@pto.ukernel` - -`@pto.ukernel` (short for *micro-instruction kernel*) is the entry point for expressing **PTO micro-instruction semantics**. Where L1 works with tile buffers as opaque wholes, L2 gives you direct control over individual MTE, vector, and scalar instructions. This layer is intended for users who pursue peak performance and need precise control over low-level hardware details — instruction ordering, DMA scheduling, per-byte data placement, and synchronization. - -Inside a ukernel, you write instructions targeting the three hardware units, and orchestrate data movement between them via **MTE Ops**: - -- **MTE Ops** (`mte_load`, `mte_store`, `copy_gm_to_ubuf`, etc.) move data between GM and UB, or between UB regions, at the DMA engine level. -- **`@pto.cube`**, **`@pto.simd`**, and **`@pto.simt`** sub-kernels execute the actual compute on their respective hardware units. - -The ukernel manages the execution sandwich for one block: staging data with MTE Ops, issuing synchronization barriers, dispatching sub-kernels, and managing loop-carried state between invocations. - -### L3 — `@pto.cube` / `@pto.simd` / `@pto.simt` +### Sub-kernels — `@pto.cube` / `@pto.simd` / `@pto.simt` These are hardware-bound compute sub-kernels, each mapped to a specific NPU compute unit: @@ -135,7 +142,9 @@ These are hardware-bound compute sub-kernels, each mapped to a specific NPU comp - **`@pto.simt`** is a scalar-programmable processor group that executes scalar instructions across many work-items in parallel. Typical operations: `lds`, `sts`, scalar arithmetic and comparison. Well-suited for per-element tile walks, boundary metadata, and pointwise blends. -L3 sub-kernels can be invoked in two ways: as named decorated functions (`@pto.cube` / `@pto.simd` / `@pto.simt`) — reusable and callable from `@pto.ukernel` or directly from `@pto.jit` — or inline as context managers (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) for quick prototyping. When called directly from `@pto.jit`, you stage data with `tile.load`/`tile.store` instead of `mte_load`/`mte_store`; PTOAS handles the synchronization between Tile Ops and L3 compute automatically. +Each can be invoked as a named decorated function (`@pto.cube` / +`@pto.simd` / `@pto.simt`) or inline as a context manager +(`with pto.cube():`, `with pto.simd():`, `with pto.simt():`). The boundary contract is strict: vreg values do not escape a simd kernel, cube-local state does not leak into UB, and data crosses layer boundaries only through UB-backed tiles or typed UB pointers. @@ -159,17 +168,27 @@ Chapter 5 (Control Flow) and Chapter 6 (Scalar & Pointer Operations) cover this ## 1.4 A worked example -The flash attention kernel from Section 1.2 is not just an architectural diagram — it is a complete, runnable design sketch distributed with PTODSL (`examplesflash_attention_sketch.py`). Here is how the layers map to actual code: +The flash attention kernel from Section 1.2 is not just an architectural diagram — it is a complete, runnable design sketch distributed with PTODSL (`examples/flash_attention_sketch.py`). Here is how the layers map to actual code: -**L1 (`@pto.jit`)** allocates tiles for the Q block, KV block, online-softmax state (m/l/o ping-pong tiles), and cube-local scratch. It loops over Q blocks (outer `pto.for_`) and KV blocks (inner `pto.for_` with carry state), calling `kv_block_process` for each KV block and using `tile.load`/`tile.store` at the GM boundary. +**Top-level `@pto.jit` schedule** allocates tiles for the Q block, KV block, +online-softmax state (m/l/o ping-pong tiles), and cube-local scratch. It loops +over Q blocks (outer `pto.for_`) and KV blocks (inner `pto.for_` with carry +state), and uses `tile.load`/`tile.store` at the GM boundary. -**L2 (`@pto.ukernel`)** stages the current K and V blocks with `mte_load`, issues `pipe_barrier(Pipe.ALL)` at phase boundaries, then sequences four sub-kernel calls: `qk_matmul` (cube), `online_softmax_rows` (simd), `pv_matmul` (cube), `blend_output_rows` (simt). +**`mode="explicit"` orchestration path** stages the current K and V blocks with +`mte_load`, issues `pipe_barrier(Pipe.ALL)` at phase boundaries, then +sequences four sub-kernel calls: `qk_matmul` (cube), +`online_softmax_rows` (simd), `pv_matmul` (cube), `blend_output_rows` (simt). -**L3a (`@pto.cube`)** performs `mte_l1_l0a` / `mte_l1_l0b` / `mad` / `mte_l0c_ub` for both QK^T and P@V products. +**`@pto.cube`** performs `mte_l1_l0a` / `mte_l1_l0b` / `mad` / +`mte_l0c_ub` for both QK^T and P@V products. -**L3b (`@pto.simd`)** implements the online softmax update: per-row max, exp, sum, and alpha/beta computation using vector ops (`vlds`, `vcgmax`, `vexp`, `vcgadd`, `vsts`). +**`@pto.simd`** implements the online softmax update: per-row max, exp, sum, +and alpha/beta computation using vector ops (`vlds`, `vcgmax`, `vexp`, +`vcgadd`, `vsts`). -**L3c (`@pto.simt`)** blends the old and new output accumulators with per-element `lds`/`sts` and scalar arithmetic. +**`@pto.simt`** blends the old and new output accumulators with per-element +`lds`/`sts` and scalar arithmetic. Chapter 11 walks through this example in full detail. @@ -189,7 +208,7 @@ Chapter 11 walks through this example in full detail. |---------|-------| | 1 | Introduction (this chapter) | | 2 | Quick Start — a minimal working kernel | -| 3 | Kernel entry points: `@pto.jit`, `@pto.ukernel`, `@pto.cube`, `@pto.simd`, `@pto.simt` | +| 3 | Kernel entry and sub-kernels: `@pto.jit(mode=...)`, `@pto.cube`, `@pto.simd`, `@pto.simt` | | 4 | Type system and buffer management: scalars, tiles, views, allocation | | 5 | Control flow: trace-time Python vs device-side `pto.for_` / `pto.if_` | | 6 | Scalar and pointer operations | diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md index 2a219c386..27b77701d 100644 --- a/ptodsl/docs/user_guide/02-quick-start.md +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -172,13 +172,18 @@ block_num = pto.get_block_num() This lets you map different data slices to different blocks — for example, one block per (batch, head) pair in flash attention. -## 2.5 Dropping down to micro-instructions +## 2.5 Adding sub-kernels and explicit orchestration -The examples above used Tile Ops (`tile.load` / `tile.store` here, and arithmetic Tile Ops in later chapters), which operate on entire tiles at once. When you need finer control — for instance, writing a custom softmax or an activation that maps directly to vector hardware — you can drop down to the micro-instruction level. This involves three layers working together: +The examples above used Tile Ops (`tile.load` / `tile.store` here, and +arithmetic Tile Ops in later chapters), which operate on entire tiles at once. +When you need finer control — for instance, writing a custom softmax or an +activation that maps directly to vector hardware — you can keep the same +`@pto.jit` entry and add sub-kernels. If you also need micro-instruction control, +switch that kernel to `mode="explicit"`: ```python -# L3: hardware-bound SIMD kernel — vector instructions on individual rows. +# SIMD sub-kernel — vector instructions on individual rows. @pto.simd def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, rows: pto.index, cols: pto.index): @@ -196,30 +201,8 @@ def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, pto.vsts(o_vec, o_tile[r, c:], mask) col_loop.update(remained=remained) - -# L2: ukernel — DMA staging, then dispatch the SIMD kernel. -@pto.ukernel -def add_block(a_part: pto.PartitionTensorView, - b_part: pto.PartitionTensorView, - o_part: pto.PartitionTensorView, - a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, - rows: pto.index, cols: pto.index): - row_bytes = cols * pto.bytewidth(pto.f32) - pto.mte_load(a_part.as_ptr(), a_tile.as_ptr(), 0, row_bytes, - nburst=(rows, 0, 0)) - pto.mte_load(b_part.as_ptr(), b_tile.as_ptr(), 0, row_bytes, - nburst=(rows, 0, 0)) - pto.pipe_barrier(pto.Pipe.ALL) - - add_rows(a_tile, b_tile, o_tile, rows, cols) - pto.pipe_barrier(pto.Pipe.ALL) - - pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), row_bytes, - nburst=(rows, 0, 0)) - - -# L1: JIT entry — tile allocation, partitioning, launch. -@pto.jit(target="a5") +# Single kernel entry in explicit mode — micro-instruction staging plus SIMD sub-kernel. +@pto.jit(target="a5", mode="explicit") def vec_add_micro( A: pto.tensor_spec(rank=1, dtype=pto.f32), B: pto.tensor_spec(rank=1, dtype=pto.f32), @@ -243,13 +226,34 @@ def vec_add_micro( a_part = pto.partition_view(a_view, offsets=[offset], sizes=[this_block]) b_part = pto.partition_view(b_view, offsets=[offset], sizes=[this_block]) o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) - add_block(a_part, b_part, o_part, a_tile, b_tile, o_tile, 1, this_block) + row_bytes = this_block * pto.bytewidth(pto.f32) + pto.mte_load(a_part.as_ptr(), a_tile.as_ptr(), 0, row_bytes, + nburst=(1, 0, 0)) + pto.mte_load(b_part.as_ptr(), b_tile.as_ptr(), 0, row_bytes, + nburst=(1, 0, 0)) + pto.pipe_barrier(pto.Pipe.ALL) + add_rows(a_tile, b_tile, o_tile, 1, this_block) + pto.pipe_barrier(pto.Pipe.ALL) + pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), row_bytes, + nburst=(1, 0, 0)) ``` -- **L1 `@pto.jit`**: allocates tiles, partitions the GM views, and loops over blocks — the same tile-level orchestration as Section 2.2, but now calling a ukernel instead of Tile Ops. - -- **L2 `@pto.ukernel`**: stages data with ptr-based `mte_load`, inserts explicit `pipe_barrier` phase boundaries, dispatches the SIMD kernel, synchronizes again, then writes back with `mte_store`. The ukernel owns the hardware-level sequencing. - -- **L3 `@pto.simd`**: the outer `pto.for_` iterates over rows, the inner `pto.for_` iterates over column chunks of the hardware vector width (`elements_per_vreg`). Each iteration loads a vector-width slice into a `vreg`, does the addition under a mask (for tail elements), and stores the result back. Both loops are recorded as structured control flow IR — the compiler decides whether to keep them or unroll them. - -Chapter 3 covers the full decorator family; Chapters 7–10 cover each operation family in detail. +- **`@pto.jit(mode="explicit")`**: allocates tiles, partitions the GM views, + loops over blocks, and directly authors the micro-instruction schedule for + each block. + +- **`@pto.simd` sub-kernel**: the top-level kernel calls a SIMD sub-kernel + for the row-wise vector work while keeping instruction staging in the + explicit entry body. + +- **Inside `@pto.simd`**: the outer `pto.for_` iterates over rows, the inner + `pto.for_` iterates over column chunks of the hardware vector width + (`elements_per_vreg`). Each iteration loads a vector-width slice into a + `vreg`, does the addition under a mask (for tail elements), and stores the + result back. Both loops are recorded as structured control flow IR — the + compiler decides whether to keep them or unroll them. + +The same pattern also has an `auto` counterpart: keep `@pto.jit` in its +default mode and replace the explicit `mte_*` sequence with `tile.load` / +`tile.store`. Chapter 3 covers the full entry model; Chapters 7–10 cover each +operation family in detail. diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index b04a4c682..0148c60f0 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -1,35 +1,35 @@ -# 3. Kernel Entry Points and Sub-Kernels +# 3. Kernel Entry and Sub-Kernels -PTODSL provides five decorators that mark functions as PTO kernels, plus three context managers for inline use. This chapter is a reference for each entry point — its role, parameter contract, and boundary constraints. +PTODSL provides one host-visible kernel decorator (`@pto.jit`) and three +compute-unit sub-kernel decorators (`@pto.cube`, `@pto.simd`, `@pto.simt`), +plus matching context managers for inline use. This chapter covers the kernel +entry, the two programming models, sub-kernel reference, parameter contracts, +and boundary constraints. -## 3.1 Decorator family overview +## 3.1 `@pto.jit` — the only kernel entry -``` -@pto.jit L1 Top-level JIT entry — compile, cache, launch -@pto.ukernel L2 Micro-instruction orchestration (MTE + sync) -@pto.cube L3 Matrix multiplication on the Cube unit -@pto.simd L3 Vector math on the SIMD unit -@pto.simt L3 Scalar compute on the SIMT unit -``` - -L3 sub-kernels can be invoked in two ways: - -1. **As decorated functions** (`@pto.cube` / `@pto.simd` / `@pto.simt`) — reusable, named sub-kernels that can be called from `@pto.ukernel` or directly from `@pto.jit`. -2. **As context managers** (`with pto.cube():` / `with pto.simd():` / `with pto.simt():`) — inline L3 blocks for quick prototyping or one-off compute snippets inside `@pto.jit` or `@pto.ukernel`. - -Calling an L3 sub-kernel directly from `@pto.jit` skips the ukernel layer: you stage data with `tile.load`/`tile.store` instead of `mte_load`/`mte_store`, and PTOAS handles the synchronization between Tile Ops and L3 compute automatically. This is the recommended path for most users — drop down to `@pto.ukernel` only when you need explicit control over micro-instruction ordering and synchronization. +Decorator overview: -## 3.2 `@pto.jit` — top-level JIT entry +```text +@pto.jit(mode="auto") tile-first authoring, compiler-managed staging +@pto.jit(mode="explicit") micro-instruction authoring, user-managed staging +@pto.cube Cube-unit matrix sub-kernel +@pto.simd SIMD-unit vector sub-kernel +@pto.simt SIMT-unit scalar sub-kernel +``` ### Role -`@pto.jit` marks a function as a launchable PTO kernel. It owns compilation (tracing + lowering), caching, and runtime launch binding. This is the only decorator that can be invoked directly from the host — all other decorators define sub-kernels that are called from within `@pto.jit` or `@pto.ukernel`. +`@pto.jit` marks a function as a launchable PTO kernel. It owns compilation +(tracing + lowering), caching, and runtime launch binding. This is the only +decorator that can be invoked directly from the host; the compute-unit +decorators define sub-kernels that are called from within `@pto.jit`. ### Signature ```python -@pto.jit(target="a5") +@pto.jit(target="a5", mode="auto") def kernel_name( tensor_arg_1: pto.tensor_spec(rank=1, dtype=pto.f32), # Python-native tensor (positional) tensor_arg_2: pto.tensor_spec(rank=1, dtype=pto.f32), # Python-native tensor (positional) @@ -41,9 +41,28 @@ def kernel_name( return ``` -**Positional parameters** are Python-native tensors — they arrive from NumPy, torch-npu, or any framework with `.shape` and `.strides`. Inside the body, wrap them with `make_tensor_view` to create GM descriptors. +**Positional parameters** are Python-native tensors — they arrive from NumPy, +torch-npu, or any framework with `.shape` and `.strides`. Inside the body, wrap +them with `make_tensor_view` to create GM descriptors. + +**Keyword-only parameters** annotated with `pto.constexpr` are compile-time +constants. They must be provided at `.compile()` time and cannot change between +launches of the same compiled kernel. Use them for tile sizes, algorithmic knobs +(e.g., `CAUSAL`), and other values that the compiler can specialize against. + +### `mode`: auto vs explicit -**Keyword-only parameters** annotated with `pto.constexpr` are compile-time constants. They must be provided at `.compile()` time and cannot change between launches of the same compiled kernel. Use them for tile sizes, algorithmic knobs (e.g., `CAUSAL`), and other values that the compiler can specialize against. +`mode` is a keyword on the decorator, not a function parameter. It selects the +programming model: + +- `mode="auto"` (the default) is **tile-centric**. You write kernels in terms + of tiles and Tile Ops. The compiler manages staging, scheduling, and + synchronization around the tile abstraction. +- `mode="explicit"` adds the full **micro-instruction** surface — MTE ops, + explicit synchronization, and direct pointer manipulation — on top of + everything available in `auto`. + +Section 3.2 covers the two models in detail. ### Compilation and launch @@ -56,8 +75,13 @@ compiled = kernel_name.compile(CONST_A=128, CONST_B=64) compiled[grid, stream](tensor_1, tensor_2, ...) ``` -- `.compile(**constexprs)` — traces the kernel body with the given constexpr values, lowers the IR, and returns a compiled handle. Subsequent calls with the same specialization key (function identity, tensor ABI signature, constexpr values) hit the cache. -- `compiled[grid, stream](args...)` — launches the compiled kernel. `grid` is the number of SPMD blocks (an integer); `stream` is the NPU stream (`None` for default). +- `.compile(**constexprs)` — traces the kernel body with the given constexpr + values, lowers the IR, and returns a compiled handle. Subsequent calls with + the same specialization key (function identity, tensor ABI signature, + constexpr values) hit the cache. +- `compiled[grid, stream](args...)` — launches the compiled kernel. `grid` is + the number of SPMD blocks (an integer); `stream` is the NPU stream (`None` + for default). ### SPMD built-ins @@ -70,10 +94,10 @@ Available inside a `@pto.jit` body: | `pto.get_subblock_idx()` | `int` | Index of the current sub-block | | `pto.get_subblock_num()` | `int` | Total number of sub-blocks | -### Typical body +### Typical body (auto mode) ```python -@pto.jit(target="a5") +@pto.jit(target="a5", mode="auto") def my_kernel( A: pto.tensor_spec(rank=2, dtype=pto.f32), B: pto.tensor_spec(rank=2, dtype=pto.f32), @@ -102,9 +126,14 @@ def my_kernel( pto.tile.store(o_tile, o_part) ``` -### Calling L3 sub-kernels directly +### Custom sub-kernels -When you call an L3 sub-kernel directly from `@pto.jit`, data movement is handled by Tile Ops (`tile.load`/`tile.store`) instead of MTE micro-instructions. PTOAS handles the synchronization between Tile Ops and L3 compute — the sub-kernel itself is unchanged: +When Tile Ops don't cover the computation you need — a custom softmax, a +specialized activation, per-element blending — you write a sub-kernel in +`@pto.simd`, `@pto.simt`, or `@pto.cube` and call it directly from +`@pto.jit`. In auto mode, data movement stays with Tile Ops +(`tile.load`/`tile.store`) and PTOAS handles the synchronization between Tile +Ops and the sub-kernel: ```python @@ -130,7 +159,7 @@ def add_rows( pto.vsts(o_vec, o_tile[r, c:], mask) col_loop.update(remained=remained) -@pto.jit(target="a5") +@pto.jit(target="a5", mode="auto") def my_kernel( A: pto.tensor_spec(rank=2, dtype=pto.f32), B: pto.tensor_spec(rank=2, dtype=pto.f32), @@ -153,30 +182,57 @@ def my_kernel( b_part = pto.partition_view(b_view, offsets=[row, 0], sizes=[1, cols]) o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) - # Tile Ops stage data from GM to UB (replaces mte_load at L1) pto.tile.load(a_part, a_tile) pto.tile.load(b_part, b_tile) - # Direct L3 call — PTOAS handles sync between tile.load and compute add_rows(a_tile, b_tile, o_tile, 1, cols) pto.tile.store(o_tile, o_part) ``` -This is the recommended path for users who want hardware-unit compute without writing explicit MTE Ops and manual sync. Mixing direct L3 calls with Tile Ops and ukernel calls in the same `@pto.jit` body is supported — the compiler unifies the lowering. +Sub-kernels are the mechanism for custom compute in PTODSL — when Tile Ops +cover your needs, you don't need one; when they don't, a sub-kernel gives you +direct access to the hardware unit. In auto mode, a sub-kernel's parameters +are restricted to `Tile` and PTO scalar types — the compiler owns staging and +sync. In explicit mode, sub-kernels may also accept `PartitionTensorView` and +`pto.ptr` parameters, matching the richer type surface available there. +Section 3.3 covers each sub-kernel decorator in detail. -## 3.3 `@pto.ukernel` — micro-instruction orchestration +## 3.2 Programming models: auto vs explicit -### Role +`@pto.jit` exposes a single entry with two programming models. The entry's +host ABI, compilation flow, and launch mechanism are identical in both — the +difference is what you can write inside the kernel body. -`@pto.ukernel` (short for *micro-instruction kernel*) is the entry point for writing PTO micro-instructions directly. Unlike `@pto.jit` where you work with tile-level ops (`tile.load`, `tile.add`, etc.), a ukernel lets you write explicit MTE, SIMD, SIMT, and Cube instructions — staging data with `mte_load`, synchronizing with `mem_bar`, and dispatching L3 sub-kernels. This is an advanced programming mode for expert users who need precise control over instruction ordering and hardware-level data movement. +### `mode="auto"` — tile-centric -### Signature +In auto mode you think in tiles. You allocate tiles, partition GM views, move +data with `tile.load` and `tile.store`, compute with Tile Ops like +`tile.add` and `tile.exp`, and call sub-kernels for hardware-specific compute. +The compiler handles the lowering of tiles to micro-instructions: inferring +staging, inserting synchronization between Tile Ops and sub-kernels, and +managing tile-level scheduling. + +Use auto mode for the majority of kernels. It gives you the full performance +of the NPU without requiring you to reason about instruction-level ordering. + +### `mode="explicit"` — tile + micro-instruction + +Explicit mode extends auto mode with direct micro-instruction access. You keep +everything available in auto — tiles, Tile Ops, sub-kernels — and additionally +gain access to MTE ops, explicit synchronization, and pointer manipulation. +When you need precise control over individual instructions and phase ordering, +you can drop below the tile abstraction without leaving the `@pto.jit` entry. - +The richer type surface also applies to sub-kernels: in auto mode, a +sub-kernel's parameters are restricted to `Tile` and PTO scalar types; in +explicit mode they may also accept `PartitionTensorView` and `pto.ptr`, +matching the types available in the enclosing orchestration code. Organize +orchestration logic into helper functions that accept these types: + + ```python -@pto.ukernel -def my_ukernel( +def my_orchestration_helper( part: pto.PartitionTensorView, # GM partition descriptors tile: pto.Tile, # UB tile buffers scratch: pto.Tile, # cube-local scratch (LEFT, RIGHT, ...) @@ -186,13 +242,12 @@ def my_ukernel( return ``` -Parameters are PTO-specific types — `Tile`, `PartitionTensorView`, `pto.ptr`, and PTO scalar types. Unlike `@pto.jit`, a ukernel does not accept Python-native tensors. - -### Typical body +**Typical pattern**: GM↔UB movement uses ptr-based `mte_load`/`mte_store` +rather than `tile.load`/`tile.store`. The user places `pipe_barrier` at phase +boundaries and explicitly sequences sub-kernel calls: - + ```python -@pto.ukernel def process_block(q_tile, k_part, v_part, k_tile, v_tile, s_tile, o_tile, o_part, rows: pto.i32, cols: pto.i32): in_row_bytes = cols * pto.bytewidth(pto.f16) @@ -219,15 +274,41 @@ def process_block(q_tile, k_part, v_part, k_tile, v_tile, nburst=(rows, ub_row_stride, gm_row_stride)) ``` -A ukernel stays below the tile-op boundary — GM↔UB movement is expressed with ptr-based `mte_load`/`mte_store` (MTE Ops) rather than `tile.load`/`tile.store`. +Sub-kernel calls and inline sub-kernel scopes (`with pto.simd():`, etc.) work +identically in both modes. -## 3.4 `@pto.cube` — Cube unit sub-kernel +### Choosing between modes -### Role +| | `mode="auto"` | `mode="explicit"` | +|---|---|---| +| Abstraction | Tiles | Tiles + micro-instructions | +| Data movement | `tile.load` / `tile.store` | `mte_load` / `mte_store` (ptr-based) | +| Sync | Compiler-managed | User-authored | +| Use case | Most kernels | Hand-tuned instruction scheduling | -`@pto.cube` marks a function that executes on the Cube unit (matrix multiplication engine). It consumes UB-resident tiles and explicit cube-local scratch buffers. +Start with auto. Move to explicit when you need to control the exact sequence +of micro-instructions — for example, to overlap DMA and compute with +double-buffering, or to hand-optimize a phase boundary that the compiler +doesn't fuse as aggressively as you need. -### Signature +## 3.3 Sub-kernels + +Sub-kernels are functions decorated with `@pto.cube`, `@pto.simd`, or +`@pto.simt` that execute on a specific NPU compute unit. They can be invoked +in two ways: + +1. **As decorated functions** — reusable, named sub-kernels called from + `@pto.jit`. +2. **As context managers** (`with pto.cube():`, etc.) — inline blocks for + one-off snippets (see Section 3.4). + +### 3.3.1 `@pto.cube` — Cube unit + +**Role**: `@pto.cube` marks a function that executes on the Cube unit (matrix +multiplication engine). It consumes UB-resident tiles and explicit cube-local +scratch buffers. + +**Signature**: ```python @@ -242,9 +323,11 @@ def my_cube_kernel( return ``` -All parameters are `Tile` references. Tiles marked as cube-local must be allocated with the appropriate `memory_space` (e.g., `pto.MemorySpace.LEFT`, `pto.MemorySpace.ACC`). +All parameters are `Tile` references. Tiles marked as cube-local must be +allocated with the appropriate `memory_space` (e.g., `pto.MemorySpace.LEFT`, +`pto.MemorySpace.ACC`). -### Typical body +**Typical body**: ```python @@ -267,20 +350,21 @@ def qk_matmul( pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) ``` -Cube-local state (LEFT, RIGHT, ACC, BIAS) never leaks into UB — it is the caller's responsibility to allocate scratch buffers and pass them in explicitly. +Cube-local state (LEFT, RIGHT, ACC, BIAS) never leaks into UB — it is the +caller's responsibility to allocate scratch buffers and pass them in +explicitly. -**Invocation modes**: `@pto.cube` functions can be: -- Called from `@pto.ukernel` (manual MTE + sync in the ukernel's hands). -- Called directly from `@pto.jit` (compiler infers MTE + sync). -- Used inline as a context manager: `with pto.cube():` (see Section 3.7). +**Invocation modes**: can be called from `@pto.jit` in either mode, or used +inline with `with pto.cube():` (Section 3.4). -## 3.5 `@pto.simd` — SIMD unit sub-kernel +### 3.3.2 `@pto.simd` — SIMD unit -### Role +**Role**: `@pto.simd` marks a function that executes on the SIMD unit (vector +engine). It operates on vector registers (`vreg`) loaded from UB tiles and +stores results back to UB tiles. Vector registers are local to the function +and never cross its boundary. -`@pto.simd` marks a function that executes on the SIMD unit (vector engine). It operates on vector registers (`vreg`) loaded from UB tiles and stores results back to UB tiles. Vector registers are local to the function and never cross its boundary. - -### Signature +**Signature**: ```python @@ -294,9 +378,11 @@ def my_simd_kernel( return ``` -Parameters are UB `Tile` references and PTO scalar values (`pto.i32`, `pto.f32`, etc.). Scalar parameters may come from `lds` reads or compile-time constants. +Parameters are UB `Tile` references and PTO scalar values (`pto.i32`, +`pto.f32`, etc.). Scalar parameters may come from `lds` reads or compile-time +constants. -### Typical body +**Typical body**: ```python @@ -318,20 +404,24 @@ def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, col_loop.update(remained=remained) ``` -The boundary contract: `vreg` values (`a_vec`, `b_vec`, `o_vec`) are local to the function. The only way to persist data across a `@pto.simd` call is to write it back to a UB tile via `vsts` (or `psts`, etc.). +The boundary contract: `vreg` values (`a_vec`, `b_vec`, `o_vec`) are local to +the function. The only way to persist data across a `@pto.simd` call is to +write it back to a UB tile via `vsts` (or `psts`, etc.). -**Invocation modes**: `@pto.simd` functions can be: -- Called from `@pto.ukernel` (manual MTE + sync in the ukernel's hands). -- Called directly from `@pto.jit` (compiler infers MTE + sync). -- Used inline as a context manager: `with pto.simd():` (see Section 3.7). +**Invocation modes**: can be called from `@pto.jit` in either mode, or used +inline with `with pto.simd():` (Section 3.4). -## 3.6 `@pto.simt` — SIMT unit sub-kernel +### 3.3.3 `@pto.simt` — SIMT unit -### Role +**Role**: `@pto.simt` marks a function that executes on the SIMT unit. SIMT +(Single Instruction, Multiple Threads) is a programming model where you write +instructions in scalar syntax, and the hardware executes them in parallel +across many threads — analogous to how a GPU SM runs a CUDA kernel. Each +instruction appears to operate on a single element (`lds`, `sts`, `a + b`), +but the same instruction is issued across a large number of work-items +simultaneously. -`@pto.simt` marks a function that executes on the SIMT unit. SIMT (Single Instruction, Multiple Threads) is a programming model where you write instructions in scalar syntax, and the hardware executes them in parallel across many threads — analogous to how a GPU SM runs a CUDA kernel. Each instruction appears to operate on a single element (`lds`, `sts`, `a + b`), but the same instruction is issued across a large number of work-items simultaneously. - -### Signature +**Signature**: ```python @@ -344,7 +434,7 @@ def my_simt_kernel( return ``` -### Typical body +**Typical body**: ```python @@ -365,23 +455,26 @@ def blend_output_rows( scalar.store(o_next, o_next_tile[row, col]) ``` -SIMT kernels read and write individual scalar elements from tiles. The unit executes the same scalar instruction across many work-items in parallel, making it efficient for per-element operations. +SIMT kernels read and write individual scalar elements from tiles. The unit +executes the same scalar instruction across many work-items in parallel, making +it efficient for per-element operations. -**Invocation modes**: `@pto.simt` functions can be: -- Called from `@pto.ukernel` (manual MTE + sync in the ukernel's hands). -- Called directly from `@pto.jit` (compiler infers MTE + sync). -- Used inline as a context manager: `with pto.simt():` (see Section 3.7). +**Invocation modes**: can be called from `@pto.jit` in either mode, or used +inline with `with pto.simt():` (Section 3.4). -## 3.7 Context manager syntax for L3 sub-kernels +## 3.4 Inline context manager syntax -In addition to the decorator form, each L3 sub-kernel unit provides a context manager: `with pto.cube():`, `with pto.simd():`, and `with pto.simt():`. These open an inline L3 block without requiring a separate named function — useful for quick prototyping, one-off compute snippets, or when the logic is too trivial to extract. The inline form is supported in top-level `@pto.jit` bodies and inside `@pto.ukernel`. +In addition to the decorator form, each sub-kernel unit provides a context +manager: `with pto.cube():`, `with pto.simd():`, and `with pto.simt():`. These +open inline blocks without requiring a separate named function — useful for +quick prototyping, one-off hardware-unit snippets, or code that is too small to +extract. Inline scopes are supported in top-level `@pto.jit` bodies. ### Syntax ```python with pto.simd(): - # Direct L3 instructions — vreg ops, scalar loads/stores a_vec = pto.vlds(a_tile[r, c:]) b_vec = pto.vlds(b_tile[r, c:]) o_vec = pto.vadd(a_vec, b_vec, mask) @@ -408,40 +501,48 @@ with pto.cube(): ### Semantics -- Inside the `with` block, instructions execute on the corresponding hardware unit. -- `vreg` values created inside `with pto.simd():` are scoped to the block — they do not escape. -- Cube-local scratch (`l0a`, `l0b`, `acc`) must be allocated by the caller before entering the block. -- The context manager form is equivalent to defining an inline anonymous sub-kernel. The compiler treats it identically to a named `@pto.simd` / `@pto.cube` / `@pto.simt` function. +- Inside the `with` block, instructions execute on the corresponding hardware + unit. +- `vreg` values created inside `with pto.simd():` are scoped to the block — + they do not escape. +- Cube-local scratch (`l0a`, `l0b`, `acc`) must be allocated by the caller + before entering the block. +- The context manager form is equivalent to an inline anonymous sub-kernel. The + compiler treats it identically to a named `@pto.simd` / `@pto.cube` / + `@pto.simt` function. ### Comparison | | Decorator form | Context manager form | |---|---|---| -| Reuse | Named, callable from multiple call sites | Inline, single-use | +| Reuse | Named, callable from multiple sites | Inline, single-use | | Readability | Good for complex, multi-step logic | Good for short (3-10 line) snippets | | Testing | Can be unit-tested independently | Tested only through the enclosing kernel | | Cube-local args | Explicit parameters | Captured from enclosing scope | -The two forms can be freely mixed in the same `@pto.jit` or `@pto.ukernel` body. +The two forms can be freely mixed in the same `@pto.jit` body. -## 3.8 Boundary contracts +## 3.5 Boundary contracts -Data crosses decorator boundaries only through UB-backed tiles or typed UB pointers: +Data crosses decorator boundaries only through UB-backed tiles or typed UB +pointers: | Boundary | Allowed | |----------|---------| | Host → `@pto.jit` | Python-native tensors | -| `@pto.jit` → `@pto.ukernel` | `Tile`, `PartitionTensorView`, `pto.ptr`, PTO scalars | -| `@pto.jit` → L3 sub-kernel (direct call) | `Tile`, PTO scalars (compiler handles MTE + sync) | -| `@pto.jit` → `with pto.{cube,sid,sitm}:` | `Tile` captured from enclosing scope | -| `@pto.ukernel` → L3 sub-kernel | `Tile`, PTO scalars | -| L3 sub-kernel → L3 sub-kernel | Not allowed (go through UB tiles via the caller) | +| `@pto.jit(mode="auto")` → sub-kernel | `Tile`, PTO scalars (compiler handles staging + sync) | +| `@pto.jit(mode="explicit")` → sub-kernel | `Tile`, `PartitionTensorView`, `pto.ptr`, PTO scalars | +| `@pto.jit` → `with pto.{cube,simd,simt}:` | `Tile` captured from enclosing scope | +| Sub-kernel → sub-kernel | Not allowed (go through UB tiles via the caller) | | `@pto.simd` → caller | Only via `vsts`/`psts` to UB tiles; `vreg` cannot escape | | Cube-local → UB | Only via `mte_l0c_ub`; LEFT/RIGHT/ACC/BIAS are private | -## 3.9 `pto.constexpr` +## 3.6 `pto.constexpr` -`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time constant. The compiler specializes the kernel for each combination of constexpr values, and the compiled artifact is cached by specialization key together with the kernel's tensor ABI contract. +`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time +constant. The compiler specializes the kernel for each combination of constexpr +values, and the compiled artifact is cached by specialization key together with +the kernel's tensor ABI contract. ```python @@ -458,9 +559,16 @@ def kernel( - Must appear as a keyword-only argument (after `*`). - Must have a default value. -- Must be provided at `.compile()` time if the caller needs to override the default. -- Cannot change between launches of the same compiled instance — compile a new variant for a different value. - -`pto.constexpr` parameters can be used anywhere in the kernel body where a Python value is expected: tile shapes, loop bounds that are known at compile time, dtype arguments, etc. They are evaluated at trace time, so `for i in range(BLOCK)` would unroll `BLOCK` times. - -In contrast, values derived from runtime tensor shapes (e.g., `A.shape[0]`) are dynamic — they vary per launch and should be used with `pto.for_` to produce device-side loops. +- Must be provided at `.compile()` time if the caller needs to override the + default. +- Cannot change between launches of the same compiled instance — compile a new + variant for a different value. + +`pto.constexpr` parameters can be used anywhere in the kernel body where a +Python value is expected: tile shapes, loop bounds that are known at compile +time, dtype arguments, etc. They are evaluated at trace time, so `for i in +range(BLOCK)` would unroll `BLOCK` times. + +In contrast, values derived from runtime tensor shapes (e.g., `A.shape[0]`) +are dynamic — they vary per launch and should be used with `pto.for_` to +produce device-side loops. diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md index 0c2b74b54..19d4c38e4 100644 --- a/ptodsl/docs/user_guide/04-type-system-and-buffer.md +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -197,7 +197,7 @@ def kernel( | `element_type` | `Type` | Element dtype (e.g., `pto.f32`) | | `strides` | `tuple[int, ...]` | Stride of each dimension, in elements | -Strides support non-contiguous tensors. Pass `strides=A.strides` from the source tensor for the default row-major layout, or supply explicit strides for sub-views. Use `tv.as_ptr()` to obtain a typed GM pointer for use with MTE Ops in a ukernel. +Strides support non-contiguous tensors. Pass `strides=A.strides` from the source tensor for the default row-major layout, or supply explicit strides for sub-views. Use `tv.as_ptr()` to obtain a typed GM pointer for use with MTE Ops in explicit-mode orchestration. ## 4.6 PartitionTensorView @@ -208,7 +208,7 @@ Strides support non-contiguous tensors. Pass `strides=A.strides` from the source part = pto.partition_view(tv, offsets=[row_offset, 0], sizes=[BLOCK, dim]) ``` -The result is a `PartitionTensorView` — a lightweight descriptor, not a data buffer. It carries the partition's shape, strides, and element type (inherited from the source TensorView). Use `part.as_ptr()` to obtain a typed GM pointer for MTE Ops in a ukernel. +The result is a `PartitionTensorView` — a lightweight descriptor, not a data buffer. It carries the partition's shape, strides, and element type (inherited from the source TensorView). Use `part.as_ptr()` to obtain a typed GM pointer for MTE Ops in explicit-mode orchestration. ## 4.7 Tile diff --git a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md index b520b89d2..85617e17e 100644 --- a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md +++ b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md @@ -234,7 +234,7 @@ m_next = scalar.max(m_prev, row_max) l_scaled = l_prev * scalar.exp(m_prev - m_next) ``` -These are the scalar-path counterparts of the vector math operations covered in Chapter 8. Use them inside `@pto.simt` kernels and in `@pto.ukernel` orchestration code where you need to compute a loop bound or a scalar coefficient from runtime data. +These are the scalar-path counterparts of the vector math operations covered in Chapter 8. Use them inside `@pto.simt` kernels and in explicit-mode orchestration code where you need to compute a loop bound or a scalar coefficient from runtime data. ## 6.4 Pointer operations diff --git a/ptodsl/docs/user_guide/07-data-movement-ops.md b/ptodsl/docs/user_guide/07-data-movement-ops.md index 225cdb25d..e3d3514a7 100644 --- a/ptodsl/docs/user_guide/07-data-movement-ops.md +++ b/ptodsl/docs/user_guide/07-data-movement-ops.md @@ -1,6 +1,6 @@ # 7. Data Movement Operations -This chapter covers every operation that moves data between memory spaces in PTODSL — tile-level transfers, DMA micro-instructions, vector loads and stores, and cube data movement. Operations are organized by abstraction level: tile ops (L1), DMA ops (L2), vector memory ops (L3 SIMD), and cube memory ops (L3 cube). +This chapter covers every operation that moves data between memory spaces in PTODSL — tile-level transfers, DMA micro-instructions, vector loads and stores, and cube data movement. Operations are organized by abstraction level: tile ops for auto mode, DMA orchestration for explicit mode, vector memory ops on the SIMD unit, and cube memory ops on the Cube unit. ## 7.1 Tile-level movement: tile.load and tile.store @@ -52,11 +52,12 @@ pto.tile.store(o_tile, o_part) --- -Both `tile.load` and `tile.store` operate at **tile granularity** — they are the idiomatic choice inside `@pto.jit` loops. When you need finer control over DMA scheduling, drop down to the micro-instruction level. +Both `tile.load` and `tile.store` operate at **tile granularity** — they are the idiomatic choice inside `@pto.jit` loops. When you need finer control over DMA scheduling, switch to +`mode="explicit"` and use the DMA micro-instructions covered in the next section. -## 7.2 DMA micro-instructions (ukernel) +## 7.2 DMA micro-instructions (explicit mode) -Inside `@pto.ukernel`, data movement between memory spaces is expressed with grouped DMA instructions on typed pointers. There are four operations covering the four data-movement directions: +Inside explicit-mode orchestration, data movement between memory spaces is expressed with grouped DMA instructions on typed pointers. There are four operations covering the four data-movement directions: | Operation | Direction | Stride unit | Padding | |-----------|-----------|-------------|---------| @@ -250,11 +251,11 @@ For `mte_ub_ub` and `mte_ub_l1`, the parameters are in **32-byte units**. Each b **UB address alignment**: For all four operations, every UB address (source and destination) must be 32-byte aligned. The `pad(...)` on `mte_gm_ub` ensures each UB row is padded to the 32B-aligned boundary of `dst_stride`, so subsequent rows stay aligned. -### 7.2.6 Typical ukernel DMA pattern +### 7.2.6 Typical explicit-mode DMA pattern - + ```python -@pto.ukernel +# Inside a @pto.jit(mode="explicit") body: def process_block(k_part, v_part, k_tile, v_tile, o_tile, o_part, rows: pto.i32, cols: pto.i32): # Stage K and V blocks from GM to UB diff --git a/ptodsl/docs/user_guide/10-sync-ops.md b/ptodsl/docs/user_guide/10-sync-ops.md index 06416d587..de727ee89 100644 --- a/ptodsl/docs/user_guide/10-sync-ops.md +++ b/ptodsl/docs/user_guide/10-sync-ops.md @@ -2,7 +2,7 @@ Chapters 7 and 8 covered data movement and computation. This chapter covers the synchronization primitives that keep those operations correctly ordered across the NPU's concurrent hardware pipelines. -The Ascend NPU executes work across multiple independent pipelines — MTE (DMA), Vector, and Cube — each with its own instruction stream. Synchronization operations coordinate these pipelines: a DMA must finish loading data before the vector unit starts computing on it; a matrix multiply must complete before the result is stored. Without explicit synchronization, pipelines race, and results are undefined. +The Ascend NPU executes work across multiple independent pipelines — MTE (DMA), Vector, and Cube — each with its own instruction stream. Synchronization operations coordinate these pipelines: a DMA must finish loading data before the vector unit starts computing on it; a matrix multiply must complete before the result is stored. These operations are available in both `mode="auto"` and `mode="explicit"` when the kernel needs them. Without correct synchronization, pipelines race, and results are undefined. ## 10.1 Enum types for synchronization @@ -127,13 +127,14 @@ pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) pto.pipe_barrier(pto.Pipe.ALL) ``` -### Typical usage pattern +### Typical explicit-mode usage pattern -A common ukernel pattern interleaves DMA and compute with `set_flag` / `wait_flag` pairs: +A common explicit-mode pattern interleaves DMA and compute with `set_flag` / +`wait_flag` pairs: - + ```python -@pto.ukernel +# Inside a @pto.jit(mode="explicit") body: def gemm_block( q_tile: pto.Tile, k_part: pto.PartitionTensorView, @@ -266,14 +267,14 @@ The most commonly used barrier types in practice: | Vector → scalar handoff | `BarrierType.VS_ALL` | | Scalar → vector handoff | `BarrierType.SV_ALL` | -### Usage in ukernel blocks +### Usage in explicit orchestration blocks -In flash attention, phase boundaries use `pipe_barrier(Pipe.ALL)`, while +In explicit-mode kernels, phase boundaries use `pipe_barrier(Pipe.ALL)`, while `mem_bar` remains the tool for narrower intra-pipeline ordering: - + ```python -@pto.ukernel +# Inside a @pto.jit(mode="explicit") body: def flash_attention_block( q_tile: pto.Tile, k_part: pto.PartitionTensorView, @@ -418,21 +419,23 @@ pto.set_intra_flag(pto.Pipe.MTE3, 0) pto.wait_intra_flag(pto.Pipe.V, 0) ``` -## 10.6 Synchronization in the abstraction hierarchy +## 10.6 Synchronization in the authoring model -Where do sync operations belong in PTODSL's layered model? +Where do sync operations belong in PTODSL's public entry model? -| Layer | Sync responsibility | -|-------|---------------------| -| L1 `@pto.jit` | Tile ops require sync, but PTOAS **auto-inserts** `set_flag`/`wait_flag` pairs based on op-to-pipe mapping — the user does not write sync explicitly | -| L2 `@pto.ukernel` | User writes micro-instructions directly and takes full responsibility for sync: `set_flag`/`wait_flag` between DMA and compute, `mem_bar` between compute phases, `pipe_barrier` at block boundaries | -| L3 `@pto.cube` / `@pto.simd` | Cross-pipeline sync (`set_flag`/`wait_flag`) is managed by the calling ukernel. Sub-kernels may still use `mem_bar` for intra-pipeline ordering (e.g., store-then-load to the same UB region) | +| Surface | Sync responsibility | +|---------|---------------------| +| `@pto.jit(mode="auto")` | Users can write sync explicitly when needed. PTOAS also provides an `--enable-insert-sync` option that auto-inserts `set_flag`/`wait_flag` pairs based on op-to-pipe mapping. | +| `@pto.jit(mode="explicit")` | The compiler does not insert sync — the user is fully responsible. Place `set_flag`/`wait_flag` between MTE and compute, `mem_bar` between compute phases, `pipe_barrier` at phase boundaries. | +| Shared `@pto.cube` / `@pto.simd` / `@pto.simt` helpers | Cross-pipeline ordering is provided by the surrounding `@pto.jit` schedule. Helpers may still use `mem_bar` for intra-pipeline ordering when UB addresses alias. | -**Rule of thumb**: at L1, sync can be manual or auto-inserted (`--enable-insert-sync`). At L2, sync is always explicit. +**Rule of thumb**: in `mode="auto"`, think in tiles and let the compiler handle +orchestration. In `mode="explicit"`, think in micro-instructions and place the +required sync yourself. ### Auto-sync at the tile level -When writing `@pto.jit` code with tile ops (`tile.load`, `tile.store`, `tile.add`, etc.), each op carries a pipe assignment (e.g., `tile.load` → `PIPE_MTE2`, `tile.add` → `PIPE_V`). PTOAS's sync-insertion pass analyzes the op sequence, infers the necessary `set_flag`/`wait_flag` pairs from the pipe transitions, and injects them into the lowered code. The tile ops themselves still require synchronization — the difference is that the compiler, not the user, writes it. +In auto mode, users can still write sync operations directly — `set_flag`/`wait_flag`, `pipe_barrier`, `mem_bar` are available in both modes. For convenience, PTOAS also provides an `--enable-insert-sync` pass: each tile op carries a pipe assignment (e.g., `tile.load` → `PIPE_MTE2`, `tile.add` → `PIPE_V`), and the pass analyzes the op sequence, infers the necessary `set_flag`/`wait_flag` pairs from pipe transitions, and injects them into the lowered code. ### Quick reference: which sync for which scenario diff --git a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md index 845017eba..6e5a5e98f 100644 --- a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md +++ b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md @@ -1,6 +1,6 @@ # 11. Flash Attention Complete Walkthrough -This chapter walks through `examplesflash_attention_sketch.py` layer by layer, tracing a complete flash attention implementation from the user-facing Python wrapper down to hardware-bound sub-kernels. Every API discussed in Chapters 1–10 appears in context here. +This chapter walks through `examples/flash_attention_sketch.py` layer by layer, tracing a complete flash attention implementation from the user-facing Python wrapper down to hardware-bound sub-kernels. Every API discussed in Chapters 1–10 appears in context here. The sketch computes **online-softmax flash attention** for one `(batch, head)` slice per launch instance. It partitions Q into blocks along the sequence dimension, iterates over KV blocks for each Q block, and maintains rolling softmax state across KV iterations. @@ -8,20 +8,18 @@ The sketch computes **online-softmax flash attention** for one `(batch, head)` s ``` flash_attention(...) L0 user-facing wrapper - └─ @pto.jit flash_attention_kernel + └─ @pto.jit(mode="explicit") flash_attention_kernel ├─ Tile Ops tile.load / tile.store at the GM↔UB boundary - └─ @pto.ukernel kv_block_process - ├─ @pto.simt materialize_tile_bounds - ├─ @pto.cube qk_matmul - ├─ @pto.simd online_softmax_rows - ├─ @pto.cube pv_matmul - └─ @pto.simt blend_output_rows + ├─ explicit orchestration mte_load / pipe_barrier / pointer sequencing + ├─ @pto.cube qk_matmul / pv_matmul + ├─ @pto.simd online_softmax_rows + └─ @pto.simt materialize_tile_bounds / blend_output_rows ``` The dataflow for one KV block: ``` -ukernel loads K/V block and sequences the pipeline +explicit-mode orchestration loads the K/V block and sequences the pipeline │ ├─ cube: Q + K ───────────────► S ├─ simd: S + (m_prev, l_prev) ─► P, (m_next, l_next), alpha, beta @@ -32,7 +30,7 @@ After each KV block: (m_prev, l_prev, o_prev) := (m_next, l_next, o_next) ``` -## 11.2 L0 — Python wrapper +## 11.2 The Python wrapper ```python def flash_attention(Q, K, V, *, O=None, causal=False, @@ -56,13 +54,13 @@ This is plain Python — no PTO types, no IR. It handles ergonomic runtime conce - **Shape extraction**: reads `batch`, `seq_q`, `heads`, `dim` from the framework tensors. - **Compile + launch**: `flash_attention_kernel.compile(...)` JIT-compiles the kernel with the given constexpr parameters, then launches it with a `[batch * heads]` grid — one block per `(batch, head)` slice. -L0 knows nothing about tiles, UB, or pipelines. It is the boundary between the user's tensor world and the PTO device world. +The wrapper knows nothing about tiles, UB, or pipelines. It is the boundary between the user's tensor world and the PTO device world. -## 11.3 L1 — `@pto.jit` kernel entry +## 11.3 Top-level `@pto.jit(mode="explicit")` kernel entry ```python -@pto.jit(target="a5") +@pto.jit(target="a5", mode="explicit") def flash_attention_kernel( Q: pto.tensor_spec(rank=4, dtype=pto.f32), K: pto.tensor_spec(rank=4, dtype=pto.f32), @@ -78,7 +76,7 @@ def flash_attention_kernel( return ``` -The `@pto.jit` decorator marks the compile + launch boundary. Inputs are Python-native tensors; outputs are written in-place to `O`. Keyword-only `constexpr` parameters (`BLOCK_Q`, `BLOCK_KV`, `CAUSAL`) are baked at compile time. +The `@pto.jit(mode="explicit")` decorator marks the compile + launch boundary. Inputs are Python-native tensors; outputs are written in-place to `O`. Keyword-only `constexpr` parameters (`BLOCK_Q`, `BLOCK_KV`, `CAUSAL`) are baked at compile time. ### 11.3.1 TensorView construction @@ -194,7 +192,7 @@ alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") ``` -The walkthrough keeps Q/K/V/P on the MAT path so the cube sub-kernels consume the same tile objects that the L1 schedule owns. Runtime tails still live in `valid_shape`; the physical tile shapes stay static. +The walkthrough keeps Q/K/V/P on the MAT path so the cube sub-kernels consume the same tile objects that the top-level kernel owns. Runtime tails still live in `valid_shape`; the physical tile shapes stay static. **UB-resident state and scratch tiles** — the online-softmax state plus intermediate outputs: @@ -241,7 +239,7 @@ meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) meta_ptr = meta_tile.as_ptr() ``` -A small UB tile stores three scalar loop bounds (`row_start`, `row_stop`, `valid_cols`). `meta_tile.as_ptr()` materializes a typed UB pointer into it, which is passed to the ukernel as scalar control metadata. +A small UB tile stores three scalar loop bounds (`row_start`, `row_stop`, `valid_cols`). `meta_tile.as_ptr()` materializes a typed UB pointer into it, which is passed to the explicit-mode orchestration as scalar control metadata. Notice that the row-wise softmax state tiles (`m_*`, `l_*`, `alpha_tile`, `beta_tile`) are authored as `blayout="ColMajor"`. This is the intended public @@ -323,15 +321,15 @@ with pto.for_(0, q_blocks, step=1) as qi: Key points: - **Static physical shape, dynamic valid extent**: `alloc_tile(shape=...)` stays constexpr. Tail handling is expressed by updating `valid_shape` before each block load and sub-kernel call. -- **`tile.load` at the L1 boundary**: Q is loaded once per Q block using a tile op into the MAT-backed bridge tile `q_mat`. The compiler auto-inserts the necessary `set_flag`/`wait_flag` pairs. +- **`tile.load` at the kernel entry boundary**: Q is loaded once per Q block using a tile op into the MAT-backed bridge tile `q_mat`. The compiler auto-inserts the necessary `set_flag`/`wait_flag` pairs. - **State initialization**: `fill(float("-inf"))` and `fill(0.0)` initialize the online-softmax accumulators before the first KV block. - **Carry state**: the inner `kv_loop` carries three ping-pong tiles (`m`, `l`, `o`) across iterations using `.carry(...)` / `.update(...)` / `.final(...)`. After each KV block, the loop updates the carried values to the `_next` tiles. After the loop, `.final("o")` extracts the final output accumulator. -- **`tile.store` at the L1 boundary**: writes the final result for this Q block back to GM. +- **`tile.store` at the kernel entry boundary**: writes the final result for this Q block back to GM. -## 11.4 L2 — `@pto.ukernel` +## 11.4 Explicit orchestration ```python -@pto.ukernel +# Explicit orchestration helper used by flash_attention_kernel: def kv_block_process( q_mat, k_part, v_part, k_mat, v_mat, o_prev_tile, o_next_tile, @@ -344,11 +342,11 @@ def kv_block_process( ): ``` -The ukernel processes one KV block against an already-loaded Q tile. It owns the execution sandwich: +The explicit-mode body processes one KV block against an already-loaded Q tile. It owns the execution sandwich: ### Phase 0 — Stage K/V data - + ```python rows = k_mat.valid_shape[0] cols = k_mat.valid_shape[1] @@ -362,11 +360,11 @@ pto.mte_load(v_part.as_ptr(), v_mat.as_ptr(), 0, row_bytes, pto.pipe_barrier(pto.Pipe.ALL) ``` -`mte_load` is the ptr-based GM→MAT DMA wrapper used by this walkthrough. The ukernel passes explicit GM/MAT pointers plus the DMA grouping parameters, and `pipe_barrier(Pipe.ALL)` makes the phase boundary explicit before the cube unit reads `k_mat`/`v_mat`. +`mte_load` is the ptr-based GM→MAT DMA wrapper used by this walkthrough. Explicit mode passes GM/MAT pointers plus the DMA grouping parameters, and `pipe_barrier(Pipe.ALL)` makes the phase boundary explicit before the cube unit reads `k_mat`/`v_mat`. ### Phase 0b — Materialize loop bounds - + ```python materialize_tile_bounds(meta_ptr, q_mat.valid_shape[0], @@ -376,11 +374,11 @@ row_stop = scalar.load(meta_ptr + 1) valid_cols = scalar.load(meta_ptr + 2) ``` -The SIMT sub-kernel `materialize_tile_bounds` writes `{0, valid_rows, valid_cols}` into the metadata buffer. The ukernel then loads these scalars. They control the row iteration range in subsequent sub-kernels, handling partial tail blocks. +The SIMT sub-kernel `materialize_tile_bounds` writes `{0, valid_rows, valid_cols}` into the metadata buffer. The explicit-mode body then loads these scalars. They control the row iteration range in subsequent sub-kernels, handling partial tail blocks. ### Phase 1 — `S = Q @ K^T` - + ```python qk_matmul(q_mat, k_mat, q_l0a, rhs_l0b, qk_acc_tile, s_tile) pto.pipe_barrier(pto.Pipe.ALL) @@ -390,7 +388,7 @@ Dispatches the cube sub-kernel. `pipe_barrier(Pipe.ALL)` separates the matrix mu ### Phase 2 — Online softmax - + ```python online_softmax_rows( s_tile, p_tile, @@ -406,7 +404,7 @@ The simd sub-kernel computes per-row softmax on `S`, updates the running `m`/`l` ### Phase 3 — `PV = P @ V` - + ```python pto.tile.mov(p_tile, p_mat) pto.pipe_barrier(pto.Pipe.ALL) @@ -419,7 +417,7 @@ The probability tile is first staged onto the MAT path with `pto.tile.mov(p_tile ### Phase 4 — Blend output - + ```python blend_output_rows( o_prev_tile, pv_tile, alpha_tile, beta_tile, @@ -431,11 +429,11 @@ pto.pipe_barrier(pto.Pipe.ALL) The simt sub-kernel blends the old output accumulator with the new PV contribution, weighted by `alpha` and `beta`. -### Why the ukernel owns sync +### Why explicit mode owns sync ordering -Each `pipe_barrier(Pipe.ALL)` between phases is explicit in the ukernel body. This is intentional: at the L2 micro-instruction level, the user controls pipeline ordering. There is no auto-sync insertion — the ukernel is the single place where the hardware execution sequence is spelled out. +Each `pipe_barrier(Pipe.ALL)` between phases is explicit in the orchestration body. This is intentional: at the orchestration boundary, the user controls pipeline ordering. Auto mode may still use synchronization primitives where needed, but it does so around compiler-managed tile staging rather than user-authored instruction scheduling. -## 11.5 L3a — `@pto.cube` sub-kernels +## 11.5 Cube sub-kernel — `@pto.cube` ### `qk_matmul` — `S = Q @ K^T` @@ -460,7 +458,7 @@ Four cube ops: 3. **`mad`**: matrix multiply-accumulate — `s_acc = q_l0a @ k_l0b`. 4. **`mte_l0c_ub`**: write the accumulator result to the UB output tile `s_tile`. -The cube kernel does not allocate scratch — the caller (L1) owns scratch lifetime. The cube kernel only expresses dataflow. +The cube kernel does not allocate scratch — the caller (top-level kernel) owns scratch lifetime. The cube kernel only expresses dataflow. ### `pv_matmul` — `PV = P @ V` @@ -478,9 +476,9 @@ def pv_matmul(p_mat, v_mat, p_l0a, v_l0b, pv_acc, pv_tile): pto.mte_l0c_ub(pv_acc.as_ptr(), pv_tile.as_ptr(), m, n, n, n, 0) ``` -Structurally identical to `qk_matmul`, but without transposition and with different input/output tiles. The scratch tiles `p_l0a`, `v_l0b`, and `pv_acc` are reused across KV blocks — the caller (L1) allocates them once. +Structurally identical to `qk_matmul`, but without transposition and with different input/output tiles. The scratch tiles `p_l0a`, `v_l0b`, and `pv_acc` are reused across KV blocks — the caller (top-level kernel) allocates them once. -## 11.6 L3b — `@pto.simd` online softmax +## 11.6 SIMD sub-kernel — online softmax ```python @pto.simd @@ -553,7 +551,7 @@ This implements the online-softmax update from the Flash Attention paper: **Boundary contract**: vreg values (`s_row`, `p_row`, `row_max`, `row_sum`) never escape the simd kernel. All persistent state is written to UB tiles. -## 11.7 L3c — `@pto.simt` sub-kernels +## 11.7 SIMT sub-kernel — blend output ### `materialize_tile_bounds` — scalar metadata @@ -616,32 +614,32 @@ For one KV block, the full execution sequence is: | Step | Layer | Operation | Hardware | |------|-------|-----------|----------| -| 1 | L1 | `tile.load(q_part, q_mat)` | GM → MAT | -| 2 | L2 | `mte_load(k_part.as_ptr(), k_mat.as_ptr(), ...)` | GM → MAT | -| 3 | L2 | `mte_load(v_part.as_ptr(), v_mat.as_ptr(), ...)` | GM → MAT | -| 4 | L2 | `pipe_barrier(Pipe.ALL)` | — | -| 5 | L3c | `materialize_tile_bounds` | SIMT | -| 6 | L3a | `qk_matmul` (mte_l1_l0a, mte_l1_l0b, mad, mte_l0c_ub) | Cube | -| 7 | L2 | `pipe_barrier(Pipe.ALL)` | — | -| 8 | L3b | `online_softmax_rows` (vlds, vcgmax, vexp, vcgadd, vsts, ...) | SIMD | -| 9 | L2 | `pipe_barrier(Pipe.ALL)` | — | -| 10 | L2 | `tile.mov(p_tile, p_mat)` | Tile copy | -| 11 | L2 | `pipe_barrier(Pipe.ALL)` | — | -| 12 | L3a | `pv_matmul` | Cube | -| 13 | L2 | `pipe_barrier(Pipe.ALL)` | — | -| 14 | L3c | `blend_output_rows` | SIMT | -| 15 | L2 | `pipe_barrier(Pipe.ALL)` | — | - -After all KV blocks: L1 issues `tile.store(o_final_tile, o_part)` to write the result back to GM. +| 1 | explicit | `tile.load(q_part, q_mat)` | GM → MAT | +| 2 | explicit | `mte_load(k_part.as_ptr(), k_mat.as_ptr(), ...)` | GM → MAT | +| 3 | explicit | `mte_load(v_part.as_ptr(), v_mat.as_ptr(), ...)` | GM → MAT | +| 4 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 5 | simt | `materialize_tile_bounds` | SIMT | +| 6 | cube | `qk_matmul` (mte_l1_l0a, mte_l1_l0b, mad, mte_l0c_ub) | Cube | +| 7 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 8 | simd | `online_softmax_rows` (vlds, vcgmax, vexp, vcgadd, vsts, ...) | SIMD | +| 9 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 10 | explicit | `tile.mov(p_tile, p_mat)` | Tile copy | +| 11 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 12 | cube | `pv_matmul` | Cube | +| 13 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 14 | simt | `blend_output_rows` | SIMT | +| 15 | explicit | `pipe_barrier(Pipe.ALL)` | — | + +After all KV blocks: the top-level kernel issues `tile.store(o_final_tile, o_part)` to write the result back to GM. ## 11.9 Design patterns in this sketch **Ping-pong state for online accumulators**: `m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next` make the state transition explicit. After each KV block, the caller swaps the ping-pong pair (via `kv_loop.update(...)`) rather than aliasing in place. -**Scratch reuse**: `rhs_l0b` serves both `K` (in `qk_matmul`) and `V` (in `pv_matmul`). `pv_acc_tile` reuses the accumulator from QK^T. The caller (L1) allocates once; the ukernel passes them to both cube sub-kernels. +**Scratch reuse**: `rhs_l0b` serves both `K` (in `qk_matmul`) and `V` (in `pv_matmul`). `pv_acc_tile` reuses the accumulator from QK^T. The caller (top-level kernel) allocates once; the explicit-mode body passes them to both cube sub-kernels. -**Tile-level boundary vs micro-instruction boundary**: `tile.load`/`tile.store` appear only in `@pto.jit`. `mte_load` appears only in `@pto.ukernel`, and it is authored in the explicit ptr-based DMA form. This is the key abstraction split: L1 operates on tiles, L2 operates on micro-instructions. +**Tile-level boundary vs micro-instruction boundary**: `tile.load`/`tile.store` are the tile-atomic surface used in auto mode and at the top-level tile boundary of this sketch. `mte_load` appears in explicit orchestration, authored as individual pointer-based instructions. The abstraction split is auto mode as tile-centric authoring, explicit mode as user-ordered orchestration. **No vreg across sub-kernel boundaries**: vector registers are local to each `@pto.simd` kernel. Data crosses sub-kernel boundaries through UB tiles — the boundary contract is enforced by the type system. -**L3 invocation flexibility**: This sketch uses the explicit `@pto.ukernel` → L3 path for full control over MTE and sync. For simpler kernels that don't need that control, L3 sub-kernels can be called directly from `@pto.jit` (the compiler handles MTE + sync) or written inline as context managers (`with pto.simd():`, etc.). See Chapter 3 for details. +**Invocation flexibility**: This sketch uses the explicit `@pto.jit(mode="explicit")` path for full micro-instruction control. The same named sub-kernels can also be reused from `@pto.jit(mode="auto")` when the body stays within the auto-mode contract, or written inline as context managers (`with pto.simd():`, etc.). See Chapter 3 for details. diff --git a/ptodsl/docs/user_guide/12-additional-examples.md b/ptodsl/docs/user_guide/12-additional-examples.md index fbd518580..cff564364 100644 --- a/ptodsl/docs/user_guide/12-additional-examples.md +++ b/ptodsl/docs/user_guide/12-additional-examples.md @@ -44,7 +44,7 @@ def mat_add(A, B, O, *, BLOCK_M: pto.constexpr = 64, BLOCK_N: pto.constexpr = 12 - `partition_view` takes 2D offsets and sizes. - `BLOCK_M` and `BLOCK_N` are `constexpr` — the compiler specializes the kernel per tile shape. -The L0 wrapper follows the same pattern as Chapter 2: +The Python wrapper follows the same pattern as Chapter 2: ```python @@ -156,7 +156,7 @@ def vec_add_with_tail( This example demonstrates a complete GEMM kernel: `C = A @ B` where A is `[M, K]` and B is `[K, N]`. It uses `@pto.jit` for tile allocation and loop scheduling, and `@pto.cube` for the actual matrix multiply. -### 12.3.1 L3: Cube sub-kernel +### 12.3.1 Cube sub-kernel ```python @@ -175,11 +175,11 @@ def gemm_tile(a_mat: pto.Tile, b_mat: pto.Tile, o_tile: pto.Tile, The cube sub-kernel consumes MAT staging tiles plus cube-local scratch buffers. The four-step sequence — stage left operand, stage right operand, multiply, writeback — is the canonical cube compute pattern. -### 12.3.2 L1: Tile orchestration +### 12.3.2 Tile orchestration ```python -@pto.jit(target="a5") +@pto.jit(target="a5", mode="explicit") def gemm( A: pto.tensor_spec(rank=2, dtype=pto.f32), B: pto.tensor_spec(rank=2, dtype=pto.f32), @@ -243,10 +243,10 @@ def gemm( - **Triply nested loops**: M, N, and K dimensions are all blocked. The K loop accumulates partial results into `o_tile`. - **Accumulation**: `o_tile.fill(0.0)` resets the accumulator before the K loop. Each K-block calls `gemm_tile` which writes its partial product back to `o_tile`. The Cube unit accumulates implicitly via `mad` — each K-block's partial result is added to the running total in `o_acc`. - **MAT staging + cube-local scratch**: `a_mat` and `b_mat` are explicit MAT tiles that satisfy the `mte_l1_l0a` / `mte_l1_l0b` source contract. `a_l0a`, `b_l0b`, and `o_acc` are cube-local scratch (`LEFT`, `RIGHT`, `ACC`). -- **Direct L3 call**: `gemm_tile` is called directly from `@pto.jit` — no ukernel needed. The compiler handles sync between `tile.load` and the Cube sub-kernel. +- **Direct sub-kernel call**: `gemm_tile` is called directly from `@pto.jit` — no separate orchestration layer needed. The compiler handles sync between `tile.load` and the Cube sub-kernel. - **Cube sub-kernel reuse**: the same `gemm_tile` function is called for every K-block — the named decorator form enables reuse. -### 12.3.3 L0 wrapper +### 12.3.3 Python wrapper ```python @@ -260,9 +260,9 @@ def gemm_wrapper(A, B, O=None, stream=None): This pattern extends directly to batch-GEMM: pass a grid of `batch` and use `pto.get_block_idx()` to select the per-batch slice from `A` and `B`. -### 12.3.4 Comparison with ukernel path +### 12.3.4 Comparison with explicit-mode orchestration -For reference, the same GEMM could be written using `@pto.ukernel` for explicit MTE control. The ukernel would replace the inner `tile.load`/`tile.store` calls with `mte_load`/`mte_store` and add `mem_bar` synchronization between DMA and compute. The direct-call path used above is recommended for most users — the ukernel path is for cases that need hand-tuned DMA scheduling. +For reference, the same GEMM could be written in `mode="explicit"` when the kernel needs micro-instruction control. The direct-call path used above is recommended for most users; explicit mode is for cases that need hand-authored instruction scheduling and ordering. ## 12.4 Online normalization with loop-carried state @@ -362,7 +362,7 @@ def online_layernorm( ## 12.5 Design guidelines -**Start simple, refine later.** Begin with `@pto.jit` + Tile Ops. If Tile Ops don't cover the computation (e.g., custom softmax, specialized activation), add an L3 sub-kernel. If you need explicit DMA scheduling or inter-pipeline sync, drop to `@pto.ukernel`. +**Start simple, refine later.** Begin with `@pto.jit` + Tile Ops. If Tile Ops don't cover the computation (e.g., custom softmax, specialized activation), add a sub-kernel. If you need micro-instruction-level control, switch the kernel to `mode="explicit"`. **Choose the right entry for each piece:** @@ -373,7 +373,7 @@ def online_layernorm( | Custom row-wise vector math | `@pto.simd` | | Custom per-element logic | `@pto.simt` | | Matrix multiply | `@pto.cube` | -| Explicit DMA + sync ordering | `@pto.ukernel` | -| Inline L3 for quick prototyping | `with pto.simd():` etc. | +| Micro-instruction-level control | `mode="explicit"` | +| Inline compute for quick prototyping | `with pto.simd():` etc. | -**Respect boundary contracts.** Vregs don't cross `@pto.simd` boundaries. Cube-local state doesn't leak into UB. Tile Ops and MTE Ops live at different abstraction levels — keep them in their respective layers. +**Respect boundary contracts.** Vregs don't cross `@pto.simd` boundaries. Cube-local state doesn't leak into UB. Tile Ops and MTE Ops belong to different programming models — use Tile Ops in `mode="auto"`, and micro-instructions in `mode="explicit"`. diff --git a/ptodsl/examples/flash_attention_sketch.py b/ptodsl/examples/flash_attention_sketch.py index 0a3819a5b..0a9dcaf60 100644 --- a/ptodsl/examples/flash_attention_sketch.py +++ b/ptodsl/examples/flash_attention_sketch.py @@ -13,12 +13,12 @@ layering explicit and keep the semantic contracts clean: emit_flash_attention_mlir(...) compile/inspect wrapper - └─ @pto.jit flash_attention_kernel + └─ flash_attention_kernel (@pto.jit, mode="explicit") ├─ Tile Ops tile.load / tile.store at the GM↔UB boundary - └─ @pto.ukernel one KV-block worth of MTE/sync orchestration - ├─ @pto.cube matrix products (QK^T and P@V) - ├─ @pto.simd row-wise online softmax - └─ @pto.simt scalar metadata and output blending + ├─ explicit orchestration mte_load / pipe_barrier / pointer sequencing + ├─ @pto.cube matrix products (QK^T and P@V) + ├─ @pto.simd row-wise online softmax + └─ @pto.simt scalar metadata and output blending Design rules illustrated here: @@ -28,25 +28,24 @@ 2. The Python wrapper owns compile/inspection concerns such as selecting specialization knobs and returning the emitted MLIR text for review. 3. ``@pto.jit`` also owns the top-level logical tiling, tile allocation, and - loop scheduling for one already-selected per-head 2D slice. It should not - manually spell low-level DMA details for every micro step. -4. ``ukernel`` owns the per-block execution sandwich: stage the current K/V + loop scheduling for one already-selected per-head 2D slice. The per-block + DMA and barrier choreography is delegated to explicit orchestration. +4. explicit mode owns the per-block execution sandwich: stage the current K/V block with explicit micro-instructions, synchronize, call hardware-bound sub-kernels, and manage scratch/state. 5. ``@pto.jit`` may use tile ops such as ``tile.load`` / ``tile.store`` at the logical - scheduling boundary, but ``ukernel`` stays below that abstraction level. - Once execution enters ``ukernel``, GM<->UB movement is expressed with - MTE micro-instructions such as ``mte_load`` instead of tile ops. + scheduling boundary, but explicit mode can also express GM<->UB movement + directly. Once execution enters explicit orchestration, MTE micro-instructions + such as ``mte_load`` are used instead of tile ops where needed. ``mte_load`` / ``mte_store`` accept partitions and tiles directly, deriving strides and burst sizes from the type metadata. 6. ``simd`` / ``simt`` / ``cube`` are hardware boundaries. They do not expose vreg values across the function boundary. Data crosses the boundary through UB-backed tiles or typed UB pointers only. -7. L3 sub-kernels can also be called directly from ``@pto.jit`` (compiler - handles MTE + sync) or written inline as context managers - (``with pto.simd():`` etc.). This sketch uses the explicit - ``@pto.ukernel`` → L3 path for full micro-instruction control, but - simpler kernels can skip the ukernel layer. +7. Named sub-kernels are reusable wherever their parameter contract is + satisfied. This sketch uses the explicit ``@pto.jit(mode="explicit")`` path + because it needs user-ordered DMA and phase barriers; smaller kernels can + stay in auto mode and rely on tile-atomic staging instead. 8. Online-softmax state is made explicit with ping-pong tiles (``m_prev``/``m_next``, ``l_prev``/``l_next``, ``o_prev``/``o_next``). Hiding these dependencies with in-place aliases makes the algorithm harder @@ -135,7 +134,7 @@ def emit_flash_attention_mlir( ) return compiled.mlir_text() -@pto.jit(target="a5") +@pto.jit(target="a5", mode="explicit") def flash_attention_kernel( Q: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_q, heads, dim] K: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_k, heads, dim] @@ -276,7 +275,8 @@ def flash_attention_kernel( pv_acc_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, dim]) # SIMT metadata buffer. A tiny raw-pointer island is acceptable at the - # ukernel boundary because this is scalar control data, not user-facing math. + # explicit-orchestration boundary because this is scalar control data, not + # user-facing math. meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) meta_ptr = meta_tile.as_ptr() @@ -371,7 +371,7 @@ def flash_attention_kernel( # ═══════════════════════════════════════════════════════════════════════════════ -# Level 3: hardware-bound sub-kernels +# Hardware-bound sub-kernels # ═══════════════════════════════════════════════════════════════════════════════ # # Boundary contract: @@ -539,11 +539,10 @@ def materialize_tile_bounds( # ═══════════════════════════════════════════════════════════════════════════════ -# Level 2: ukernel — one KV block worth of execution orchestration +# Level 2: explicit orchestration — one KV block worth of execution # ═══════════════════════════════════════════════════════════════════════════════ -@pto.ukernel def kv_block_process( q_mat: pto.Tile, # MAT, reused across inner KV loop k_part: pto.PartitionTensorView, # GM view for current K block @@ -572,7 +571,7 @@ def kv_block_process( """ Process one KV block against an already-loaded Q tile. - The ukernel owns: + The explicit-mode body owns: - staging the current K/V block into reusable UB scratch with explicit DMA-style micro-instructions, - synchronizing the hand-off between MTE, cube, simd, and simt stages, @@ -663,7 +662,7 @@ def kv_block_process( # │ │ # │ Key idea: current demo goal is compile/inspect, not runtime launch. │ # ├──────────────────────────────────────────────────────────────────────────┤ -# │ L1 @pto.jit compile + cache + top-level orchestration │ +# │ L1 @pto.jit(mode="explicit") flash_attention_kernel │ # │ │ # │ flash_attention_kernel.compile(...).mlir_text() │ # │ TensorView metadata / alloc_tile / partition_view / tile.load / tile.store │ @@ -672,15 +671,15 @@ def kv_block_process( # │ Key idea: one launchable entry owns both runtime binding and logical │ # │ tile scheduling. │ # ├──────────────────────────────────────────────────────────────────────────┤ -# │ L2 @pto.ukernel Per-block execution sandwich │ +# │ L2 explicit orchestration Per-block execution sandwich │ # │ │ -# │ explicit mte_load(part, tile) staging for current K/V block, mem_bar, │ -# │ call cube/simd/simt sub-kernels, │ +# │ explicit mte_load(part, tile) staging for current K/V block, │ +# │ pipe_barrier, call cube/simd/simt sub-kernels, │ # │ manage scratch/state hand-off │ # │ │ # │ Key idea: one place owns the "how this block runs on hardware" story. │ # ├──────────────────────────────────────────────────────────────────────────┤ -# │ L3a @pto.cube Matrix-product kernels │ +# │ @pto.cube Matrix-product kernels │ # │ │ # │ qk_matmul: Q @ K^T │ # │ pv_matmul: P @ V │ @@ -688,14 +687,14 @@ def kv_block_process( # │ │ # │ Key idea: UB tiles are inputs/outputs; cube-local state is explicit. │ # ├──────────────────────────────────────────────────────────────────────────┤ -# │ L3b @pto.simd Row-wise vector math │ +# │ @pto.simd Row-wise vector math │ # │ │ # │ online_softmax_rows │ # │ vreg stays local; persistent state is written back to UB tiles │ # │ │ # │ Key idea: no cross-kernel vreg values, only UB-backed state. │ # ├──────────────────────────────────────────────────────────────────────────┤ -# │ L3c @pto.simt Scalar metadata and pointwise blend │ +# │ @pto.simt Scalar metadata and pointwise blend │ # │ │ # │ materialize_tile_bounds / blend_output_rows │ # │ │ @@ -707,7 +706,7 @@ def kv_block_process( # jit kernel alloc/schedule # │ # ▼ -# ukernel loads K/V block and sequences the pipeline +# explicit orchestration loads K/V block and sequences the pipeline # │ # ├─ cube: Q + K ───────────────► S # ├─ simd: S + (m_prev, l_prev) ─► P, (m_next, l_next), alpha, beta diff --git a/ptodsl/examples/softmax_dsl.py b/ptodsl/examples/softmax_dsl.py index f4f6ebbf4..913009bb6 100644 --- a/ptodsl/examples/softmax_dsl.py +++ b/ptodsl/examples/softmax_dsl.py @@ -11,7 +11,7 @@ Generates the same IR as test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto -using the ``@pto.jit`` decorator and the ``pto.*`` namespace. +using the ``@pto.jit(mode="explicit")`` decorator and the ``pto.*`` namespace. The Python maps almost line-for-line to the target MLIR: @@ -47,6 +47,7 @@ kernel_kind="vector", target="a5", func_attr="pto.aicore", + mode="explicit", ) def online_softmax_update_kernel_2d( arg0: pto.ptr(pto.float32, "gm"), diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index 4b4693719..ecc6fe3e1 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -14,6 +14,20 @@ class PTODSLTracingMisuseError(TypeError): """Raised when authored Python misuses PTODSL runtime values during tracing.""" +def _format_source_context(function_name: str | None, source_file: str | None, source_line: int | None) -> str: + details = [] + if function_name: + details.append(f"kernel {function_name!r}") + if source_file is not None: + location = source_file + if source_line is not None: + location = f"{location}:{source_line}" + details.append(location) + if not details: + return "" + return f" ({', '.join(details)})" + + def native_python_control_flow_error(usage: str) -> PTODSLTracingMisuseError: """Return one actionable diagnostic for native Python control-flow misuse.""" return PTODSLTracingMisuseError( @@ -51,18 +65,13 @@ def subkernel_signature_boundary_error(role: str, name: str) -> TypeError: def illegal_subkernel_placement_error(role: str, outer_role: str | None) -> RuntimeError: """Return one diagnostic for a subkernel call placed outside the supported layer graph.""" - if role == "ukernel": - return RuntimeError( - "@pto.ukernel may only be called from the top-level @pto.jit body; " - f"nested invocation inside @pto.{outer_role} is not part of the PTODSL layer contract." - ) if role == "simt": return RuntimeError( - "@pto.simt helper materialization is only supported from the top-level @pto.jit body " - f"or inside @pto.ukernel; it cannot be materialized inside @pto.{outer_role}." + "@pto.simt helper materialization is only supported from the top-level @pto.jit body; " + f"it cannot be materialized inside @pto.{outer_role}." ) return RuntimeError( - f"@pto.{role} may only be called from the top-level @pto.jit body or inside @pto.ukernel; " + f"@pto.{role} may only be called from the top-level @pto.jit body; " f"nested invocation inside @pto.{outer_role} is not part of the PTODSL layer contract." ) @@ -70,7 +79,7 @@ def illegal_subkernel_placement_error(role: str, outer_role: str | None) -> Runt def illegal_inline_subkernel_placement_error(role: str, outer_role: str | None) -> RuntimeError: """Return one diagnostic for an inline subkernel scope placed outside the supported layer graph.""" return RuntimeError( - f"inline pto.{role}() may only be used from the top-level @pto.jit body or inside @pto.ukernel; " + f"inline pto.{role}() may only be used from the top-level @pto.jit body; " f"nested use inside @pto.{outer_role} is not part of the PTODSL layer contract." ) @@ -95,12 +104,65 @@ def tile_row_alignment_error(*, shape, dtype, row_bytes: int, required_alignment ) +def explicit_mode_required_error(surface: str, current_mode: str | None) -> RuntimeError: + """Return one diagnostic for explicit-only surfaces used outside explicit mode.""" + observed_mode = "unknown" if current_mode is None else current_mode + return RuntimeError( + f"{surface} is an auto-mode contract violation: it is only available in " + f'@pto.jit(mode="explicit"); current kernel mode is {observed_mode!r}. ' + "Move the kernel to explicit mode before authoring this surface." + ) + + +def explicit_mode_required_with_context_error(surface: str, module_spec) -> RuntimeError: + """Return one diagnostic for explicit-only surfaces used outside explicit mode with source context.""" + observed_mode = getattr(module_spec, "mode", None) + context = _format_source_context( + getattr(module_spec, "function_name", None), + getattr(module_spec, "source_file", None), + getattr(module_spec, "source_line", None), + ) + observed_mode = "unknown" if observed_mode is None else observed_mode + return RuntimeError( + f"{surface} is an auto-mode contract violation{context}: it is only available in " + f'@pto.jit(mode="explicit"); current kernel mode is {observed_mode!r}. ' + "Move the kernel to explicit mode before authoring this surface." + ) + + +def invalid_jit_mode_error( + mode: str, + *, + function_name: str | None = None, + source_file: str | None = None, + source_line: int | None = None, +) -> ValueError: + """Return one diagnostic for unsupported ``@pto.jit(mode=...)`` values.""" + context = _format_source_context(function_name, source_file, source_line) + return ValueError( + f"unsupported PTODSL jit mode {mode!r}{context}; expected 'auto' or 'explicit'" + ) + + +def removed_ukernel_surface_error() -> AttributeError: + """Return one diagnostic for the removed ``pto.ukernel`` public surface.""" + return AttributeError( + 'pto.ukernel has been removed from the PTODSL public surface. ' + 'Use @pto.jit(mode="explicit") for explicit DMA orchestration, and call or inline ' + "@pto.simd/@pto.simt/@pto.cube directly from that kernel." + ) + + __all__ = [ "PTODSLTracingMisuseError", + "explicit_mode_required_error", + "explicit_mode_required_with_context_error", "host_tensor_metadata_error", "illegal_inline_subkernel_placement_error", "illegal_subkernel_placement_error", + "invalid_jit_mode_error", "native_python_control_flow_error", + "removed_ukernel_surface_error", "simd_value_escape_error", "subkernel_host_tensor_boundary_error", "subkernel_signature_boundary_error", diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py index ca21938f2..bae9020da 100644 --- a/ptodsl/ptodsl/_jit.py +++ b/ptodsl/ptodsl/_jit.py @@ -9,6 +9,9 @@ from __future__ import annotations +import inspect + +from ._diagnostics import invalid_jit_mode_error from ._kernel_compilation import CompiledKernelHandle, KernelCompiler from ._kernel_signature import parse_jit_kernel_signature from ._tracing import ( @@ -20,7 +23,28 @@ from mlir.ir import InsertionPoint -_MODULE_ATTRS = ("pto.target_arch", "pto.kernel_kind") +_MODULE_ATTRS = ("pto.target_arch", "pto.kernel_kind", "pto.mode") + + +def _normalize_mode(mode: str, *, fn=None) -> str: + if mode not in {"auto", "explicit"}: + source_file = None + source_line = None + function_name = None + if fn is not None: + function_name = fn.__name__ + try: + source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) + except (OSError, TypeError): + source_file = None + source_line = getattr(getattr(fn, "__code__", None), "co_firstlineno", None) + raise invalid_jit_mode_error( + mode, + function_name=function_name, + source_file=source_file, + source_line=source_line, + ) + return mode def _module_attr_map(module): @@ -32,8 +56,9 @@ def merge_jit_modules(*kernels: KernelHandle): """ Merge multiple ``@pto.jit`` flat-module kernels into one MLIR module. - Each handle must have been compiled with the same ``target`` and - ``kernel_kind`` module attributes. Function order follows *kernels*. + Each handle must have been compiled with the same ``target``, + ``kernel_kind``, and ``mode`` module attributes. Function order follows + *kernels*. """ if not kernels: raise ValueError("merge_jit_modules() requires at least one kernel handle") @@ -62,6 +87,7 @@ def jit( *, target: str = "a5", kernel_kind: str = "vector", + mode: str = "auto", func_attr: str = None, ): """ @@ -72,6 +98,7 @@ def jit( name: IR function name (defaults to the Python function name). target: Target architecture string, e.g. ``"a5"``. kernel_kind: ``"vector"`` or ``"cube"`` – sets ``pto.kernel_kind``. + mode: ``"auto"`` or ``"explicit"`` – sets ``pto.mode``. func_attr: Optional function attribute. Pass ``"pto.aicore"`` to select the flat-module structure with the aicore attribute. @@ -86,6 +113,12 @@ def jit( def decorator(fn): fn_name = name or fn.__name__ kernel_signature = parse_jit_kernel_signature(fn) + normalized_mode = _normalize_mode(mode, fn=fn) + source_file = None + try: + source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) + except (OSError, TypeError): + source_file = None module_style = ( ModuleStyle.FLAT_AICORE if func_attr == "pto.aicore" @@ -97,7 +130,10 @@ def decorator(fn): function_name=fn_name, target_arch=target, kernel_kind=kernel_kind, + mode=normalized_mode, module_style=module_style, + source_file=source_file, + source_line=getattr(fn.__code__, "co_firstlineno", None), ), kernel_signature, fn, diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index d632fd319..b3a473ca0 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -22,8 +22,10 @@ - ``partition_view`` infers the PartitionTensorViewType from the source type. """ +from functools import wraps + from ._bootstrap import make_context # noqa: F401 – ensure MLIR on sys.path -from ._diagnostics import tile_row_alignment_error +from ._diagnostics import explicit_mode_required_with_context_error, tile_row_alignment_error from ._host_tensors import resolve_tensor_data_entry from ._scalar_coercion import coerce_scalar_to_type, materialize_scalar_literal from ._runtime_scalar_ops import classify_runtime_scalar_type, emit_runtime_binary_op @@ -129,6 +131,31 @@ def _validate_sync_pipe(pipe, *, context: str, allowed: tuple[str, ...]): raise ValueError(f"{context} expects pipe to be one of {expected}, got <{canonical}>") +def _require_explicit_mode(surface: str): + try: + from ._tracing.active import current_session + session = current_session() + except Exception: + session = None + if session is None: + return + current_mode = getattr(session.module_spec, "mode", None) + if current_mode != "explicit": + raise explicit_mode_required_with_context_error(surface, session.module_spec) + + +def _explicit_mode_only(surface: str): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + _require_explicit_mode(surface) + return fn(*args, **kwargs) + + return wrapper + + return decorator + + # ── Constants ──────────────────────────────────────────────────────────────── def const(value: int, *, dtype=None): @@ -2666,6 +2693,7 @@ def _require_pto_ptr_operand(value, *, context: str): return raw_value +@_explicit_mode_only("pto.mte_load(...)") def mte_load(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=None, pad=None): """ Ptr-based GM->UB DMA wrapper aligned with the underlying ``pto.dma_load`` surface. @@ -2704,6 +2732,7 @@ def mte_load(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=None ) +@_explicit_mode_only("pto.mte_store(...)") def mte_store(source, destination, len_burst, *, nburst, loops=None): """Ptr-based UB->GM DMA wrapper aligned with the underlying ``pto.dma_store`` surface.""" n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( @@ -2780,6 +2809,7 @@ def _normalize_dma_pad(pad, *, context: str): ) +@_explicit_mode_only("pto.mte_gm_ub(...)") def mte_gm_ub(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=None, pad=None): """``pto.mte_gm_ub`` – grouped GM-to-UB DMA surface.""" n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( @@ -2812,6 +2842,7 @@ def mte_gm_ub(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=Non ) +@_explicit_mode_only("pto.mte_ub_gm(...)") def mte_ub_gm(source, destination, len_burst, *, nburst, loops=None): """``pto.mte_ub_gm`` – grouped UB-to-GM DMA surface.""" n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( @@ -2836,6 +2867,7 @@ def mte_ub_gm(source, destination, len_burst, *, nburst, loops=None): ) +@_explicit_mode_only("pto.mte_ub_ub(...)") def mte_ub_ub(source, destination, len_burst, *, nburst): """``pto.mte_ub_ub`` – grouped UB-to-UB DMA surface.""" n_burst, src_stride, dst_stride = _normalize_dma_group( @@ -2853,6 +2885,7 @@ def mte_ub_ub(source, destination, len_burst, *, nburst): ) +@_explicit_mode_only("pto.mte_ub_l1(...)") def mte_ub_l1(source, destination, len_burst, *, nburst): """``pto.mte_ub_l1`` – grouped UB-to-L1 DMA surface.""" n_burst, src_stride, dst_stride = _normalize_dma_group( @@ -2876,6 +2909,7 @@ def mem_bar(barrier_type): _pto.MemBarOp(kind=_membar_attr(barrier_name)) +@_explicit_mode_only("pto.mte_l1_l0a(...)") def mte_l1_l0a(source, destination, m, k, *, transpose=False): """``pto.mte_l1_l0a`` – cube-side LEFT staging.""" _pto.MteL1L0aOp( @@ -2887,6 +2921,7 @@ def mte_l1_l0a(source, destination, m, k, *, transpose=False): ) +@_explicit_mode_only("pto.mte_l1_l0b(...)") def mte_l1_l0b(source, destination, k, n, *, transpose=False): """``pto.mte_l1_l0b`` – cube-side RIGHT staging.""" _pto.MteL1L0bOp( @@ -2898,6 +2933,7 @@ def mte_l1_l0b(source, destination, k, n, *, transpose=False): ) +@_explicit_mode_only("pto.mte_l0c_ub(...)") def mte_l0c_ub(source, destination, m, n, src_stride, dst_stride, sub_blockid=0, *, dst_mode="single"): """``pto.mte_l0c_ub`` – ACC to UB store.""" _pto.MteL0cUbOp( diff --git a/ptodsl/ptodsl/_subkernels.py b/ptodsl/ptodsl/_subkernels.py index 380c1ff59..66d44e442 100644 --- a/ptodsl/ptodsl/_subkernels.py +++ b/ptodsl/ptodsl/_subkernels.py @@ -27,7 +27,6 @@ class KernelRole(str, Enum): - UKERNEL = "ukernel" CUBE = "cube" SIMD = "simd" SIMT = "simt" @@ -123,10 +122,9 @@ def _find_transient_simd_escape(value): def _validate_subkernel_placement(role: KernelRole, outer_frame, *, inline: bool = False) -> None: if outer_frame is None: return - if role == KernelRole.UKERNEL or outer_frame.role != KernelRole.UKERNEL.value: - if inline: - raise illegal_inline_subkernel_placement_error(role.value, outer_frame.role) - raise illegal_subkernel_placement_error(role.value, outer_frame.role) + if inline: + raise illegal_inline_subkernel_placement_error(role.value, outer_frame.role) + raise illegal_subkernel_placement_error(role.value, outer_frame.role) class _SubkernelSurface: @@ -184,10 +182,6 @@ def _decorate_subkernel(role: KernelRole, fn=None, *, name: str | None = None, t return _subkernel_decorator(role, name=name, target=target) -def ukernel(fn=None, *, name: str | None = None, target: str = "a5"): - return _decorate_subkernel(KernelRole.UKERNEL, fn, name=name, target=target) - - def cube(fn=None, *, name: str | None = None, target: str = "a5"): return _decorate_subkernel(KernelRole.CUBE, fn, name=name, target=target) @@ -204,7 +198,6 @@ def simt(fn=None, *, name: str | None = None, target: str = "a5"): "KernelRole", "SubkernelSpec", "SubkernelTemplate", - "ukernel", "cube", "simd", "simt", diff --git a/ptodsl/ptodsl/_tile_template_tracing.py b/ptodsl/ptodsl/_tile_template_tracing.py index 59bcad95c..971e8ece3 100644 --- a/ptodsl/ptodsl/_tile_template_tracing.py +++ b/ptodsl/ptodsl/_tile_template_tracing.py @@ -288,7 +288,10 @@ def __init__(self, descriptor: "TileTemplate", tile_specs: dict[str, TileSpec]): function_name=descriptor.name, target_arch=descriptor.target, kernel_kind="vector", + mode="auto", module_style=ModuleStyle.NESTED, + source_file=inspect.getsourcefile(descriptor.py_fn) or inspect.getfile(descriptor.py_fn), + source_line=getattr(descriptor.py_fn.__code__, "co_firstlineno", None), ) ) self.descriptor = descriptor diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py index 87a1c2d0f..7012ddfa1 100644 --- a/ptodsl/ptodsl/_tracing/module_builder.py +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -30,7 +30,10 @@ class KernelModuleSpec: function_name: str target_arch: str kernel_kind: str + mode: str = "auto" module_style: ModuleStyle = ModuleStyle.NESTED + source_file: str | None = None + source_line: int | None = None def _kernel_kind_attr(kernel_kind: str): @@ -41,6 +44,7 @@ def _build_flat_aicore_module(spec: KernelModuleSpec, arg_types): module = Module.create() module.operation.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) module.operation.attributes["pto.kernel_kind"] = _kernel_kind_attr(spec.kernel_kind) + module.operation.attributes["pto.mode"] = StringAttr.get(spec.mode) fn_ty = func.FunctionType.get(arg_types, []) with InsertionPoint(module.body): ir_fn = func.FuncOp(spec.function_name, fn_ty) @@ -51,11 +55,13 @@ def _build_flat_aicore_module(spec: KernelModuleSpec, arg_types): def _build_nested_module(spec: KernelModuleSpec, arg_types): outer = Module.create() outer.operation.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) + outer.operation.attributes["pto.mode"] = StringAttr.get(spec.mode) with InsertionPoint(outer.body): inner_op = Operation.create("builtin.module", regions=1) inner_op.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) inner_op.attributes["pto.kernel_kind"] = _kernel_kind_attr(spec.kernel_kind) + inner_op.attributes["pto.mode"] = StringAttr.get(spec.mode) inner_body = inner_op.regions[0].blocks.append() with InsertionPoint(inner_body): diff --git a/ptodsl/ptodsl/_tracing/runtime.py b/ptodsl/ptodsl/_tracing/runtime.py index 358ce6b10..630cd9931 100644 --- a/ptodsl/ptodsl/_tracing/runtime.py +++ b/ptodsl/ptodsl/_tracing/runtime.py @@ -62,7 +62,7 @@ def finalize_session(self, session): def dispatch_subkernel_call(self, subkernel, *args, **kwargs): """Dispatch a decorated PTODSL subkernel call in the active trace.""" session = require_active_session(f"@pto.{subkernel.spec.role.value}") - if subkernel.spec.role.value in {"ukernel", "cube", "simd"}: + if subkernel.spec.role.value in {"cube", "simd"}: return session.lower_inline_subkernel(subkernel, *args, **kwargs) if subkernel.spec.role.value == "simt": return session.lower_simt_helper_subkernel(subkernel, *args, **kwargs) diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index 9ad179199..293fab0c4 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -20,6 +20,8 @@ internally as ``_pto`` (``from mlir.dialects import pto as _pto``). """ +from ._diagnostics import removed_ukernel_surface_error + # ── Types ───────────────────────────────────────────────────────────────────── from ._types import ( # noqa: F401 float32, float16, bf16, @@ -104,7 +106,7 @@ # ── Decorator ───────────────────────────────────────────────────────────────── from ._jit import jit, KernelHandle, merge_jit_modules # noqa: F401 -from ._subkernels import ukernel, cube, simd, simt # noqa: F401 +from ._subkernels import cube, simd, simt # noqa: F401 # ── Shorthand dtype aliases ─────────────────────────────────────────────────── f32 = float32 @@ -117,3 +119,9 @@ mask_b8 = mask_type("b8") mask_b16 = mask_type("b16") mask_b32 = mask_type("b32") + + +def __getattr__(name): + if name == "ukernel": + raise removed_ukernel_surface_error() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/test/python/ptodsl_docs_fragment_fixtures.py b/test/python/ptodsl_docs_fragment_fixtures.py index 41f419de3..7dc5d6c94 100644 --- a/test/python/ptodsl_docs_fragment_fixtures.py +++ b/test/python/ptodsl_docs_fragment_fixtures.py @@ -344,13 +344,13 @@ def tail_simd_helper_probe(*, BLOCK: pto.constexpr = 128): kernel_entry_direct_l3_call_probe = my_kernel """ ), - "kernel_entry.ukernel_signature": _fixture( + "kernel_entry.explicit_signature": _fixture( f""" {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") - def kernel_entry_ukernel_signature_probe( + @pto.jit(target="a5", mode="explicit") + def kernel_entry_explicit_signature_probe( A: pto.tensor_spec(rank=2, dtype=pto.f32), *, BLOCK: pto.constexpr = 16, @@ -359,10 +359,10 @@ def kernel_entry_ukernel_signature_probe( part = pto.partition_view(view, offsets=[0, 0], sizes=[1, BLOCK]) tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) scratch = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, memory_space=pto.MemorySpace.LEFT) - my_ukernel(part, tile, scratch, tile.as_ptr(), pto.const(0, dtype=pto.i32)) + my_orchestration_helper(part, tile, scratch, tile.as_ptr(), pto.const(0, dtype=pto.i32)) """ ), - "kernel_entry.ukernel_body": _fixture( + "kernel_entry.explicit_body": _fixture( f""" @pto.cube def qk_matmul(q_tile: pto.Tile, k_tile: pto.Tile, s_tile: pto.Tile): @@ -377,8 +377,8 @@ def online_softmax(s_tile: pto.Tile, o_tile: pto.Tile, rows: pto.i32, cols: pto. {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") - def kernel_entry_ukernel_body_probe( + @pto.jit(target="a5", mode="explicit") + def kernel_entry_explicit_body_probe( K: pto.tensor_spec(rank=2, dtype=pto.f16), V: pto.tensor_spec(rank=2, dtype=pto.f16), O: pto.tensor_spec(rank=2, dtype=pto.f32), @@ -400,12 +400,34 @@ def kernel_entry_ukernel_body_probe( process_block(q_tile, k_part, v_part, k_tile, v_tile, s_tile, o_tile, o_part, ROWS, COLS) """ ), + "kernel_entry.inline_explicit_scope": _fixture( + f""" + @pto.jit(target="a5", mode="explicit") + def kernel_entry_inline_explicit_scope_probe( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 16, + ): + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + part = pto.partition_view(a_view, offsets=[0, 0], sizes=[1, BLOCK]) + out_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[1, BLOCK]) + tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, valid_shape=[1, BLOCK]) + row_bytes = BLOCK * pto.bytewidth(pto.f32) + pto.mte_load(part.as_ptr(), tile.as_ptr(), 0, row_bytes, + nburst=(1, row_bytes, row_bytes)) + pto.pipe_barrier(pto.Pipe.ALL) + pto.mte_store(tile.as_ptr(), out_part.as_ptr(), row_bytes, + nburst=(1, row_bytes, row_bytes)) + """ + ), "kernel_entry.cube_signature": _fixture( f""" {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def kernel_entry_cube_signature_probe( *, BLOCK_M: pto.constexpr = 16, @@ -464,7 +486,6 @@ def kernel_entry_simt_signature_probe(*, BLOCK: pto.constexpr = 8): ), "kernel_entry.inline_simd_scope": _fixture( f""" - @pto.ukernel def kernel_entry_inline_simd_scope( a_tile: pto.Tile, b_tile: pto.Tile, @@ -486,7 +507,6 @@ def kernel_entry_inline_simd_scope_probe(*, BLOCK: pto.constexpr = 128): ), "kernel_entry.inline_simt_scope": _fixture( f""" - @pto.ukernel def kernel_entry_inline_simt_scope( o_prev_tile: pto.Tile, pv_tile: pto.Tile, @@ -514,7 +534,7 @@ def kernel_entry_inline_simt_scope_probe(*, BLOCK: pto.constexpr = 8): ), "kernel_entry.inline_cube_scope": _fixture( f""" - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def kernel_entry_inline_cube_scope_probe( *, BLOCK_M: pto.constexpr = 16, @@ -661,9 +681,8 @@ def data_movement_tload_probe( {SNIPPET_PLACEHOLDER} """ ), - "data_movement.ukernel_dma": _fixture( + "data_movement.explicit_dma": _fixture( f""" - @pto.ukernel def process_block( k_part: pto.PartitionTensorView, v_part: pto.PartitionTensorView, @@ -677,8 +696,8 @@ def process_block( {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") - def data_movement_ukernel_dma_probe( + @pto.jit(target="a5", mode="explicit") + def data_movement_explicit_dma_probe( K: pto.tensor_spec(rank=2, dtype=pto.f16), V: pto.tensor_spec(rank=2, dtype=pto.f16), O: pto.tensor_spec(rank=2, dtype=pto.f32), @@ -698,9 +717,119 @@ def data_movement_ukernel_dma_probe( process_block(k_part, v_part, k_tile, v_tile, o_tile, o_part, ROWS, COLS) """ ), + "sync_ops.flag_pattern_explicit": _fixture( + f""" + @pto.cube + def qk_matmul(q_tile: pto.Tile, k_tile: pto.Tile, p_tile: pto.Tile): + return + + + @pto.cube + def pv_matmul(p_tile: pto.Tile, v_tile: pto.Tile, o_tile: pto.Tile): + return + + + {SNIPPET_PLACEHOLDER} + + + @pto.jit(target="a5", mode="explicit") + def sync_ops_flag_pattern_explicit_probe( + K: pto.tensor_spec(rank=2, dtype=pto.f16), + V: pto.tensor_spec(rank=2, dtype=pto.f16), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + ROWS: pto.constexpr = 8, + COLS: pto.constexpr = 16, + ): + k_view = pto.make_tensor_view(K, shape=K.shape, strides=K.strides) + v_view = pto.make_tensor_view(V, shape=V.shape, strides=V.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + q_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) + k_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) + v_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) + p_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) + k_part = pto.partition_view(k_view, offsets=[0, 0], sizes=[ROWS, COLS]) + v_part = pto.partition_view(v_view, offsets=[0, 0], sizes=[ROWS, COLS]) + o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[ROWS, COLS]) + gemm_block( + q_tile, + k_part, + v_part, + k_tile, + v_tile, + p_tile, + o_tile, + o_part, + pto.const(ROWS, dtype=pto.i32), + pto.const(COLS, dtype=pto.i32), + ) + """ + ), + "sync_ops.phase_barrier_explicit": _fixture( + f""" + @pto.cube + def qk_matmul(q_tile: pto.Tile, k_tile: pto.Tile, s_tile: pto.Tile): + return + + + @pto.simd + def online_softmax(s_tile: pto.Tile, p_tile: pto.Tile, rows: pto.i32, cols: pto.i32): + return + + + @pto.cube + def pv_matmul(p_tile: pto.Tile, v_tile: pto.Tile, pv_tile: pto.Tile): + return + + + @pto.simt + def blend_output(o_prev_tile: pto.Tile, pv_tile: pto.Tile, o_next_tile: pto.Tile, rows: pto.i32, cols: pto.i32): + return + + + {SNIPPET_PLACEHOLDER} + + + @pto.jit(target="a5", mode="explicit") + def sync_ops_phase_barrier_explicit_probe( + K: pto.tensor_spec(rank=2, dtype=pto.f16), + V: pto.tensor_spec(rank=2, dtype=pto.f16), + *, + ROWS: pto.constexpr = 8, + COLS: pto.constexpr = 16, + ): + k_view = pto.make_tensor_view(K, shape=K.shape, strides=K.strides) + v_view = pto.make_tensor_view(V, shape=V.shape, strides=V.strides) + q_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) + k_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) + v_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) + s_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) + p_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) + pv_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) + o_prev_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) + o_next_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) + k_part = pto.partition_view(k_view, offsets=[0, 0], sizes=[ROWS, COLS]) + v_part = pto.partition_view(v_view, offsets=[0, 0], sizes=[ROWS, COLS]) + flash_attention_block( + q_tile, + k_part, + v_part, + k_tile, + v_tile, + s_tile, + p_tile, + pv_tile, + o_prev_tile, + o_next_tile, + pto.const(ROWS, dtype=pto.i32), + pto.const(COLS, dtype=pto.i32), + ) + """ + ), "data_movement.grouped_dma_ptrs": _fixture( f""" - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def data_movement_grouped_dma_ptrs_probe(): gm_src = pto.castptr(pto.ui64(0), pto.ptr(pto.f16, "gm")) gm_dst = pto.castptr(pto.ui64(0), pto.ptr(pto.f16, "gm")) @@ -753,7 +882,7 @@ def qk_matmul( {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def data_movement_cube_helper_probe( *, BLOCK_M: pto.constexpr = 16, @@ -833,116 +962,6 @@ def sync_ops_basic_probe(): {SNIPPET_PLACEHOLDER} """ ), - "sync_ops.flag_pattern_ukernel": _fixture( - f""" - @pto.cube - def qk_matmul(q_tile: pto.Tile, k_tile: pto.Tile, p_tile: pto.Tile): - return - - - @pto.cube - def pv_matmul(p_tile: pto.Tile, v_tile: pto.Tile, o_tile: pto.Tile): - return - - - {SNIPPET_PLACEHOLDER} - - - @pto.jit(target="a5") - def sync_ops_flag_pattern_ukernel_probe( - K: pto.tensor_spec(rank=2, dtype=pto.f16), - V: pto.tensor_spec(rank=2, dtype=pto.f16), - O: pto.tensor_spec(rank=2, dtype=pto.f32), - *, - ROWS: pto.constexpr = 8, - COLS: pto.constexpr = 16, - ): - k_view = pto.make_tensor_view(K, shape=K.shape, strides=K.strides) - v_view = pto.make_tensor_view(V, shape=V.shape, strides=V.strides) - o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) - q_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) - k_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) - v_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) - p_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) - o_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) - k_part = pto.partition_view(k_view, offsets=[0, 0], sizes=[ROWS, COLS]) - v_part = pto.partition_view(v_view, offsets=[0, 0], sizes=[ROWS, COLS]) - o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[ROWS, COLS]) - gemm_block( - q_tile, - k_part, - v_part, - k_tile, - v_tile, - p_tile, - o_tile, - o_part, - pto.const(ROWS, dtype=pto.i32), - pto.const(COLS, dtype=pto.i32), - ) - """ - ), - "sync_ops.phase_barrier_ukernel": _fixture( - f""" - @pto.cube - def qk_matmul(q_tile: pto.Tile, k_tile: pto.Tile, s_tile: pto.Tile): - return - - - @pto.simd - def online_softmax(s_tile: pto.Tile, p_tile: pto.Tile, rows: pto.i32, cols: pto.i32): - return - - - @pto.cube - def pv_matmul(p_tile: pto.Tile, v_tile: pto.Tile, pv_tile: pto.Tile): - return - - - @pto.simt - def blend_output(o_prev_tile: pto.Tile, pv_tile: pto.Tile, o_next_tile: pto.Tile, rows: pto.i32, cols: pto.i32): - return - - - {SNIPPET_PLACEHOLDER} - - - @pto.jit(target="a5") - def sync_ops_phase_barrier_ukernel_probe( - K: pto.tensor_spec(rank=2, dtype=pto.f16), - V: pto.tensor_spec(rank=2, dtype=pto.f16), - *, - ROWS: pto.constexpr = 8, - COLS: pto.constexpr = 16, - ): - k_view = pto.make_tensor_view(K, shape=K.shape, strides=K.strides) - v_view = pto.make_tensor_view(V, shape=V.shape, strides=V.strides) - q_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) - k_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) - v_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f16) - s_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) - p_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) - pv_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) - o_prev_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) - o_next_tile = pto.alloc_tile(shape=[ROWS, COLS], dtype=pto.f32) - k_part = pto.partition_view(k_view, offsets=[0, 0], sizes=[ROWS, COLS]) - v_part = pto.partition_view(v_view, offsets=[0, 0], sizes=[ROWS, COLS]) - flash_attention_block( - q_tile, - k_part, - v_part, - k_tile, - v_tile, - s_tile, - p_tile, - pv_tile, - o_prev_tile, - o_next_tile, - pto.const(ROWS, dtype=pto.i32), - pto.const(COLS, dtype=pto.i32), - ) - """ - ), "flash_attention.l1_tensor_views": _fixture( f""" @pto.jit(target="a5") @@ -1037,7 +1056,6 @@ def _block_valid_extent(total, block_index, block_size): return _min_index(total - block_index * block_size, pto.const(block_size)) - @pto.ukernel def kv_block_process( q_mat: pto.Tile, k_part: pto.PartitionTensorView, @@ -1066,7 +1084,7 @@ def kv_block_process( pto.pipe_barrier(pto.Pipe.ALL) - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def flash_attention_l1_loop_body_probe( Q: pto.tensor_spec(rank=4, dtype=pto.f32), K: pto.tensor_spec(rank=4, dtype=pto.f32), @@ -1158,7 +1176,7 @@ def flash_attention_l1_loop_body_probe( {SNIPPET_PLACEHOLDER} """ ), - "flash_attention.ukernel_phase": _fixture( + "flash_attention.explicit_phase": _fixture( f""" @pto.cube def qk_matmul( @@ -1222,8 +1240,7 @@ def materialize_tile_bounds(meta_ptr, valid_rows: pto.i32, valid_cols: pto.i32): scalar.store(valid_cols, meta_ptr + 2) - @pto.ukernel - def flash_attention_ukernel_phase( + def flash_attention_explicit_phase( q_mat: pto.Tile, k_part: pto.PartitionTensorView, v_part: pto.PartitionTensorView, @@ -1254,8 +1271,8 @@ def flash_attention_ukernel_phase( {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") - def flash_attention_ukernel_phase_probe( + @pto.jit(target="a5", mode="explicit") + def flash_attention_explicit_phase_probe( K: pto.tensor_spec(rank=4, dtype=pto.f32), V: pto.tensor_spec(rank=4, dtype=pto.f32), *, @@ -1292,7 +1309,7 @@ def flash_attention_ukernel_phase_probe( pv_acc_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, memory_space=pto.MemorySpace.ACC, valid_shape=[Br, D]) meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) meta_ptr = meta_tile.as_ptr() - flash_attention_ukernel_phase( + flash_attention_explicit_phase( q_mat, k_part, v_part, k_mat, v_mat, o_prev_tile, o_next_tile, m_prev_tile, l_prev_tile, m_next_tile, l_next_tile, @@ -1309,7 +1326,7 @@ def flash_attention_ukernel_phase_probe( {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def flash_attention_qk_cube_helper_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_KV: pto.constexpr = 16): Br = BLOCK_Q Bc = BLOCK_KV @@ -1328,7 +1345,7 @@ def flash_attention_qk_cube_helper_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_K {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def flash_attention_pv_cube_helper_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_KV: pto.constexpr = 16): Br = BLOCK_Q Bc = BLOCK_KV @@ -1344,7 +1361,6 @@ def flash_attention_pv_cube_helper_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_K ), "flash_attention.inline_simt_scope": _fixture( f""" - @pto.ukernel def flash_attention_inline_simt_scope( q_mat: pto.Tile, k_mat: pto.Tile, @@ -1546,7 +1562,7 @@ def gemm_tile( {SNIPPET_PLACEHOLDER} - @pto.jit(target="a5") + @pto.jit(target="a5", mode="explicit") def gemm_tile_probe(*, BLOCK_M: pto.constexpr = 64, BLOCK_K: pto.constexpr = 64, BLOCK_N: pto.constexpr = 64): a_mat = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_M, BLOCK_K]) b_mat = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_K, BLOCK_N]) diff --git a/test/python/ptodsl_flash_attention_demo_compile.py b/test/python/ptodsl_flash_attention_demo_compile.py index 521f0fd96..2d433ef6b 100644 --- a/test/python/ptodsl_flash_attention_demo_compile.py +++ b/test/python/ptodsl_flash_attention_demo_compile.py @@ -35,8 +35,18 @@ def expect_parse_roundtrip_and_verify(text: str, label: str) -> None: def load_flash_attention_demo(): - demo_path = REPO_ROOT / "ptodsl" / "demos" / "flash_attention_sketch.py" - expect(demo_path.is_file(), f"canonical flash attention demo is missing: {demo_path}") + demo_candidates = [ + REPO_ROOT / "ptodsl" / "examples" / "flash_attention_sketch.py", + REPO_ROOT / "ptodsl" / "demos" / "flash_attention_sketch.py", + ] + for demo_path in demo_candidates: + if demo_path.is_file(): + break + else: + raise AssertionError( + "canonical flash attention demo is missing: " + + ", ".join(str(path) for path in demo_candidates) + ) spec = spec_from_file_location("ptodsl_flash_attention_demo", demo_path) expect(spec is not None and spec.loader is not None, f"unable to create import spec for {demo_path}") @@ -54,6 +64,7 @@ def main() -> None: wrapper_text = demo.emit_flash_attention_mlir(head_dim=128, causal=False, block_q=128, block_kv=128) expect_parse_roundtrip_and_verify(wrapper_text, "flash attention wrapper-emitted MLIR") expect("func.func @flash_attention_kernel" in wrapper_text, "wrapper compile should emit the flash_attention_kernel entry") + expect('pto.mode = "explicit"' in wrapper_text, "flash attention wrapper compile should carry explicit mode metadata") expect("func.func @materialize_tile_bounds" in wrapper_text, "wrapper compile should emit the SIMT helper function") expect("pto.store_vfsimt_info" in wrapper_text, "wrapper compile should materialize SIMT caller metadata setup") expect("pto.barrier " in wrapper_text, "demo phase boundaries should lower to pipe_barrier(Pipe.ALL)") @@ -80,6 +91,7 @@ def main() -> None: specialized_text = compiled.mlir_text() expect_parse_roundtrip_and_verify(specialized_text, "flash attention specialized MLIR") expect("func.func @flash_attention_kernel" in specialized_text, "direct compile should emit the flash_attention_kernel entry") + expect('pto.mode = "explicit"' in specialized_text, "direct compile should carry explicit mode metadata") expect("!pto.tile_buf None: runtime_metadata_kernel.verify() tile_surface_compute_probe.verify() shared_subkernel_lowering_probe.verify() - inline_subkernel_scope_probe.verify() simt_helper_lowering_probe.verify() carry_loop_lowering_probe.verify() branch_handle_then_only_probe.verify() @@ -1089,11 +1110,37 @@ def main() -> None: default_text = default_compiled.mlir_text() block64_text = block64.mlir_text() + explicit_text = host_vec_copy_explicit.compile().mlir_text() expect_parse_roundtrip_and_verify(default_text, "default host_vec_copy specialization") expect_parse_roundtrip_and_verify(block64_text, "BLOCK=64 host_vec_copy specialization") + expect_parse_roundtrip_and_verify(explicit_text, "explicit host_vec_copy specialization") expect("!pto.tile_buf" in default_text, "default specialization MLIR missing BLOCK=128 tile") expect("!pto.tile_buf" in block64_text, "BLOCK=64 specialization MLIR missing specialized tile") + expect('pto.mode = "auto"' in default_text, "default specialization should carry auto mode module metadata") + expect('pto.mode = "explicit"' in explicit_text, "explicit specialization should carry explicit mode module metadata") expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") + auto_mode_violation = expect_raises( + RuntimeError, + auto_mode_explicit_surface_violation_probe.compile, + '@pto.jit(mode="explicit")', + ) + expect( + "auto-mode contract violation" in str(auto_mode_violation), + "explicit-only surface use in auto mode should be diagnosed as an auto-mode contract violation", + ) + expect( + "auto_mode_explicit_surface_violation_probe" in str(auto_mode_violation), + "auto-mode DMA violation should identify the authored kernel name", + ) + expect( + __file__ in str(auto_mode_violation), + "auto-mode DMA violation should preserve the authored source file", + ) + expect_raises( + ValueError, + lambda: pto.merge_jit_modules(host_vec_copy.compile(), host_vec_copy_explicit.compile()), + "compatible module attributes", + ) runtime_metadata_text = runtime_metadata_kernel.compile().mlir_text() expect_parse_roundtrip_and_verify(runtime_metadata_text, "runtime metadata specialization") @@ -1142,8 +1189,7 @@ def main() -> None: expect( SUBKERNEL_OBSERVATIONS == [ ("cube", "top_level_cube_probe", 1), - ("ukernel", "ukernel_probe", 1), - ("simd", "nested_simd_probe", 2), + ("simd", "top_level_simd_probe", 1), ("simd", "nested_simd_probe", 1), ], f"unexpected shared subkernel lowering observations: {SUBKERNEL_OBSERVATIONS!r}", @@ -1154,9 +1200,9 @@ def main() -> None: expect_parse_roundtrip_and_verify(inline_subkernel_scope_text, "inline subkernel scope specialization") expect( INLINE_SUBKERNEL_SCOPE_OBSERVATIONS == [ - ("simt", "inline_simt", 2), - ("simd", "inline_simd", 2), - ("cube", "inline_cube", 2), + ("simt", "inline_simt", 1), + ("simd", "inline_simd", 1), + ("cube", "inline_cube", 1), ], f"unexpected inline subkernel scope observations: {INLINE_SUBKERNEL_SCOPE_OBSERVATIONS!r}", ) diff --git a/test/python/ptodsl_subkernel_diagnostics.py b/test/python/ptodsl_subkernel_diagnostics.py index 84da4bfa0..84cb556af 100644 --- a/test/python/ptodsl_subkernel_diagnostics.py +++ b/test/python/ptodsl_subkernel_diagnostics.py @@ -33,14 +33,26 @@ def expect_raises(callback, exc_type, *message_fragments: str) -> None: def define_bad_subkernel_signature_probe(): - @pto.ukernel + @pto.simd def bad_tensor_formal(A: pto.tensor_spec(rank=2, dtype=pto.f32)): pto.pipe_barrier(pto.Pipe.ALL) return bad_tensor_formal -@pto.ukernel +def define_removed_ukernel_surface_probe(): + return pto.ukernel + + +def define_invalid_jit_mode_probe(): + @pto.jit(target="a5", mode="hybrid") + def bad_mode_probe(): + pass + + return bad_mode_probe + + +@pto.simd def host_tensor_operand_probe(tensor): pto.pipe_barrier(pto.Pipe.ALL) @@ -87,28 +99,43 @@ def simd_value_escape_entry(*, TRACE_TOKEN: pto.constexpr = 0): def main() -> None: + expect_raises( + define_removed_ukernel_surface_probe, + AttributeError, + "pto.ukernel has been removed from the PTODSL public surface", + '@pto.jit(mode="explicit")', + "@pto.simd/@pto.simt/@pto.cube", + ) + expect_raises( + define_invalid_jit_mode_probe, + ValueError, + "unsupported PTODSL jit mode 'hybrid'", + "bad_mode_probe", + __file__, + "expected 'auto' or 'explicit'", + ) expect_raises( define_bad_subkernel_signature_probe, TypeError, - "@pto.ukernel parameter 'A' cannot be annotated with pto.tensor_spec(...)", + "@pto.simd parameter 'A' cannot be annotated with pto.tensor_spec(...)", "@pto.jit positional parameters", ) expect_raises( host_tensor_into_subkernel_probe.compile, TypeError, - "@pto.ukernel parameter 'tensor' uses a host tensor value", + "@pto.simd parameter 'tensor' uses a host tensor value", "host tensors only belong at the @pto.jit boundary", ) expect_raises( nested_simt_from_simd_entry.compile, RuntimeError, - "@pto.simt helper materialization is only supported from the top-level @pto.jit body or inside @pto.ukernel", + "@pto.simt helper materialization is only supported from the top-level @pto.jit body", "inside @pto.simd", ) expect_raises( nested_inline_simt_from_simd_entry.compile, RuntimeError, - "inline pto.simt() may only be used from the top-level @pto.jit body or inside @pto.ukernel", + "inline pto.simt() may only be used from the top-level @pto.jit body", "inside @pto.simd", ) expect_raises( From 5ba043dd093062461dcda00836bb4bc4d7e4d9bd Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 23 May 2026 21:02:50 +0800 Subject: [PATCH 22/31] Clean up the pending docs-as-test in the user guide --- ptodsl/docs/user_guide/01-introduction.md | 2 +- ptodsl/docs/user_guide/02-quick-start.md | 2 +- .../03-kernel-entry-and-subkernels.md | 11 +- .../user_guide/04-type-system-and-buffer.md | 4 +- .../docs/user_guide/12-additional-examples.md | 9 +- test/python/ptodsl_docs_as_test.py | 138 ++++++++++++- test/python/ptodsl_docs_fragment_fixtures.py | 186 ++++++++++++++++++ 7 files changed, 335 insertions(+), 17 deletions(-) diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index c705191ae..9de4d02e9 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -69,7 +69,7 @@ Python Wrapper L0 user-facing wrapper (NumPy, torch-npu, pure Pyth The outermost layer is plain Python. It handles ergonomic runtime concerns: allocating output tensors, extracting shapes and strides from framework tensors, compiling the JIT kernel, and launching it. Because the wrapper is just Python, you can freely mix in NumPy, torch-npu, or any other Python framework for pre- and post-processing, data preparation, or composing multiple kernel launches. It knows nothing about NPU internals — it is just a convenience function that most end users will call. - + ```python def flash_attention(Q, K, V, *, O=None, causal=False): if O is None: diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md index 27b77701d..986fc5884 100644 --- a/ptodsl/docs/user_guide/02-quick-start.md +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -137,7 +137,7 @@ Here `rows` and `cols` are dynamic — they come from `A.shape` and can differ a Once the kernel is defined, you compile it and then launch it: - + ```python # Compile once, cache the result. compiled = blocked_copy.compile(BLOCK=128) diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 0148c60f0..3c05058b5 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -66,13 +66,20 @@ Section 3.2 covers the two models in detail. ### Compilation and launch - + ```python +import numpy as np + + # Compile (traces the body, lowers through PTOAS, caches the result) compiled = kernel_name.compile(CONST_A=128, CONST_B=64) +# Allocate or obtain concrete tensors that match the declared host ABI. +A = np.random.randn(4, 128).astype(np.float32) +O = np.empty_like(A) + # Launch on NPU -compiled[grid, stream](tensor_1, tensor_2, ...) +compiled[grid, stream](A, O) ``` - `.compile(**constexprs)` — traces the kernel body with the given constexpr diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md index 19d4c38e4..6f1dcd8a9 100644 --- a/ptodsl/docs/user_guide/04-type-system-and-buffer.md +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -127,9 +127,9 @@ Masks are typed by bit granularity and must match the vector element width: Use `make_mask` to generate a mask from a pattern or scalar — it automatically selects the correct bit width from the element dtype: - + ```python -active = pto.make_mask(pto.f16, "PAT_ALL") # pattern-based full mask +active = pto.make_mask(pto.f16, pto.MaskPattern.ALL) # pattern-based full mask tail_mask, _ = pto.make_mask(pto.f32, tail_count) # load mask from tail count scalar ``` diff --git a/ptodsl/docs/user_guide/12-additional-examples.md b/ptodsl/docs/user_guide/12-additional-examples.md index cff564364..ff4937dda 100644 --- a/ptodsl/docs/user_guide/12-additional-examples.md +++ b/ptodsl/docs/user_guide/12-additional-examples.md @@ -46,7 +46,7 @@ def mat_add(A, B, O, *, BLOCK_M: pto.constexpr = 64, BLOCK_N: pto.constexpr = 12 The Python wrapper follows the same pattern as Chapter 2: - + ```python def mat_add_wrapper(A, B, O=None, stream=None): if O is None: @@ -248,11 +248,14 @@ def gemm( ### 12.3.3 Python wrapper - + ```python +import numpy as np + + def gemm_wrapper(A, B, O=None, stream=None): if O is None: - O = pto.empty([A.shape[0], B.shape[1]], dtype=A.dtype) + O = np.empty((A.shape[0], B.shape[1]), dtype=A.dtype) compiled = gemm.compile(BLOCK_M=64, BLOCK_K=64, BLOCK_N=64) compiled[1, stream](A, B, O) return O diff --git a/test/python/ptodsl_docs_as_test.py b/test/python/ptodsl_docs_as_test.py index e732bae75..0f7bfd2fb 100644 --- a/test/python/ptodsl_docs_as_test.py +++ b/test/python/ptodsl_docs_as_test.py @@ -7,6 +7,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import Iterable @@ -16,6 +17,7 @@ import subprocess import sys import tempfile +from unittest import mock REPO_ROOT = Path(__file__).resolve().parents[2] @@ -24,6 +26,7 @@ from ptodsl import pto, scalar from ptodsl._bootstrap import make_context +from ptodsl._runtime.launch import LaunchHandle, _marshal_launch_args from mlir.ir import Module from ptodsl_docs_fragment_fixtures import FRAGMENT_FIXTURES, render_fragment_fixture @@ -67,6 +70,15 @@ class DocTestDirective: fixture: str | None = None +@dataclass(frozen=True) +class LaunchRecord: + compiled: object + grid: int + stream: object + args: tuple[object, ...] + marshaled_arg_count: int + + def expect(condition: bool, message: str) -> None: if not condition: raise AssertionError(message) @@ -216,15 +228,38 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: ) return DocTestDirective(mode=mode, symbol=symbol, compile_kwargs=compile_kwargs) + if mode == "launch_fragment": + expect( + isinstance(fixture, str) and fixture, + f"{block_label(block)}: ptodsl-doc-test launch_fragment metadata must define a non-empty string 'fixture'", + ) + if symbol is not None: + expect( + isinstance(symbol, str) and symbol, + f"{block_label(block)}: ptodsl-doc-test launch_fragment 'symbol' must be a non-empty string when present", + ) + expect( + compile_kwargs is None, + f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " + "ptodsl-doc-test launch_fragment does not accept a 'compile' object; the snippet owns its compile/launch flow", + ) + return DocTestDirective(mode=mode, symbol=symbol, fixture=fixture) + expect( False, f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " - f"unsupported ptodsl-doc-test mode {mode!r}; only 'compile' and 'compile_fragment' are supported", + f"unsupported ptodsl-doc-test mode {mode!r}; only 'compile', 'compile_fragment', and 'launch_fragment' are supported", ) return DocTestDirective(mode=mode) -def execute_source(source: str, block: MarkdownCodeBlock, symbol: str | None = None) -> dict[str, object]: +def execute_source( + source: str, + block: MarkdownCodeBlock, + symbol: str | None = None, + *, + extra_namespace: dict[str, object] | None = None, +) -> dict[str, object]: namespace: dict[str, object] = { "__builtins__": __builtins__, "__name__": "__ptodsl_doc_snippet__", @@ -232,6 +267,8 @@ def execute_source(source: str, block: MarkdownCodeBlock, symbol: str | None = N "pto": pto, "scalar": scalar, } + if extra_namespace is not None: + namespace.update(extra_namespace) try: exec(compile(source, str(block.path), "exec"), namespace, namespace) except Exception as exc: @@ -241,6 +278,27 @@ def execute_source(source: str, block: MarkdownCodeBlock, symbol: str | None = N return namespace +@contextmanager +def capture_launch_records(): + records: list[LaunchRecord] = [] + + def fake_launch_call(self, *args): + marshaled = _marshal_launch_args(self._compiled._kernel_signature, args) + records.append( + LaunchRecord( + compiled=self._compiled, + grid=self._grid, + stream=self._stream, + args=tuple(args), + marshaled_arg_count=len(marshaled), + ) + ) + return None + + with mock.patch.object(LaunchHandle, "__call__", new=fake_launch_call): + yield records + + def verify_compiled_target( block: MarkdownCodeBlock, directive: DocTestDirective, @@ -311,6 +369,60 @@ def run_compile_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> Non namespace = execute_source(rendered_source, block, directive.symbol) verify_compiled_target(block, directive, namespace, ptoas_bin) + +def run_launch_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: + directive = parse_test_directive(block) + expect( + directive.fixture is not None, + f"{block_label(block, directive.symbol)}: launch_fragment mode requires a fixture id", + ) + expect( + directive.fixture in FRAGMENT_FIXTURES, + f"{block_label(block, directive.symbol)}: unknown fragment fixture {directive.fixture!r}", + ) + try: + rendered_source = render_fragment_fixture(FRAGMENT_FIXTURES[directive.fixture], block.text) + except ValueError as exc: + raise AssertionError( + f"{block_label(block, directive.symbol)}: fragment fixture {directive.fixture!r} is invalid: {exc}" + ) from exc + + with capture_launch_records() as launch_records: + execute_source( + rendered_source, + block, + directive.symbol, + extra_namespace={"PTODSL_DOC_LAUNCH_RECORDS": launch_records}, + ) + + expect( + bool(launch_records), + f"{block_label(block, directive.symbol)}: launch_fragment snippet did not execute any compiled[grid, stream](...) launch", + ) + + seen_compiled_ids: set[int] = set() + for record in launch_records: + compiled = record.compiled + compiled_id = id(compiled) + if compiled_id in seen_compiled_ids: + continue + seen_compiled_ids.add(compiled_id) + try: + compiled.verify() + except Exception as exc: + raise AssertionError( + f"{block_label(block, directive.symbol)}: compiled launch target verify() failed: " + f"{exc.__class__.__name__}: {exc}" + ) from exc + mlir_text = compiled.mlir_text() + expect( + isinstance(mlir_text, str) and mlir_text.strip(), + f"{block_label(block, directive.symbol)}: compiled launch target should expose non-empty mlir_text()", + ) + label = block_label(block, directive.symbol or getattr(compiled, "ir_function_name", None)) + expect_parse_roundtrip_and_verify(mlir_text, label) + run_ptoas_frontend_verify(ptoas_bin, mlir_text, label) + def scan_markdown_file(path: Path) -> MarkdownScanResult: lines = path.read_text(encoding="utf-8").splitlines(keepends=True) blocks: list[MarkdownCodeBlock] = [] @@ -394,18 +506,21 @@ def collect_test_blocks(blocks: Iterable[MarkdownCodeBlock]) -> tuple[MarkdownCo ) -def summarize_test_modes(blocks: Iterable[MarkdownCodeBlock]) -> tuple[int, int]: +def summarize_test_modes(blocks: Iterable[MarkdownCodeBlock]) -> tuple[int, int, int]: compile_count = 0 compile_fragment_count = 0 + launch_fragment_count = 0 for block in blocks: directive = parse_test_directive(block) if directive.mode == "compile": compile_count += 1 elif directive.mode == "compile_fragment": compile_fragment_count += 1 + elif directive.mode == "launch_fragment": + launch_fragment_count += 1 else: raise AssertionError(f"{block_label(block)}: unsupported docs-as-test mode {directive.mode!r}") - return compile_count, compile_fragment_count + return compile_count, compile_fragment_count, launch_fragment_count def main() -> None: @@ -416,19 +531,19 @@ def main() -> None: tagged_python_blocks = collect_tagged_python_blocks(python_blocks) test_count, pending_count = summarize_metadata(tagged_python_blocks) test_blocks = collect_test_blocks(tagged_python_blocks) - compile_test_count, compile_fragment_test_count = summarize_test_modes(test_blocks) + compile_test_count, compile_fragment_test_count, launch_fragment_test_count = summarize_test_modes(test_blocks) expect(bool(results), f"no markdown files found under {USER_GUIDE_ROOT}") expect(bool(python_blocks), f"no Python fenced code blocks found under {USER_GUIDE_ROOT}") - if compile_test_count or compile_fragment_test_count: + if compile_test_count or compile_fragment_test_count or launch_fragment_test_count: try: ptoas_bin = resolve_ptoas_binary() except FileNotFoundError as exc: compile_blocks = [ block for block in test_blocks - if parse_test_directive(block).mode in ("compile", "compile_fragment") + if parse_test_directive(block).mode in ("compile", "compile_fragment", "launch_fragment") ] fail_doc(compile_blocks[0].path, compile_blocks[0].start_line, str(exc)) else: @@ -444,6 +559,12 @@ def main() -> None: f"{block_label(block, directive.symbol)}: missing ptoas binary for compile_fragment-mode docs test", ) run_compile_fragment_block(block, ptoas_bin) + elif directive.mode == "launch_fragment": + expect( + ptoas_bin is not None, + f"{block_label(block, directive.symbol)}: missing ptoas binary for launch_fragment-mode docs test", + ) + run_launch_fragment_block(block, ptoas_bin) else: raise AssertionError(f"{block_label(block)}: unsupported docs-as-test mode {directive.mode!r}") @@ -455,7 +576,8 @@ def main() -> None: "ptodsl_docs_as_test: scanned " f"{markdown_count} markdown files, {block_count} fenced blocks, {python_count} python blocks " f"({test_count} test = {compile_test_count} compile + " - f"{compile_fragment_test_count} compile_fragment, {pending_count} pending, {untracked_count} untracked)" + f"{compile_fragment_test_count} compile_fragment + {launch_fragment_test_count} launch_fragment, " + f"{pending_count} pending, {untracked_count} untracked)" ) diff --git a/test/python/ptodsl_docs_fragment_fixtures.py b/test/python/ptodsl_docs_fragment_fixtures.py index 7dc5d6c94..6a4a4d39d 100644 --- a/test/python/ptodsl_docs_fragment_fixtures.py +++ b/test/python/ptodsl_docs_fragment_fixtures.py @@ -167,6 +167,14 @@ def type_system_mask_bitcast_probe(): {SNIPPET_PLACEHOLDER} """ ), + "type_system.make_mask": _fixture( + f""" + @pto.jit(target="a5") + def type_system_make_mask_probe(): + tail_count = pto.const(16, dtype=pto.i32) + {SNIPPET_PLACEHOLDER} + """ + ), "quick_start.make_tensor_view": _fixture( f""" @pto.jit(target="a5") @@ -219,6 +227,184 @@ def quick_start_tile_io_probe( {SNIPPET_PLACEHOLDER} """ ), + "launch.flash_attention_wrapper": _fixture( + f""" + import numpy as np + + + @pto.jit(target="a5") + def flash_attention_kernel( + Q: pto.tensor_spec(rank=4, dtype=pto.f32), + K: pto.tensor_spec(rank=4, dtype=pto.f32), + V: pto.tensor_spec(rank=4, dtype=pto.f32), + O: pto.tensor_spec(rank=4, dtype=pto.f32), + *, + BLOCK_Q: pto.constexpr = 128, + BLOCK_KV: pto.constexpr = 128, + CAUSAL: pto.constexpr = False, + ): + pto.get_block_idx() + + + batch = 2 + heads = 3 + seq_q = 4 + seq_k = 4 + dim = 8 + stream = object() + Q = np.random.randn(batch, seq_q, heads, dim).astype(np.float32) + K = np.random.randn(batch, seq_k, heads, dim).astype(np.float32) + V = np.random.randn(batch, seq_k, heads, dim).astype(np.float32) + + {SNIPPET_PLACEHOLDER} + + O = flash_attention(Q, K, V, causal=False) + assert O.shape == Q.shape + assert len(PTODSL_DOC_LAUNCH_RECORDS) == 1 + record = PTODSL_DOC_LAUNCH_RECORDS[0] + assert record.grid == batch * heads + assert record.stream is stream + assert len(record.args) == 4 + assert record.args[0] is Q + assert record.args[1] is K + assert record.args[2] is V + assert record.args[3] is O + assert record.marshaled_arg_count == 36 + """ + ), + "launch.blocked_copy_compile_and_launch": _fixture( + f""" + @pto.jit(target="a5") + def blocked_copy( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, + ): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + with pto.for_(0, rows, step=1) as row: + a_part = pto.partition_view(a_view, offsets=[row, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) + pto.tile.load(a_part, tile) + pto.tile.store(tile, o_part) + + + {SNIPPET_PLACEHOLDER} + + assert len(PTODSL_DOC_LAUNCH_RECORDS) == 1 + record = PTODSL_DOC_LAUNCH_RECORDS[0] + assert record.grid == 1 + assert record.stream is None + assert len(record.args) == 2 + assert record.args[0] is A + assert record.args[1] is O + assert record.marshaled_arg_count == 10 + """ + ), + "launch.generic_compile_and_launch": _fixture( + f""" + import numpy as np + + + @pto.jit(target="a5") + def kernel_name( + tensor_1: pto.tensor_spec(rank=2, dtype=pto.f32), + tensor_2: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + CONST_A: pto.constexpr = 128, + CONST_B: pto.constexpr = 64, + ): + pto.get_block_idx() + + + grid = 2 + stream = object() + + {SNIPPET_PLACEHOLDER} + + assert len(PTODSL_DOC_LAUNCH_RECORDS) == 1 + record = PTODSL_DOC_LAUNCH_RECORDS[0] + assert record.grid == grid + assert record.stream is stream + assert len(record.args) == 2 + assert record.args[0] is A + assert record.args[1] is O + assert record.marshaled_arg_count == 10 + """ + ), + "launch.mat_add_wrapper": _fixture( + f""" + import numpy as np + + + @pto.jit(target="a5") + def mat_add( + A: pto.tensor_spec(rank=3, dtype=pto.f32), + B: pto.tensor_spec(rank=3, dtype=pto.f32), + O: pto.tensor_spec(rank=3, dtype=pto.f32), + *, + BLOCK_M: pto.constexpr = 64, + BLOCK_N: pto.constexpr = 128, + ): + pto.get_block_idx() + + + {SNIPPET_PLACEHOLDER} + + A = np.random.randn(2, 64, 128).astype(np.float32) + B = np.random.randn(2, 64, 128).astype(np.float32) + O = mat_add_wrapper(A, B, stream=None) + assert O.shape == A.shape + assert len(PTODSL_DOC_LAUNCH_RECORDS) == 1 + record = PTODSL_DOC_LAUNCH_RECORDS[0] + assert record.grid == A.shape[0] + assert record.stream is None + assert len(record.args) == 3 + assert record.args[0] is A + assert record.args[1] is B + assert record.args[2] is O + assert record.marshaled_arg_count == 21 + """ + ), + "launch.gemm_wrapper": _fixture( + f""" + import numpy as np + + + @pto.jit(target="a5") + def gemm( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK_M: pto.constexpr = 64, + BLOCK_K: pto.constexpr = 64, + BLOCK_N: pto.constexpr = 64, + ): + pto.get_block_idx() + + + {SNIPPET_PLACEHOLDER} + + A = np.random.randn(64, 32).astype(np.float32) + B = np.random.randn(32, 16).astype(np.float32) + O = gemm_wrapper(A, B, stream=None) + assert O.shape == (A.shape[0], B.shape[1]) + assert len(PTODSL_DOC_LAUNCH_RECORDS) == 1 + record = PTODSL_DOC_LAUNCH_RECORDS[0] + assert record.grid == 1 + assert record.stream is None + assert len(record.args) == 3 + assert record.args[0] is A + assert record.args[1] is B + assert record.args[2] is O + assert record.marshaled_arg_count == 15 + """ + ), "control_flow.basic_for": _fixture( f""" @pto.jit(target="a5") From ade5bf77ed54dee130283bb37126800a7425c6f0 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sun, 24 May 2026 00:51:15 +0800 Subject: [PATCH 23/31] Clarify the pto.jit kernel signature --- ptodsl/README.md | 138 +++++- .../03-kernel-entry-and-subkernels.md | 123 +++++- ptodsl/examples/jit/README.md | 77 ---- .../jit/flash_attention_softmax_launch.py | 411 ++++++++++++++++++ ptodsl/examples/jit/tadd_launch.py | 55 ++- ptodsl/examples/softmax_dsl.py | 264 +++++------ ptodsl/examples/tilelang_codegen.py | 315 ++++++++++++++ ptodsl/ptodsl/_diagnostics.py | 24 + ptodsl/ptodsl/_jit.py | 18 +- ptodsl/ptodsl/_kernel_signature.py | 39 +- ptodsl/ptodsl/_runtime/codegen.py | 57 ++- ptodsl/ptodsl/_runtime/launch.py | 44 +- ptodsl/ptodsl/_runtime/native_build.py | 38 +- ptodsl/ptodsl/_tracing/module_builder.py | 1 + scripts/sim_dsl.sh | 113 +++++ test/python/ptodsl_docs_as_test.py | 9 +- test/python/ptodsl_jit_compile.py | 86 ++++ test/python/ptodsl_jit_diagnostics.py | 30 ++ 18 files changed, 1564 insertions(+), 278 deletions(-) delete mode 100644 ptodsl/examples/jit/README.md create mode 100644 ptodsl/examples/jit/flash_attention_softmax_launch.py create mode 100644 ptodsl/examples/tilelang_codegen.py create mode 100755 scripts/sim_dsl.sh diff --git a/ptodsl/README.md b/ptodsl/README.md index c16c430f6..a4e232ae7 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -55,6 +55,110 @@ pip install -e . --- +## JIT examples + +`ptodsl/examples/jit/` contains self-contained `@pto.jit` examples that cover +both compile-only and end-to-end launch flows. + +### Prerequisites for launch examples + +- `ptoas` + `ptodsl` installed as above +- CANN 9.0+ with `ASCEND_HOME_PATH` set +- For end-to-end launch: `torch`, `torch_npu`, `numpy` +- `bisheng` on `PATH` + +Set up the environment in each new shell: + +```bash +cd $PTOAS_REPO_ROOT +source set_ptoas_env.sh +source "${ASCEND_HOME_PATH}/bin/setenv.bash" +``` + +For CPU simulation with `msprof`, the wrapper script below will set the +simulator library path and `ulimit` for you. The normal PTOAS + CANN shell +setup above is still required. + +### `tadd_launch.py` + +Single script: kernel definition, compile, launch, and accuracy check. +Equivalent IR to the TileLang ST `tadd.pto` testcase. + +Compile-only: + +```bash +python3 ptodsl/examples/jit/tadd_launch.py --emit-mlir +``` + +Expected: MLIR containing `@TADD_f32_16x64` and `@TADD_f32_32x32`. + +Optional PTOAS frontend smoke: + +```bash +python3 ptodsl/examples/jit/tadd_launch.py --emit-mlir > /tmp/tadd_dsl.mlir +ptoas --emit-pto-ir /tmp/tadd_dsl.mlir -o - | head +``` + +End-to-end under the `msprof` CPU simulator: + +```bash +scripts/sim_dsl.sh ptodsl/examples/jit/tadd_launch.py +``` + +Expected output: + +```text +PASS f32_16x64 compile=0.024s launch=35.193s +PASS f32_32x32 compile=0.022s launch=35.926s +All cases passed. +``` + +Direct run on a real NPU: + +```bash +python3 ptodsl/examples/jit/tadd_launch.py +``` + +### `flash_attention_softmax_launch.py` + +Launchable flash-attention softmax-stage demo. It intentionally keeps the +online softmax update stage only, so the runtime path can be validated without +depending on the still-incomplete SIMT/cube coverage needed for the full +flash-attention stack. + +Compile-only: + +```bash +python3 ptodsl/examples/jit/flash_attention_softmax_launch.py --emit-mlir +``` + +End-to-end under the `msprof` CPU simulator: + +```bash +scripts/sim_dsl.sh ptodsl/examples/jit/flash_attention_softmax_launch.py +``` + +Expected output: + +```text +PASS rows8_seq128 +PASS rows17_seq96 +All cases passed. +``` + +Direct run on a real NPU: + +```bash +python3 ptodsl/examples/jit/flash_attention_softmax_launch.py +``` + +### Launch artifacts + +- `~/.cache/ptodsl/` — JIT-compiled kernel `.so` cache +- `build/msprof_res/` — `msprof` simulator trace output + +--- + ## Running regression checks ```bash @@ -147,17 +251,37 @@ it is intentionally not exported as `pto.scalar`. def MyKernel(): ... -@pto.jit(name="Softmax", kernel_kind="vector", target="a5", func_attr="pto.aicore") -def Softmax(arg0: pto.ptr(pto.float32, "gm"), n: pto.int32): +@pto.jit(name="Softmax", kernel_kind="vector", target="a5") +def Softmax( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): ... -print(MyKernel) # prints MLIR text -mod = MyKernel.mlir_module() # returns mlir.ir.Module +print(MyKernel) # prints MLIR text +mod = MyKernel.mlir_module() # returns mlir.ir.Module ``` -`func_attr="pto.aicore"` selects a flat single-module structure with the -`pto.aicore` function attribute (softmax style). Without it, a nested -double-module is emitted (TADD style). +`@pto.jit` now emits a flat aicore launch-entry module by default. The traced +entry function carries the `pto.aicore` attribute and lives directly under the +top-level module, which matches the runtime-launch path and merged-MLIR example +flow. + +PTODSL v1 keeps the public `@pto.jit` entry ABI intentionally narrow: + +- positional parameters are Python-native tensors declared with + `pto.tensor_spec(...)` +- positional runtime scalars use PTO scalar annotations such as `pto.i32`, + `pto.f32`, and `pto.i1`, while launch-time values remain ordinary Python + scalars +- keyword-only parameters annotated with `pto.constexpr` are compile-time + specialization knobs + +Typed pointers such as `pto.ptr(...)` remain valid PTODSL surface types inside +kernel bodies and explicit-mode sub-kernels, but they are not the recommended +host-visible `@pto.jit` parameter contract. Additional layered kernel entry modes and shared compute decorators are also exported on the public surface: `@pto.jit(mode="auto")`, diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 3c05058b5..4ec1fc0fb 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -41,14 +41,116 @@ def kernel_name( return ``` -**Positional parameters** are Python-native tensors — they arrive from NumPy, -torch-npu, or any framework with `.shape` and `.strides`. Inside the body, wrap -them with `make_tensor_view` to create GM descriptors. +### How to declare and pass parameters -**Keyword-only parameters** annotated with `pto.constexpr` are compile-time -constants. They must be provided at `.compile()` time and cannot change between -launches of the same compiled kernel. Use them for tile sizes, algorithmic knobs -(e.g., `CAUSAL`), and other values that the compiler can specialize against. +A `@pto.jit` kernel accepts three kinds of parameters. Each has a distinct role, +position in the signature, and way to supply the value: + +| Parameter kind | Position | Annotation | Pass the value at | +|---|---|---|---| +| **Tensor** | positional (before `*`) | `pto.tensor_spec(rank=N, dtype=...)` | launch time | +| **Runtime scalar** | positional (before `*`) | `pto.i32`, `pto.f32`, `pto.i1`, etc. | launch time | +| **Compile-time constant** | keyword-only (after `*`) | `pto.constexpr = ` | compile time | + +#### 1. Tensor parameters + +Declare a positional parameter with `pto.tensor_spec(rank=..., dtype=...)`. +At launch time, pass a **Python-native tensor** — a NumPy array, a torch-npu +tensor, or any object with `.shape`, `.dtype`, `.strides` (or `.stride()`), and +a data pointer (`.data_ptr()` or `.ptr`): + +```python +@pto.jit(target="a5") +def my_kernel( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), +): + # Inside the body, access shape/strides/dtype directly: + rows, cols = X.shape[0], X.shape[1] + # Then wrap with make_tensor_view(...) to build a GM descriptor: + x_view = pto.make_tensor_view(X, shape=X.shape, strides=X.strides) +``` + +#### 2. Runtime scalar parameters + +Declare a positional parameter with a PTO scalar annotation (`pto.i32`, +`pto.f32`, `pto.i1`, etc.). At launch time, pass an ordinary Python +`int`, `float`, or `bool`: + +```python +@pto.jit(target="a5") +def my_kernel( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + n: pto.i32, # pass an int at launch + alpha: pto.f32, # pass a float at launch +): + # Scalars arrive as PTO values and can be used directly in + # index math, loop bounds, comparisons, and sub-kernel calls: + limit = n // 2 +``` + +#### 3. Compile-time constants + +Declare after `*` with `pto.constexpr` and a default value. +Pass the value to `.compile(...)` — **not** at launch time: + +```python +@pto.jit(target="a5") +def my_kernel( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + # BLOCK is a Python value at trace time — use it for tile shapes, + # unrolled loops, or dtype arguments: + tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) +``` + +The compiler specializes the kernel for each combination of constexpr values. +Once compiled, the values are baked in — they cannot change between launches of +the same compiled instance. To use a different value, call `.compile(...)` again. + +### Full example: declare and launch + +Bringing all three kinds together: + +```python +@pto.jit(target="a5", mode="auto") +def scaled_bias_add( + X: pto.tensor_spec(rank=2, dtype=pto.f32), # tensor + O: pto.tensor_spec(rank=2, dtype=pto.f32), # tensor + alpha: pto.f32, # runtime scalar + bias: pto.f32, # runtime scalar + *, + BLOCK: pto.constexpr = 128, # compile-time constant +): + rows, cols = X.shape[0], X.shape[1] + # ... use alpha, bias, BLOCK inside the kernel body ... + return +``` + +```python +# Step 1 — compile: constexpr values go to .compile() +compiled = scaled_bias_add.compile(BLOCK=64) + +# Step 2 — launch: tensors and runtime scalars go to compiled[grid, stream](...) +import numpy as np +X = np.random.randn(4, 128).astype(np.float32) +O = np.empty_like(X) +compiled[1, None](X, O, 2.0, 1.0) # alpha=2.0, bias=1.0 +``` + +### What is NOT accepted at the entry + +The following types are intentionally **not** accepted as `@pto.jit` parameters: + +- `pto.ptr(...)` — typed pointers are available inside the kernel body and + across sub-kernel boundaries, but not at the host/kernel entry. +- `Tile`, `PartitionTensorView`, `VReg` — these are created inside the kernel + body, not passed from the host. + +They are valid **inside** the kernel and across sub-kernel calls, just not at +the public host/kernel boundary. ### `mode`: auto vs explicit @@ -62,6 +164,11 @@ programming model: explicit synchronization, and direct pointer manipulation — on top of everything available in `auto`. +`mode` changes what you can write **inside the kernel body**. It does **not** +change the recommended host-visible entry ABI: both modes use the same +`tensor_spec(...)` + runtime scalar + `constexpr` contract at the `@pto.jit` +boundary. + Section 3.2 covers the two models in detail. ### Compilation and launch @@ -203,6 +310,8 @@ direct access to the hardware unit. In auto mode, a sub-kernel's parameters are restricted to `Tile` and PTO scalar types — the compiler owns staging and sync. In explicit mode, sub-kernels may also accept `PartitionTensorView` and `pto.ptr` parameters, matching the richer type surface available there. +This richer pointer surface belongs to the **in-kernel orchestration and +sub-kernel boundary**, not to the public `@pto.jit` host entry ABI. Section 3.3 covers each sub-kernel decorator in detail. ## 3.2 Programming models: auto vs explicit diff --git a/ptodsl/examples/jit/README.md b/ptodsl/examples/jit/README.md deleted file mode 100644 index a5cf7dbe5..000000000 --- a/ptodsl/examples/jit/README.md +++ /dev/null @@ -1,77 +0,0 @@ -# JIT examples - -Python `@pto.jit` kernels with compile-only and end-to-end launch smoke tests. - -## Prerequisites - -- ptoas + ptodsl installed per [ptodsl README](../README.md) (`quick_install.sh`, `pip install -e .`) -- CANN 9.0+ with `ASCEND_HOME_PATH` set -- For end-to-end launch: `torch`, `torch_npu`, `numpy`; `bisheng` on PATH - -## Environment (every shell) - -```bash -cd $PTOAS_REPO_ROOT # e.g. /workdir/ptoas_a5 -source set_ptoas_env.sh -source "${ASCEND_HOME_PATH}/bin/setenv.bash" -``` - -For CPU simulation (msprof), also: - -```bash -export LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/tools/simulator/Ascend950PR_9599/lib:${LD_LIBRARY_PATH}" -ulimit -n 65535 -``` - -## `tadd_launch.py` - -Single script: kernel definition, compile, launch, and accuracy check. Equivalent IR to the TileLang ST `tadd.pto` testcase. - -### Compile-only: DSL → MLIR - -```bash -cd ptodsl/examples/jit -python3 tadd_launch.py --emit-mlir -``` - -Expected: MLIR module text containing `@TADD_f32_16x64` and `@TADD_f32_32x32`. - -Optional — run through the ptoas frontend: - -```bash -python3 tadd_launch.py --emit-mlir > /tmp/tadd_dsl.mlir -ptoas --emit-pto-ir /tmp/tadd_dsl.mlir -o - | head -``` - -### End-to-end: DSL → IR → binary → launch → accuracy - -Runs under the msprof CPU simulator — no physical NPU required. - -```bash -cd ptodsl/examples/jit -msprof op simulator --soc-version=Ascend950PR_9599 \ - --output=msprof_res/tadd \ - python3 tadd_launch.py -``` - -Expected output: - -``` -PASS f32_16x64 compile=0.024s launch=35.193s -PASS f32_32x32 compile=0.022s launch=35.926s -All cases passed. -``` - -(Timing varies by machine; launch includes msprof simulator overhead and one-time native build on first run per kernel.) - -Direct run on a real NPU (omit the msprof wrapper when hardware is available): - -```bash -cd ptodsl/examples/jit -python3 tadd_launch.py -``` - -## Artifacts (gitignored) - -- `~/.cache/ptodsl/` — JIT-compiled kernel `.so` cache (override with `PTODSL_CACHE_DIR`) -- `msprof_res/` — msprof simulator trace output diff --git a/ptodsl/examples/jit/flash_attention_softmax_launch.py b/ptodsl/examples/jit/flash_attention_softmax_launch.py new file mode 100644 index 000000000..cb538de77 --- /dev/null +++ b/ptodsl/examples/jit/flash_attention_softmax_launch.py @@ -0,0 +1,411 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Flash-attention softmax stage — end-to-end launch demo. + +This example is the launchable counterpart to the compile-only +``flash_attention_sketch.py`` demo. It intentionally keeps only the online +softmax update stage from flash attention because the current PTODSL runtime +path is already strong enough for vector-heavy softmax, while the full +flash-attention stack still depends on simt/cube capabilities that are not yet +complete for an end-to-end runtime demo. + +Each kernel instance updates one block of up to 8 rows: + + m_next = max(m_prev, row_max(scores)) + p = exp(scores - m_next) + l_next = l_prev * exp(m_prev - m_next) + row_sum(p) + expmax = l_prev * exp(m_prev - m_next) / l_next + out = p / l_next + +The demo offers two fixed-shape launchable kernels so the current launch ABI +does not need runtime scalar parameters: + +- ``rows8_seq128``: full-width 128-column softmax +- ``rows17_seq96``: multi-block + tail-mask coverage +""" + +import argparse +import time +from pathlib import Path +import sys + +import numpy as np + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from flash_attention_softmax_launch.py" + ) + +from ptodsl import pto, scalar + +s = scalar + +_DEVICE = "npu:0" +_ROWS_PER_BLOCK = 8 +_PHYSICAL_COLS = 128 + + +def _make_flash_attention_softmax_kernel(name: str, *, rows: int, seq: int): + if rows <= 0: + raise ValueError("rows must be positive") + if not 0 < seq <= _PHYSICAL_COLS: + raise ValueError(f"seq must be in [1, {_PHYSICAL_COLS}]") + + @pto.jit( + name=name, + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False + ) + def kernel( + oldmax: pto.tensor_spec(rank=2, dtype=pto.f32), + oldsum: pto.tensor_spec(rank=2, dtype=pto.f32), + scores: pto.tensor_spec(rank=2, dtype=pto.f32), + newmax: pto.tensor_spec(rank=2, dtype=pto.f32), + newsum: pto.tensor_spec(rank=2, dtype=pto.f32), + expmax: pto.tensor_spec(rank=2, dtype=pto.f32), + out: pto.tensor_spec(rank=2, dtype=pto.f32), + ): + c0 = pto.const(0) + c1 = pto.const(1) + c8 = pto.const(_ROWS_PER_BLOCK) + c64 = pto.const(64) + c128 = pto.const(_PHYSICAL_COLS) + c_rows = pto.const(rows) + c_seq = pto.const(seq) + c_rows_x_128 = pto.const(rows * _PHYSICAL_COLS) + + c0_i64 = pto.const(0, dtype=pto.int64) + c128_i64 = pto.const(128, dtype=pto.int64) + c256_i64 = pto.const(256, dtype=pto.int64) + c8448_i64 = pto.const(8448, dtype=pto.int64) + c16640_i64 = pto.const(16640, dtype=pto.int64) + c16768_i64 = pto.const(16768, dtype=pto.int64) + c16896_i64 = pto.const(16896, dtype=pto.int64) + + c0_i32 = pto.const(0, dtype=pto.int32) + c1_i32 = pto.const(1, dtype=pto.int32) + c8_i32 = pto.const(_ROWS_PER_BLOCK, dtype=pto.int32) + c_seq_i32 = pto.const(seq, dtype=pto.int32) + c_rows_i32 = pto.const(rows, dtype=pto.int32) + + block_i64 = pto.get_block_idx() + block_idx = s.index_cast(block_i64) + row_base = s.muli(block_idx, c8) + row_base_i32 = s.index_cast(pto.int32, row_base) + remaining_rows = s.subi(c_rows_i32, row_base_i32) + has_rows = remaining_rows > c0_i32 + too_many_rows = remaining_rows > c8_i32 + row_count_i32 = s.select(too_many_rows, c8_i32, remaining_rows) + row_count = s.index_cast(row_count_i32) + + with pto.if_(has_rows) as has_rows_br: + with has_rows_br.then_: + s1 = [c_rows, c_rows, c_rows, c1, c_rows] + s128 = [c_rows_x_128, c_rows_x_128, c_rows_x_128, c128, c1] + sh1 = [c1, c1, c1, c_rows, c1] + sh128 = [c1, c1, c1, c_rows, c128] + + oldmax_view = pto.make_tensor_view(oldmax, shape=sh1, strides=s1) + oldsum_view = pto.make_tensor_view(oldsum, shape=sh1, strides=s1) + scores_view = pto.make_tensor_view(scores, shape=sh128, strides=s128) + newmax_view = pto.make_tensor_view(newmax, shape=sh1, strides=s1) + newsum_view = pto.make_tensor_view(newsum, shape=sh1, strides=s1) + expmax_view = pto.make_tensor_view(expmax, shape=sh1, strides=s1) + out_view = pto.make_tensor_view(out, shape=sh128, strides=s128) + + off = [c0, c0, c0, row_base, c0] + z1 = [c1, c1, c1, row_count, c1] + zs = [c1, c1, c1, row_count, c_seq] + + oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) + oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) + scores_part = pto.partition_view(scores_view, offsets=off, sizes=zs) + newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) + newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) + expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) + out_part = pto.partition_view(out_view, offsets=off, sizes=zs) + + tile_col = pto.tile_buf_type([8, 1], pto.float32, [-1, 1], blayout="ColMajor") + tile_w = pto.tile_buf_type([8, 128], pto.float32, [-1, -1]) + + oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) + oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) + scores_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=c_seq) + out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=c_seq) + newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) + newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) + expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) + + pto.tile.load(oldmax_part, oldmax_tile) + pto.tile.load(oldsum_part, oldsum_tile) + pto.tile.load(scores_part, scores_tile) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.vecscope(): + ptr_ub = pto.ptr(pto.float32, "ub") + vf32 = pto.vreg_type(64, pto.float32) + + ub_om = pto.as_ptr(oldmax_tile, ptr_ub) + ub_os = pto.as_ptr(oldsum_tile, ptr_ub) + ub_scores = pto.as_ptr(scores_tile, ptr_ub) + ub_out = pto.as_ptr(out_tile, ptr_ub) + ub_nm = pto.as_ptr(newmax_tile, ptr_ub) + ub_ns = pto.as_ptr(newsum_tile, ptr_ub) + ub_em = pto.as_ptr(expmax_tile, ptr_ub) + + active = pto.pset_b32(pto.MaskPattern.ALL) + one_mask, _ = pto.plt_b32(c1_i32) + + with pto.for_(c0, row_count, step=c1) as row: + row_scores = s.muli(row, c128) + oldmax_bc = pto.vbrc_load(ub_om, row, vf32) + oldsum_bc = pto.vbrc_load(ub_os, row, vf32) + + with pto.for_(c0, c128, step=c64, iter_args=(oldmax_bc, oldsum_bc)) as softmax_loop: + chunk = softmax_loop.iv + running_max, running_sum = softmax_loop.iter_args + + chunk_i32 = s.index_cast(pto.int32, chunk) + remaining_cols = s.subi(c_seq_i32, chunk_i32) + has_chunk = remaining_cols > c0_i32 + + with pto.if_(has_chunk) as br: + with br.then_: + chunk_mask, _ = pto.plt_b32(remaining_cols) + chunk_base = s.addi(row_scores, chunk) + vec = pto.vlds(ub_scores, chunk_base, vf32) + chunk_max = pto.vcmax(vec, chunk_mask) + chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") + merged_max = pto.vmax(running_max, chunk_max_bc, active) + scaled_running = pto.vexpdif(running_max, merged_max, active) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active) + chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask) + chunk_sum = pto.vcadd(chunk_exp, chunk_mask) + chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") + merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) + br.assign(next_max=merged_max, next_sum=merged_sum) + with br.else_: + br.assign(next_max=running_max, next_sum=running_sum) + pto.yield_(br.next_max, br.next_sum) + + final_max, final_sum = softmax_loop.results + + raw_em = pto.vexpdif(oldmax_bc, final_max, active) + scaled_oldsum = pto.vmul(raw_em, oldsum_bc, active) + expmax = pto.vdiv(scaled_oldsum, final_sum, active) + + pto.vsts_1pt(final_max, ub_nm, row, one_mask) + pto.vsts_1pt(final_sum, ub_ns, row, one_mask) + pto.vsts_1pt(expmax, ub_em, row, one_mask) + + with pto.for_(c0, c128, step=c64) as chunk2: + rem2 = s.subi(c_seq_i32, s.index_cast(pto.int32, chunk2)) + has_chunk2 = rem2 > c0_i32 + with pto.if_(has_chunk2) as br2: + with br2.then_: + cmask2, _ = pto.plt_b32(rem2) + cbase2 = s.addi(row_scores, chunk2) + vec2 = pto.vlds(ub_scores, cbase2, vf32) + exp2 = pto.vexpdif(vec2, final_max, cmask2) + out2 = pto.vdiv(exp2, final_sum, cmask2) + pto.vsts(out2, ub_out, cbase2, cmask2) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + pto.tile.store(newmax_tile, newmax_part) + pto.tile.store(newsum_tile, newsum_part) + pto.tile.store(expmax_tile, expmax_part) + pto.tile.store(out_tile, out_part) + + pto.pipe_barrier(pto.Pipe.ALL) + + return kernel + + +FLASH_SOFTMAX_ROWS8_SEQ128 = _make_flash_attention_softmax_kernel( + "flash_attention_softmax_rows8_seq128", + rows=8, + seq=128, +) +FLASH_SOFTMAX_ROWS17_SEQ96 = _make_flash_attention_softmax_kernel( + "flash_attention_softmax_rows17_seq96", + rows=17, + seq=96, +) + +KERNELS = ( + FLASH_SOFTMAX_ROWS8_SEQ128, + FLASH_SOFTMAX_ROWS17_SEQ96, +) + +CASES = [ + { + "name": "rows8_seq128", + "kernel": FLASH_SOFTMAX_ROWS8_SEQ128, + "rows": 8, + "seq": 128, + }, + { + "name": "rows17_seq96", + "kernel": FLASH_SOFTMAX_ROWS17_SEQ96, + "rows": 17, + "seq": 96, + }, +] + + +def emit_mlir(): + return pto.merge_jit_modules(*KERNELS) + + +def reference_online_softmax_update(oldmax: np.ndarray, oldsum: np.ndarray, scores: np.ndarray, seq: int): + rows = oldmax.shape[0] + newmax = np.empty_like(oldmax) + newsum = np.empty_like(oldsum) + expmax = np.empty_like(oldsum) + out = np.full_like(scores, np.nan) + + for row in range(rows): + m_prev = float(oldmax[row, 0]) + l_prev = float(oldsum[row, 0]) + row_scores = scores[row, :seq] + m_next = max(m_prev, float(np.max(row_scores))) + shifted = np.exp(row_scores - m_next) + l_scaled = l_prev * np.exp(m_prev - m_next) + l_next = l_scaled + float(np.sum(shifted)) + + newmax[row, 0] = m_next + newsum[row, 0] = l_next + expmax[row, 0] = l_scaled / l_next + out[row, :seq] = shifted / l_next + + return newmax, newsum, expmax, out + + +def init_runtime(): + import torch + import torch_npu # noqa: F401 + + torch.npu.config.allow_internal_format = False + torch_npu.npu.set_compile_mode(jit_compile=False) + torch.npu.set_device(_DEVICE) + return torch + + +def npu_stream(torch): + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + +def make_case_inputs(case: dict[str, object]): + rows = int(case["rows"]) + seq = int(case["seq"]) + rng = np.random.RandomState(hash(case["name"]) & 0xFFFFFFFF) + + oldmax = rng.uniform(-2.0, 2.0, size=(rows, 1)).astype(np.float32) + oldsum = rng.uniform(0.25, 3.0, size=(rows, 1)).astype(np.float32) + scores = np.full((rows, _PHYSICAL_COLS), -1000.0, dtype=np.float32) + scores[:, :seq] = rng.uniform(-4.0, 4.0, size=(rows, seq)).astype(np.float32) + + newmax = np.full((rows, 1), np.nan, dtype=np.float32) + newsum = np.full((rows, 1), np.nan, dtype=np.float32) + expmax = np.full((rows, 1), np.nan, dtype=np.float32) + out = np.full((rows, _PHYSICAL_COLS), np.nan, dtype=np.float32) + + return oldmax, oldsum, scores, newmax, newsum, expmax, out + + +def run_case(case: dict[str, object], torch) -> None: + rows = int(case["rows"]) + seq = int(case["seq"]) + grid = (rows + _ROWS_PER_BLOCK - 1) // _ROWS_PER_BLOCK + oldmax, oldsum, scores, newmax, newsum, expmax, out = make_case_inputs(case) + ref_newmax, ref_newsum, ref_expmax, ref_out = reference_online_softmax_update( + oldmax, + oldsum, + scores, + seq, + ) + + oldmax_t = torch.from_numpy(oldmax).to(_DEVICE) + oldsum_t = torch.from_numpy(oldsum).to(_DEVICE) + scores_t = torch.from_numpy(scores).to(_DEVICE) + newmax_t = torch.from_numpy(newmax).to(_DEVICE) + newsum_t = torch.from_numpy(newsum).to(_DEVICE) + expmax_t = torch.from_numpy(expmax).to(_DEVICE) + out_t = torch.from_numpy(out).to(_DEVICE) + stream = npu_stream(torch) + + t0 = time.perf_counter() + compiled = case["kernel"].compile() + compile_s = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled[grid, stream]( + oldmax_t, + oldsum_t, + scores_t, + newmax_t, + newsum_t, + expmax_t, + out_t, + ) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + np.testing.assert_allclose(newmax_t.cpu().numpy(), ref_newmax, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(newsum_t.cpu().numpy(), ref_newsum, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(expmax_t.cpu().numpy(), ref_expmax, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(out_t.cpu().numpy()[:, :seq], ref_out[:, :seq], rtol=1e-5, atol=1e-5) + if seq < _PHYSICAL_COLS: + assert np.isnan(out_t.cpu().numpy()[:, seq:]).all(), "tail columns should remain untouched" + + print( + f"PASS {case['name']} " + f"compile={compile_s:.3f}s launch={launch_s:.3f}s" + ) + + +def test_flash_attention_softmax() -> None: + torch = init_runtime() + for case in CASES: + run_case(case, torch) + print("All cases passed.") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--emit-mlir", + action="store_true", + help="print the merged MLIR module and exit", + ) + args = parser.parse_args(argv) + + if args.emit_mlir: + print(emit_mlir()) + return 0 + + test_flash_attention_softmax() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ptodsl/examples/jit/tadd_launch.py b/ptodsl/examples/jit/tadd_launch.py index d4a4f5f46..734041689 100644 --- a/ptodsl/examples/jit/tadd_launch.py +++ b/ptodsl/examples/jit/tadd_launch.py @@ -15,10 +15,21 @@ import argparse import time +from pathlib import Path +import sys import numpy as np -import torch -import torch_npu # noqa: F401 + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from tadd_launch.py" + ) from ptodsl import pto @@ -29,7 +40,7 @@ # Kernel # --------------------------------------------------------------------------- -def _tadd_tile(a_ptr, b_ptr, c_ptr, rows: int, cols: int) -> None: +def _tadd_tile(A, B, C, rows: int, cols: int) -> None: c0 = pto.const(0) c1 = pto.const(1) c_rows = pto.const(rows) @@ -40,9 +51,9 @@ def _tadd_tile(a_ptr, b_ptr, c_ptr, rows: int, cols: int) -> None: strides = [c_elems, c_elems, c_elems, c_cols, c1] off = [c0, c0, c0, c0, c0] - a_view = pto.make_tensor_view(a_ptr, shape=shape, strides=strides) - b_view = pto.make_tensor_view(b_ptr, shape=shape, strides=strides) - c_view = pto.make_tensor_view(c_ptr, shape=shape, strides=strides) + a_view = pto.make_tensor_view(A, shape=shape, strides=strides) + b_view = pto.make_tensor_view(B, shape=shape, strides=strides) + c_view = pto.make_tensor_view(C, shape=shape, strides=strides) a_part = pto.partition_view(a_view, offsets=off, sizes=shape) b_part = pto.partition_view(b_view, offsets=off, sizes=shape) @@ -62,28 +73,26 @@ def _tadd_tile(a_ptr, b_ptr, c_ptr, rows: int, cols: int) -> None: name="TADD_f32_16x64", kernel_kind="vector", target="a5", - func_attr="pto.aicore", ) def TADD_f32_16x64( - a_ptr: pto.ptr(pto.float32, "gm"), - b_ptr: pto.ptr(pto.float32, "gm"), - c_ptr: pto.ptr(pto.float32, "gm"), + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + C: pto.tensor_spec(rank=2, dtype=pto.f32), ): - _tadd_tile(a_ptr, b_ptr, c_ptr, 16, 64) + _tadd_tile(A, B, C, 16, 64) @pto.jit( name="TADD_f32_32x32", kernel_kind="vector", target="a5", - func_attr="pto.aicore", ) def TADD_f32_32x32( - a_ptr: pto.ptr(pto.float32, "gm"), - b_ptr: pto.ptr(pto.float32, "gm"), - c_ptr: pto.ptr(pto.float32, "gm"), + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + C: pto.tensor_spec(rank=2, dtype=pto.f32), ): - _tadd_tile(a_ptr, b_ptr, c_ptr, 32, 32) + _tadd_tile(A, B, C, 32, 32) KERNELS = (TADD_f32_16x64, TADD_f32_32x32) @@ -104,16 +113,20 @@ def emit_mlir(): def init_torch_npu() -> None: + import torch + import torch_npu # noqa: F401 + torch.npu.config.allow_internal_format = False torch_npu.npu.set_compile_mode(jit_compile=False) torch.npu.set_device(_DEVICE) + return torch -def npu_stream(): +def npu_stream(torch): return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 -def run_case(case: dict) -> None: +def run_case(case: dict, torch) -> None: shape = case["shape"] rng = np.random.RandomState(hash(case["name"]) & 0xFFFFFFFF) x = rng.randint(1, 10, size=shape).astype(np.float32) @@ -123,7 +136,7 @@ def run_case(case: dict) -> None: a = torch.from_numpy(x).to(_DEVICE) b = torch.from_numpy(y).to(_DEVICE) c = torch.empty(shape, dtype=torch.float32, device=_DEVICE) - stream = npu_stream() + stream = npu_stream(torch) t0 = time.perf_counter() compiled = case["kernel"].compile() @@ -142,9 +155,9 @@ def run_case(case: dict) -> None: def test_tadd() -> None: - init_torch_npu() + torch = init_torch_npu() for case in CASES: - run_case(case) + run_case(case, torch) print("All cases passed.") diff --git a/ptodsl/examples/softmax_dsl.py b/ptodsl/examples/softmax_dsl.py index 913009bb6..d55a271dd 100644 --- a/ptodsl/examples/softmax_dsl.py +++ b/ptodsl/examples/softmax_dsl.py @@ -46,7 +46,6 @@ name="online_softmax_update_kernel_2d", kernel_kind="vector", target="a5", - func_attr="pto.aicore", mode="explicit", ) def online_softmax_update_kernel_2d( @@ -103,137 +102,138 @@ def online_softmax_update_kernel_2d( rows = s.index_cast(arg8) # → index rows_x_128 = s.muli(rows, c128) - with pto.if_(has_rows): - # ── Tensor views ───────────────────────────────────────────────────── - s1 = [rows, rows, rows, c1, rows] - s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] - sh1 = [c1, c1, c1, rows, c1] - sh128= [c1, c1, c1, rows, c128] - - oldmax_view = pto.make_tensor_view(arg0, shape=sh1, strides=s1) - oldsum_view = pto.make_tensor_view(arg1, shape=sh1, strides=s1) - qk_view = pto.make_tensor_view(arg2, shape=sh128, strides=s128) - newmax_view = pto.make_tensor_view(arg3, shape=sh1, strides=s1) - newsum_view = pto.make_tensor_view(arg4, shape=sh1, strides=s1) - expmax_view = pto.make_tensor_view(arg5, shape=sh1, strides=s1) - out_view = pto.make_tensor_view(arg6, shape=sh128, strides=s128) - - # ── Partition views ─────────────────────────────────────────────────── - off = [c0, c0, c0, row_base, c0] - z1 = [c1, c1, c1, row_count, c1] - zs = [c1, c1, c1, row_count, seq] - - oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) - oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) - qk_part = pto.partition_view(qk_view, offsets=off, sizes=zs) - newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) - newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) - expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) - out_part = pto.partition_view(out_view, offsets=off, sizes=zs) - - # ── UB tile allocation ──────────────────────────────────────────────── - tile_col = pto.tile_buf_type([8, 1], pto.float32, [-1, 1], blayout="ColMajor") - tile_w = pto.tile_buf_type([8, 128], pto.float32, [-1, -1]) - - oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) - oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) - qk_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) - out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) - newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) - newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) - expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) - - # ── Tile loads from GM ──────────────────────────────────────────────── - pto.tile.load(oldmax_part, oldmax_tile) - pto.tile.load(oldsum_part, oldsum_tile) - pto.tile.load(qk_part, qk_tile) - - pto.set_flag("MTE2", "V", event_id=0) - pto.wait_flag("MTE2", "V", event_id=0) - - with pto.vecscope(): - # Materialise typed UB pointers from tile handles - ptr_ub = pto.ptr(pto.float32, "ub") - vf32 = pto.vreg_type(64, pto.float32) - - ub_om = pto.as_ptr(oldmax_tile, ptr_ub) - ub_os = pto.as_ptr(oldsum_tile, ptr_ub) - ub_qk = pto.as_ptr(qk_tile, ptr_ub) - ub_out = pto.as_ptr(out_tile, ptr_ub) - ub_nm = pto.as_ptr(newmax_tile, ptr_ub) - ub_ns = pto.as_ptr(newsum_tile, ptr_ub) - ub_em = pto.as_ptr(expmax_tile, ptr_ub) - - active = pto.pset_b32("PAT_ALL") - one_mask, _ = pto.plt_b32(c1_i32) - - with pto.for_(c0, row_count, step=c1) as row: - row_qk = s.muli(row, c128) - oldmax_bc = pto.vbrc_load(ub_om, row, vf32) - oldsum_bc = pto.vbrc_load(ub_os, row, vf32) - - # scf.for with iter_args: accumulate (running_max, running_sum) - with pto.for_(c0, c128, step=c64, iter_args=(oldmax_bc, oldsum_bc)) as loop: - chunk = loop.iv - running_max, running_sum = loop.iter_args - - chunk_i32 = s.index_cast(pto.int32, chunk) - remaining_cols = s.subi(arg7, chunk_i32) - has_chunk = remaining_cols > c0_i32 - - # scf.if with results – produce (next_max, next_sum) - with pto.if_(has_chunk, results=(vf32, vf32)) as br: - with br.then_: - chunk_mask, _ = pto.plt_b32(remaining_cols) - chunk_base = s.addi(row_qk, chunk) - vec = pto.vlds(ub_qk, chunk_base, vf32) - chunk_max = pto.vcmax(vec, chunk_mask) - chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") - merged_max = pto.vmax(running_max, chunk_max_bc, active) - scaled_running = pto.vexpdif(running_max, merged_max, active) - running_sum_scaled = pto.vmul(scaled_running, running_sum, active) - chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask) - chunk_sum = pto.vcadd(chunk_exp, chunk_mask) - chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") - merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) - pto.yield_(merged_max, merged_sum) - with br.else_: - pto.yield_(running_max, running_sum) - - next_max, next_sum = br.results - pto.yield_(next_max, next_sum) - - final_max, final_sum = loop.results - - # Compute per-row expmax scalar - raw_em = pto.vexpdif(oldmax_bc, final_max, active) - sc_os = pto.vmul(raw_em, oldsum_bc, active) - expmax = pto.vdiv(sc_os, final_sum, active) - - pto.vsts_1pt(final_max, ub_nm, row, one_mask) - pto.vsts_1pt(final_sum, ub_ns, row, one_mask) - pto.vsts_1pt(expmax, ub_em, row, one_mask) - - # Output normalisation loop - with pto.for_(c0, c128, step=c64) as chunk2: - rem2 = s.subi(arg7, s.index_cast(pto.int32, chunk2)) - has_chunk2= rem2 > c0_i32 - with pto.if_(has_chunk2): - cmask2, _ = pto.plt_b32(rem2) - cbase2 = s.addi(row_qk, chunk2) - vec2 = pto.vlds(ub_qk, cbase2, vf32) - exp2 = pto.vexpdif(vec2, final_max, cmask2) - out2 = pto.vdiv(exp2, final_sum, cmask2) - pto.vsts(out2, ub_out, cbase2, cmask2) - - pto.set_flag("V", "MTE3", event_id=0) - pto.wait_flag("V", "MTE3", event_id=0) - - # Tile stores to GM - pto.tile.store(newmax_tile, newmax_part) - pto.tile.store(newsum_tile, newsum_part) - pto.tile.store(expmax_tile, expmax_part) - pto.tile.store(out_tile, out_part) + with pto.if_(has_rows) as has_rows_br: + with has_rows_br.then_: + # ── Tensor views ───────────────────────────────────────────────────── + s1 = [rows, rows, rows, c1, rows] + s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] + sh1 = [c1, c1, c1, rows, c1] + sh128= [c1, c1, c1, rows, c128] + + oldmax_view = pto.make_tensor_view(arg0, shape=sh1, strides=s1) + oldsum_view = pto.make_tensor_view(arg1, shape=sh1, strides=s1) + qk_view = pto.make_tensor_view(arg2, shape=sh128, strides=s128) + newmax_view = pto.make_tensor_view(arg3, shape=sh1, strides=s1) + newsum_view = pto.make_tensor_view(arg4, shape=sh1, strides=s1) + expmax_view = pto.make_tensor_view(arg5, shape=sh1, strides=s1) + out_view = pto.make_tensor_view(arg6, shape=sh128, strides=s128) + + # ── Partition views ─────────────────────────────────────────────────── + off = [c0, c0, c0, row_base, c0] + z1 = [c1, c1, c1, row_count, c1] + zs = [c1, c1, c1, row_count, seq] + + oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) + oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) + qk_part = pto.partition_view(qk_view, offsets=off, sizes=zs) + newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) + newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) + expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) + out_part = pto.partition_view(out_view, offsets=off, sizes=zs) + + # ── UB tile allocation ──────────────────────────────────────────────── + tile_col = pto.tile_buf_type([8, 1], pto.float32, [-1, 1], blayout="ColMajor") + tile_w = pto.tile_buf_type([8, 128], pto.float32, [-1, -1]) + + oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) + oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) + qk_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) + out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) + newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) + newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) + expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) + + # ── Tile loads from GM ──────────────────────────────────────────────── + pto.tile.load(oldmax_part, oldmax_tile) + pto.tile.load(oldsum_part, oldsum_tile) + pto.tile.load(qk_part, qk_tile) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.vecscope(): + # Materialise typed UB pointers from tile handles + ptr_ub = pto.ptr(pto.float32, "ub") + vf32 = pto.vreg_type(64, pto.float32) + + ub_om = pto.as_ptr(oldmax_tile, ptr_ub) + ub_os = pto.as_ptr(oldsum_tile, ptr_ub) + ub_qk = pto.as_ptr(qk_tile, ptr_ub) + ub_out = pto.as_ptr(out_tile, ptr_ub) + ub_nm = pto.as_ptr(newmax_tile, ptr_ub) + ub_ns = pto.as_ptr(newsum_tile, ptr_ub) + ub_em = pto.as_ptr(expmax_tile, ptr_ub) + + active = pto.pset_b32("PAT_ALL") + one_mask, _ = pto.plt_b32(c1_i32) + + with pto.for_(c0, row_count, step=c1) as row: + row_qk = s.muli(row, c128) + oldmax_bc = pto.vbrc_load(ub_om, row, vf32) + oldsum_bc = pto.vbrc_load(ub_os, row, vf32) + + # scf.for with iter_args: accumulate (running_max, running_sum) + with pto.for_(c0, c128, step=c64, iter_args=(oldmax_bc, oldsum_bc)) as loop: + chunk = loop.iv + running_max, running_sum = loop.iter_args + + chunk_i32 = s.index_cast(pto.int32, chunk) + remaining_cols = s.subi(arg7, chunk_i32) + has_chunk = remaining_cols > c0_i32 + + # scf.if with merged branch values – produce (next_max, next_sum) + with pto.if_(has_chunk) as br: + with br.then_: + chunk_mask, _ = pto.plt_b32(remaining_cols) + chunk_base = s.addi(row_qk, chunk) + vec = pto.vlds(ub_qk, chunk_base, vf32) + chunk_max = pto.vcmax(vec, chunk_mask) + chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") + merged_max = pto.vmax(running_max, chunk_max_bc, active) + scaled_running = pto.vexpdif(running_max, merged_max, active) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active) + chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask) + chunk_sum = pto.vcadd(chunk_exp, chunk_mask) + chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") + merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) + br.assign(next_max=merged_max, next_sum=merged_sum) + with br.else_: + br.assign(next_max=running_max, next_sum=running_sum) + + pto.yield_(br.next_max, br.next_sum) + + final_max, final_sum = loop.results + + # Compute per-row expmax scalar + raw_em = pto.vexpdif(oldmax_bc, final_max, active) + sc_os = pto.vmul(raw_em, oldsum_bc, active) + expmax = pto.vdiv(sc_os, final_sum, active) + + pto.vsts_1pt(final_max, ub_nm, row, one_mask) + pto.vsts_1pt(final_sum, ub_ns, row, one_mask) + pto.vsts_1pt(expmax, ub_em, row, one_mask) + + # Output normalisation loop + with pto.for_(c0, c128, step=c64) as chunk2: + rem2 = s.subi(arg7, s.index_cast(pto.int32, chunk2)) + has_chunk2= rem2 > c0_i32 + with pto.if_(has_chunk2) as br2: + with br2.then_: + cmask2, _ = pto.plt_b32(rem2) + cbase2 = s.addi(row_qk, chunk2) + vec2 = pto.vlds(ub_qk, cbase2, vf32) + exp2 = pto.vexpdif(vec2, final_max, cmask2) + out2 = pto.vdiv(exp2, final_sum, cmask2) + pto.vsts(out2, ub_out, cbase2, cmask2) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + # Tile stores to GM + pto.tile.store(newmax_tile, newmax_part) + pto.tile.store(newsum_tile, newsum_part) + pto.tile.store(expmax_tile, expmax_part) + pto.tile.store(out_tile, out_part) pto.pipe_barrier(pto.Pipe.ALL) diff --git a/ptodsl/examples/tilelang_codegen.py b/ptodsl/examples/tilelang_codegen.py new file mode 100644 index 000000000..7353979a7 --- /dev/null +++ b/ptodsl/examples/tilelang_codegen.py @@ -0,0 +1,315 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +TileLang-generated explicit PTODSL kernel. + +This file keeps the original generated kernel body essentially intact and only +adds the minimum wrapper needed to make it usable as a compile/test target: + +- public `@pto.jit` host ABI via `tensor_spec(...)` +- `--emit-mlir` entry point +- compile smoke path for regression tests +""" + +import argparse +from pathlib import Path +import sys +import time + +import numpy as np + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from tilelang_codegen.py" + ) + +from ptodsl import pto + + +_DEVICE = "npu:0" + + +def _tilelang_generated_body( + A, + B, + C, +): + bx = pto.get_block_idx() + buf_dyn_shmem = pto.const(0, dtype=pto.int64) + with pto.for_(0, 2, step=1) as f: + pto.set_flag("MTE3", "V", event_id=f) + pto.set_flag("V", "MTE2", event_id=f) + with pto.for_(0, 2048, step=1) as iter: + pto.wait_flag("V", "MTE2", event_id=iter % 2) + pto.mte_gm_ub( + pto.addptr(A, (iter * 524288) + (bx * 8192)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (iter % 2) * 8192, + ), + 0, + 32768, + nburst=(1, 0, 0), + ) + pto.mte_gm_ub( + pto.addptr(B, (iter * 524288) + (bx * 8192)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 8192) + 16384, + ), + 0, + 32768, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE3", "V", event_id=iter % 2) + with pto.simd(): + mask_cnt = 8192 + _ = mask_cnt + with pto.for_(0, 128, step=1) as i: + mask = pto.pset_b32("PAT_ALL") + r0 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 8192) + (i * 64), + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r1 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 8192) + (i * 64)) + 16384, + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r0 = pto.vadd(r0, r1, mask) + pto.vsts( + r0, + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 8192) + (i * 64)) + 32768, + ), + pto.const(0), + mask, + ) + pto.set_flag("V", "MTE3", event_id=iter % 2) + pto.set_flag("V", "MTE2", event_id=iter % 2) + pto.wait_flag("V", "MTE3", event_id=iter % 2) + pto.mte_ub_gm( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 8192) + 32768, + ), + pto.addptr(C, (iter * 524288) + (bx * 8192)), + 32768, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE3", "V", event_id=iter % 2) + with pto.for_(0, 2, step=1) as f_1: + pto.wait_flag("MTE3", "V", event_id=f_1) + pto.wait_flag("V", "MTE2", event_id=f_1) + + +@pto.jit( + name="main_kernel", + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False, +) +def main_kernel( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + C: pto.tensor_spec(rank=1, dtype=pto.f32), +): + _tilelang_generated_body(A.data_handle, B.data_handle, C.data_handle) + + +def _tilelang_generated_body_small(A, B, C): + bx = pto.get_block_idx() + buf_dyn_shmem = pto.const(0, dtype=pto.int64) + with pto.for_(0, 2, step=1) as f: + pto.set_flag("MTE3", "V", event_id=f) + pto.set_flag("V", "MTE2", event_id=f) + with pto.for_(0, 2, step=1) as iter: + pto.wait_flag("V", "MTE2", event_id=iter % 2) + pto.mte_gm_ub( + pto.addptr(A, (iter * 128) + (bx * 128)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (iter % 2) * 128, + ), + 0, + 512, + nburst=(1, 0, 0), + ) + pto.mte_gm_ub( + pto.addptr(B, (iter * 128) + (bx * 128)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 128) + 256, + ), + 0, + 512, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE3", "V", event_id=iter % 2) + with pto.simd(): + with pto.for_(0, 2, step=1) as i: + mask = pto.pset_b32("PAT_ALL") + r0 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 128) + (i * 64), + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r1 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 128) + (i * 64)) + 256, + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r0 = pto.vadd(r0, r1, mask) + pto.vsts( + r0, + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 128) + (i * 64)) + 512, + ), + pto.const(0), + mask, + ) + pto.set_flag("V", "MTE3", event_id=iter % 2) + pto.set_flag("V", "MTE2", event_id=iter % 2) + pto.wait_flag("V", "MTE3", event_id=iter % 2) + pto.mte_ub_gm( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 128) + 512, + ), + pto.addptr(C, (iter * 128) + (bx * 128)), + 512, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE3", "V", event_id=iter % 2) + with pto.for_(0, 2, step=1) as f_1: + pto.wait_flag("MTE3", "V", event_id=f_1) + pto.wait_flag("V", "MTE2", event_id=f_1) + + +@pto.jit( + name="main_kernel_precision_test", + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False, +) +def main_kernel_precision_test( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + C: pto.tensor_spec(rank=1, dtype=pto.f32), +): + _tilelang_generated_body_small(A.data_handle, B.data_handle, C.data_handle) + + +def emit_mlir(): + return main_kernel.mlir_text() + + +def compile_kernel(): + compiled = main_kernel.compile() + compiled.verify() + return compiled + + +def init_torch_npu(): + import torch + import torch_npu # noqa: F401 + + torch.npu.config.allow_internal_format = False + torch_npu.npu.set_compile_mode(jit_compile=False) + torch.npu.set_device(_DEVICE) + return torch + + +def npu_stream(torch): + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + +def make_case_inputs(): + total = 256 + rng = np.random.RandomState(20260524) + a = rng.uniform(-3.0, 3.0, size=(total,)).astype(np.float32) + b = rng.uniform(-3.0, 3.0, size=(total,)).astype(np.float32) + c = np.full((total,), np.nan, dtype=np.float32) + return a, b, c + + +def run_precision_case(torch) -> None: + a_np, b_np, c_np = make_case_inputs() + ref = a_np + b_np + + a_t = torch.from_numpy(a_np).to(_DEVICE) + b_t = torch.from_numpy(b_np).to(_DEVICE) + c_t = torch.from_numpy(c_np).to(_DEVICE) + stream = npu_stream(torch) + + t0 = time.perf_counter() + compiled = main_kernel_precision_test.compile() + compile_s = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled[1, stream](a_t, b_t, c_t) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + np.testing.assert_allclose(c_t.cpu().numpy(), ref, rtol=1e-6, atol=1e-6) + print(f"PASS tilelang_codegen compile={compile_s:.3f}s launch={launch_s:.3f}s") + + +def test_tilelang_codegen() -> None: + torch = init_torch_npu() + run_precision_case(torch) + print("All cases passed.") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--emit-mlir", + action="store_true", + help="print compiled MLIR and exit", + ) + args = parser.parse_args(argv) + + if args.emit_mlir: + print(emit_mlir()) + return 0 + + test_tilelang_codegen() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index ecc6fe3e1..9087a861d 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -46,6 +46,28 @@ def host_tensor_metadata_error(message: str, *, param_name: str | None = None) - return TypeError(f"{prefix}: {message}") +def jit_missing_annotation_error(name: str) -> TypeError: + """Return one diagnostic for missing ``@pto.jit`` positional ABI annotations.""" + return TypeError( + f"@pto.jit positional parameter '{name}' does not declare an entry ABI annotation. " + "Use pto.tensor_spec(...) for runtime tensors, a PTO scalar type such as " + "pto.i32/pto.f32/pto.i1 for runtime scalars, or move compile-time values " + "to keyword-only pto.constexpr parameters." + ) + + +def jit_illegal_formal_annotation_error(name: str, annotation: object) -> TypeError: + """Return one diagnostic for unsupported ``@pto.jit`` positional annotations.""" + return TypeError( + f"@pto.jit positional parameter '{name}' uses unsupported entry annotation {annotation!r}. " + "The public @pto.jit entry ABI accepts pto.tensor_spec(...) runtime tensors, " + "PTO scalar annotations such as pto.i32/pto.f32/pto.i1 for runtime scalars, " + "and keyword-only pto.constexpr compile-time parameters. " + "Low-level PTODSL types such as pto.ptr(...), Tile, PartitionTensorView, and VReg " + "belong inside the kernel body or across sub-kernel boundaries, not at the host/kernel entry." + ) + + def subkernel_host_tensor_boundary_error(role: str, name: str) -> TypeError: """Return one diagnostic for host-tensor usage outside the JIT boundary.""" return TypeError( @@ -158,6 +180,8 @@ def removed_ukernel_surface_error() -> AttributeError: "explicit_mode_required_error", "explicit_mode_required_with_context_error", "host_tensor_metadata_error", + "jit_illegal_formal_annotation_error", + "jit_missing_annotation_error", "illegal_inline_subkernel_placement_error", "illegal_subkernel_placement_error", "invalid_jit_mode_error", diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py index bae9020da..4c7012518 100644 --- a/ptodsl/ptodsl/_jit.py +++ b/ptodsl/ptodsl/_jit.py @@ -54,7 +54,7 @@ def _module_attr_map(module): def merge_jit_modules(*kernels: KernelHandle): """ - Merge multiple ``@pto.jit`` flat-module kernels into one MLIR module. + Merge multiple ``@pto.jit`` kernels into one MLIR module. Each handle must have been compiled with the same ``target``, ``kernel_kind``, and ``mode`` module attributes. Function order follows @@ -88,7 +88,7 @@ def jit( target: str = "a5", kernel_kind: str = "vector", mode: str = "auto", - func_attr: str = None, + insert_sync: bool | None = None, ): """ Decorator that wraps a Python function as a PTODSL JIT kernel template. @@ -99,8 +99,9 @@ def jit( target: Target architecture string, e.g. ``"a5"``. kernel_kind: ``"vector"`` or ``"cube"`` – sets ``pto.kernel_kind``. mode: ``"auto"`` or ``"explicit"`` – sets ``pto.mode``. - func_attr: Optional function attribute. Pass ``"pto.aicore"`` to - select the flat-module structure with the aicore attribute. + insert_sync: ``True``/``False`` to explicitly control PTOAS sync insertion + for launch builds. ``None`` keeps the mode-based default + behavior. The decorated function is replaced by a :class:`KernelHandle` that: @@ -108,6 +109,7 @@ def jit( - prints as the default-specialization MLIR text, - exposes ``my_kernel.mlir_module()`` / ``verify()`` / ``emit()`` on the default specialization for convenience. + - emits a flat aicore launch-entry module by default. """ def decorator(fn): @@ -119,11 +121,6 @@ def decorator(fn): source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) except (OSError, TypeError): source_file = None - module_style = ( - ModuleStyle.FLAT_AICORE - if func_attr == "pto.aicore" - else ModuleStyle.NESTED - ) compiler = KernelCompiler( fn.__name__, KernelModuleSpec( @@ -131,7 +128,8 @@ def decorator(fn): target_arch=target, kernel_kind=kernel_kind, mode=normalized_mode, - module_style=module_style, + insert_sync=insert_sync, + module_style=ModuleStyle.FLAT_AICORE, source_file=source_file, source_line=getattr(fn.__code__, "co_firstlineno", None), ), diff --git a/ptodsl/ptodsl/_kernel_signature.py b/ptodsl/ptodsl/_kernel_signature.py index a31f865c3..a0c891494 100644 --- a/ptodsl/ptodsl/_kernel_signature.py +++ b/ptodsl/ptodsl/_kernel_signature.py @@ -12,10 +12,14 @@ import inspect from dataclasses import dataclass +from ._diagnostics import ( + jit_illegal_formal_annotation_error, + jit_missing_annotation_error, +) from ._host_tensors import bind_host_tensor_argument, infer_jit_host_tensor_spec from ._surface_values import wrap_surface_value from ._surface_types import constexpr as _constexpr_marker -from ._types import _resolve +from ._types import _DType, _MaskDescriptor, _PtrDescriptor, _VRegDescriptor, _resolve @dataclass(frozen=True) @@ -42,6 +46,23 @@ def abi_signature(self): return ("device", self.name, _hashable_signature_atom(self.annotation)) +@dataclass(frozen=True) +class RuntimeScalarParameterSpec: + name: str + annotation: object + + def entry_arg_types(self): + return (_resolve(self.annotation),) + + def bind_entry_arguments(self, entry_arguments): + if not entry_arguments: + raise RuntimeError(f"entry ABI for runtime scalar parameter '{self.name}' is incomplete") + return wrap_surface_value(entry_arguments[0]), entry_arguments[1:] + + def abi_signature(self): + return ("scalar", self.name, _hashable_signature_atom(self.annotation)) + + @dataclass(frozen=True) class TensorSpecParameterSpec: name: str @@ -82,6 +103,13 @@ def _hashable_signature_atom(value): return value +def _is_supported_runtime_scalar_annotation(annotation) -> bool: + return ( + isinstance(annotation, _DType) + and not isinstance(annotation, (_PtrDescriptor, _VRegDescriptor, _MaskDescriptor)) + ) + + @dataclass(frozen=True) class KernelSignature: positional_parameters: tuple @@ -145,15 +173,19 @@ def parse_jit_kernel_signature(py_fn) -> KernelSignature: inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, }: + if param.annotation is inspect.Parameter.empty: + raise jit_missing_annotation_error(param.name) host_tensor_spec = infer_jit_host_tensor_spec(param) if host_tensor_spec is not None: positional_parameters.append( TensorSpecParameterSpec(param.name, host_tensor_spec) ) - else: + elif _is_supported_runtime_scalar_annotation(param.annotation): positional_parameters.append( - DeviceParameterSpec(param.name, param.annotation) + RuntimeScalarParameterSpec(param.name, param.annotation) ) + else: + raise jit_illegal_formal_annotation_error(param.name, param.annotation) continue if param.kind is inspect.Parameter.KEYWORD_ONLY: @@ -186,6 +218,7 @@ def parse_jit_kernel_signature(py_fn) -> KernelSignature: "DeviceParameterSpec", "KernelSpecializationKey", "KernelSignature", + "RuntimeScalarParameterSpec", "TensorSpecParameterSpec", "parse_jit_kernel_signature", ] diff --git a/ptodsl/ptodsl/_runtime/codegen.py b/ptodsl/ptodsl/_runtime/codegen.py index 936d56c01..75eb4ed0f 100644 --- a/ptodsl/ptodsl/_runtime/codegen.py +++ b/ptodsl/ptodsl/_runtime/codegen.py @@ -9,8 +9,10 @@ from __future__ import annotations -from .._kernel_signature import DeviceParameterSpec, TensorSpecParameterSpec -from .._types import _PtrDescriptor +from mlir.ir import BF16Type, F16Type, F32Type, IndexType, IntegerType + +from .._kernel_signature import DeviceParameterSpec, RuntimeScalarParameterSpec, TensorSpecParameterSpec +from .._types import _PtrDescriptor, _resolve def _elem_cpp_type(elem) -> str: @@ -49,10 +51,46 @@ def _device_param_cpp_type(annotation) -> str: return "float" +def _runtime_scalar_cpp_type(annotation) -> str: + type_obj = _resolve(annotation) + if IndexType.isinstance(type_obj): + return "int64_t" + if IntegerType.isinstance(type_obj): + width = IntegerType(type_obj).width + if width == 1: + return "bool" + signedness = str(type_obj) + if signedness.startswith("ui"): + return { + 8: "uint8_t", + 16: "uint16_t", + 32: "uint32_t", + 64: "uint64_t", + }[width] + return { + 8: "int8_t", + 16: "int16_t", + 32: "int32_t", + 64: "int64_t", + }[width] + if F32Type.isinstance(type_obj): + return "float" + if F16Type.isinstance(type_obj): + return "__fp16" + if BF16Type.isinstance(type_obj): + return "__bf16" + raise TypeError(f"unsupported @pto.jit runtime scalar codegen type {type_obj}") + + def launch_symbol_name(ir_function_name: str) -> str: return f"ptodsl_launch_{ir_function_name}" +def _tensor_metadata_cpp_type() -> str: + # Host-visible tensor shape/stride metadata is marshaled as 64-bit integers. + return "int64_t" + + def generate_launch_cpp(*, ir_function_name: str, kernel_signature) -> str: """Return C++ source for one extern-C launch entry point.""" gm_params = [] @@ -66,19 +104,26 @@ def generate_launch_cpp(*, ir_function_name: str, kernel_signature) -> str: host_params.append(f"{cpp_type} *{param.name}") kernel_args.append(f"(__gm__ {cpp_type} *){param.name}") continue + if isinstance(param, RuntimeScalarParameterSpec): + cpp_type = _runtime_scalar_cpp_type(param.annotation) + gm_params.append(f"{cpp_type} {param.name}") + host_params.append(f"{cpp_type} {param.name}") + kernel_args.append(param.name) + continue if isinstance(param, TensorSpecParameterSpec): cpp_type = _elem_cpp_type(param.tensor_spec.dtype) + meta_cpp_type = _tensor_metadata_cpp_type() rank = param.tensor_spec.rank gm_params.append(f"__gm__ {cpp_type} *{param.name}_ptr") host_params.append(f"{cpp_type} *{param.name}_ptr") kernel_args.append(f"(__gm__ {cpp_type} *){param.name}_ptr") for idx in range(rank): - gm_params.append(f"index {param.name}_shape_{idx}") - host_params.append(f"int64_t {param.name}_shape_{idx}") + gm_params.append(f"{meta_cpp_type} {param.name}_shape_{idx}") + host_params.append(f"{meta_cpp_type} {param.name}_shape_{idx}") kernel_args.append(f"{param.name}_shape_{idx}") for idx in range(rank): - gm_params.append(f"index {param.name}_stride_{idx}") - host_params.append(f"int64_t {param.name}_stride_{idx}") + gm_params.append(f"{meta_cpp_type} {param.name}_stride_{idx}") + host_params.append(f"{meta_cpp_type} {param.name}_stride_{idx}") kernel_args.append(f"{param.name}_stride_{idx}") continue raise TypeError(f"unsupported launch parameter spec: {param!r}") diff --git a/ptodsl/ptodsl/_runtime/launch.py b/ptodsl/ptodsl/_runtime/launch.py index 5a04c88ed..c517e7037 100644 --- a/ptodsl/ptodsl/_runtime/launch.py +++ b/ptodsl/ptodsl/_runtime/launch.py @@ -16,9 +16,12 @@ inspect_host_tensor_metadata, looks_like_host_tensor, ) -from .._kernel_signature import DeviceParameterSpec, TensorSpecParameterSpec +from .._kernel_signature import DeviceParameterSpec, RuntimeScalarParameterSpec, TensorSpecParameterSpec +from .._types import _resolve from .native_build import build_native_library +from mlir.ir import BF16Type, F16Type, F32Type, IndexType, IntegerType + if TYPE_CHECKING: from .._kernel_compilation import CompiledKernelHandle @@ -52,6 +55,39 @@ def _as_void_ptr(value): raise TypeError(f"expected a pointer-like launch argument, got {type(value)!r}") +def _ctype_for_runtime_scalar(annotation): + type_obj = _resolve(annotation) + if IndexType.isinstance(type_obj): + return ctypes.c_int64 + if IntegerType.isinstance(type_obj): + width = IntegerType(type_obj).width + if width == 1: + return ctypes.c_bool + if width == 8: + return ctypes.c_int8 + if width == 16: + return ctypes.c_int16 + if width == 32: + return ctypes.c_int32 + if width == 64: + return ctypes.c_int64 + if F32Type.isinstance(type_obj): + return ctypes.c_float + if F16Type.isinstance(type_obj) or BF16Type.isinstance(type_obj): + raise TypeError( + f"runtime launch does not yet support host scalar marshaling for {type_obj}; " + "use pto.f32 / integer scalar parameters or tensorize this value for now" + ) + raise TypeError(f"unsupported @pto.jit runtime scalar launch type {type_obj}") + + +def _marshal_runtime_scalar(annotation, value): + ctype = _ctype_for_runtime_scalar(annotation) + if ctype is ctypes.c_bool: + return ctype(bool(value)) + return ctype(value) + + def _marshal_launch_args(kernel_signature, args): if len(args) != len(kernel_signature.positional_parameters): raise TypeError( @@ -64,6 +100,9 @@ def _marshal_launch_args(kernel_signature, args): if isinstance(param, DeviceParameterSpec): marshaled.append(_as_void_ptr(value)) continue + if isinstance(param, RuntimeScalarParameterSpec): + marshaled.append(_marshal_runtime_scalar(param.annotation, value)) + continue if isinstance(param, TensorSpecParameterSpec): if not looks_like_host_tensor(value): raise TypeError( @@ -126,6 +165,9 @@ def _launch_argtypes(kernel_signature): if isinstance(param, DeviceParameterSpec): argtypes.append(ctypes.c_void_p) continue + if isinstance(param, RuntimeScalarParameterSpec): + argtypes.append(_ctype_for_runtime_scalar(param.annotation)) + continue if isinstance(param, TensorSpecParameterSpec): argtypes.append(ctypes.c_void_p) rank = param.tensor_spec.rank diff --git a/ptodsl/ptodsl/_runtime/native_build.py b/ptodsl/ptodsl/_runtime/native_build.py index b78094b68..38fc59286 100644 --- a/ptodsl/ptodsl/_runtime/native_build.py +++ b/ptodsl/ptodsl/_runtime/native_build.py @@ -36,19 +36,33 @@ def _run(cmd: list[str], *, cwd: Path | None = None) -> None: ) -def _run_ptoas(mlir_path: Path, kernel_object: Path, *, target_arch: str) -> None: +def _run_ptoas( + mlir_path: Path, + kernel_object: Path, + *, + target_arch: str, + mode: str, + insert_sync: bool | None, +) -> None: ptoas = resolve_ptoas_binary() + cmd = [ + str(ptoas), + f"--pto-arch={target_arch}", + "--pto-backend=vpto", + ] + effective_insert_sync = (mode != "explicit") if insert_sync is None else insert_sync + if mode == "explicit": + cmd.append("--pto-level=level3") + if effective_insert_sync: + cmd.append("--enable-insert-sync") + cmd.extend([ + "--enable-tile-op-expand", + str(mlir_path), + "-o", + str(kernel_object), + ]) _run( - [ - str(ptoas), - f"--pto-arch={target_arch}", - "--pto-backend=vpto", - "--enable-insert-sync", - "--enable-tile-op-expand", - str(mlir_path), - "-o", - str(kernel_object), - ] + cmd ) @@ -169,6 +183,8 @@ def build_native_library( artifacts.mlir_path, artifacts.kernel_object, target_arch=module_spec.target_arch, + mode=module_spec.mode, + insert_sync=module_spec.insert_sync, ) launch_object = artifacts.cache_dir / "launch.o" diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py index 7012ddfa1..baba789c6 100644 --- a/ptodsl/ptodsl/_tracing/module_builder.py +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -31,6 +31,7 @@ class KernelModuleSpec: target_arch: str kernel_kind: str mode: str = "auto" + insert_sync: bool | None = None module_style: ModuleStyle = ModuleStyle.NESTED source_file: str | None = None source_line: int | None = None diff --git a/scripts/sim_dsl.sh b/scripts/sim_dsl.sh new file mode 100755 index 000000000..26fe4c658 --- /dev/null +++ b/scripts/sim_dsl.sh @@ -0,0 +1,113 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +usage() { + cat <<'EOF' +Run a PTODSL JIT example under `msprof op simulator`. + +Usage: + scripts/sim_dsl.sh [options] [-- ] + +Options: + --output Override msprof output directory. + --soc-version Override simulator soc version. Default: Ascend950PR_9599 + -h, --help Show this help. + +Examples: + scripts/sim_dsl.sh ptodsl/examples/jit/tadd_launch.py + scripts/sim_dsl.sh \ + --output "$PWD/build/msprof_res/flash_softmax" \ + ptodsl/examples/jit/flash_attention_softmax_launch.py +EOF +} + +die() { + echo "error: $*" >&2 + exit 1 +} + +SOC_VERSION="Ascend950PR_9599" +OUTPUT_DIR="" +EXAMPLE_PATH="" +EXAMPLE_ARGS=() + +while [[ $# -gt 0 ]]; do + case "$1" in + --output) + [[ $# -ge 2 ]] || die "--output requires a value" + OUTPUT_DIR="$2" + shift 2 + ;; + --soc-version) + [[ $# -ge 2 ]] || die "--soc-version requires a value" + SOC_VERSION="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + --) + shift + EXAMPLE_ARGS=("$@") + break + ;; + -*) + die "unknown option: $1" + ;; + *) + if [[ -z "${EXAMPLE_PATH}" ]]; then + EXAMPLE_PATH="$1" + else + EXAMPLE_ARGS+=("$1") + fi + shift + ;; + esac +done + +[[ -n "${EXAMPLE_PATH}" ]] || die "missing " + +if [[ "${EXAMPLE_PATH}" != /* ]]; then + EXAMPLE_PATH="${REPO_ROOT}/${EXAMPLE_PATH}" +fi +[[ -f "${EXAMPLE_PATH}" ]] || die "example script not found: ${EXAMPLE_PATH}" + +if [[ -z "${ASCEND_HOME_PATH:-}" ]]; then + die "ASCEND_HOME_PATH is not set; source CANN setenv or export it first" +fi + +if [[ -z "${OUTPUT_DIR}" ]]; then + EXAMPLE_STEM="$(basename -- "${EXAMPLE_PATH}" .py)" + OUTPUT_DIR="${REPO_ROOT}/build/msprof_res/${EXAMPLE_STEM}" +fi + +SIM_LIB_DIR="${ASCEND_HOME_PATH}/tools/simulator/${SOC_VERSION}/lib" +[[ -d "${SIM_LIB_DIR}" ]] || die "simulator library directory not found: ${SIM_LIB_DIR}" + +mkdir -p "${OUTPUT_DIR}" + +source "${ASCEND_HOME_PATH}/bin/setenv.bash" +source "${REPO_ROOT}/set_ptoas_env.sh" +export LD_LIBRARY_PATH="${SIM_LIB_DIR}:${LD_LIBRARY_PATH:-}" +ulimit -n 65535 + +# msprof rejects group/other-writable working directories, so always launch +# from a private directory and use an absolute path for the example script. +cd "${HOME}" + +exec msprof op simulator \ + --soc-version="${SOC_VERSION}" \ + --output="${OUTPUT_DIR}" \ + python3 "${EXAMPLE_PATH}" "${EXAMPLE_ARGS[@]}" diff --git a/test/python/ptodsl_docs_as_test.py b/test/python/ptodsl_docs_as_test.py index 0f7bfd2fb..c920dc5ff 100644 --- a/test/python/ptodsl_docs_as_test.py +++ b/test/python/ptodsl_docs_as_test.py @@ -304,6 +304,8 @@ def verify_compiled_target( directive: DocTestDirective, namespace: dict[str, object], ptoas_bin: Path, + *, + frontend_verify: bool, ) -> None: expect(directive.symbol is not None, f"{block_label(block)}: compile mode requires a symbol") expect(directive.compile_kwargs is not None, f"{block_label(block, directive.symbol)}: compile mode requires compile kwargs") @@ -341,13 +343,14 @@ def verify_compiled_target( label = block_label(block, directive.symbol) expect_parse_roundtrip_and_verify(mlir_text, label) - run_ptoas_frontend_verify(ptoas_bin, mlir_text, label) + if frontend_verify: + run_ptoas_frontend_verify(ptoas_bin, mlir_text, label) def run_compile_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: directive = parse_test_directive(block) namespace = execute_source(block.text, block, directive.symbol) - verify_compiled_target(block, directive, namespace, ptoas_bin) + verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) def run_compile_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: @@ -367,7 +370,7 @@ def run_compile_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> Non f"{block_label(block, directive.symbol)}: fragment fixture {directive.fixture!r} is invalid: {exc}" ) from exc namespace = execute_source(rendered_source, block, directive.symbol) - verify_compiled_target(block, directive, namespace, ptoas_bin) + verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) def run_launch_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: diff --git a/test/python/ptodsl_jit_compile.py b/test/python/ptodsl_jit_compile.py index 3c8cd562b..020a0e41c 100644 --- a/test/python/ptodsl_jit_compile.py +++ b/test/python/ptodsl_jit_compile.py @@ -91,6 +91,36 @@ def host_vec_copy_explicit( pto.tile.store(o_tile, out) +@pto.jit(target="a5", insert_sync=False) +def host_vec_copy_no_insert_sync( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), +): + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + a_tile = pto.alloc_tile(shape=[1, 16], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, 16], dtype=pto.f32) + part = pto.partition_view(a_view, offsets=[0, 0], sizes=[A.shape[0], A.shape[1]]) + out = pto.partition_view(o_view, offsets=[0, 0], sizes=[O.shape[0], O.shape[1]]) + pto.tile.load(part, a_tile) + pto.tile.store(o_tile, out) + + +@pto.jit(target="a5", mode="explicit", insert_sync=True) +def host_vec_copy_explicit_insert_sync( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), +): + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + a_tile = pto.alloc_tile(shape=[1, 16], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, 16], dtype=pto.f32) + part = pto.partition_view(a_view, offsets=[0, 0], sizes=[A.shape[0], A.shape[1]]) + out = pto.partition_view(o_view, offsets=[0, 0], sizes=[O.shape[0], O.shape[1]]) + pto.tile.load(part, a_tile) + pto.tile.store(o_tile, out) + + @pto.jit(target="a5") def runtime_metadata_kernel( A: pto.tensor_spec(rank=2, dtype=pto.f32), @@ -296,6 +326,29 @@ def runtime_scalar_operator_probe( _ = in_range +@pto.jit(target="a5") +def host_runtime_scalar_entry_probe( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + limit: pto.i32, + alpha: pto.f32, +): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A) + o_view = pto.make_tensor_view(O) + a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) + o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) + a_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.f32, valid_shape=[1, cols]) + o_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.f32, valid_shape=[1, cols]) + pto.tile.load(a_part, a_tile) + row_limit = limit // pto.const(2, dtype=pto.i32) + scaled = alpha + 1.0 + _ = row_limit + _ = scaled + pto.tile.store(o_tile, o_part) + + @pto.simd def tile_slice_vector_probe(inp_tile: pto.Tile, out_tile: pto.Tile, row: pto.index): mask, _ = pto.plt_b32(pto.const(64, dtype=pto.i32)) @@ -1116,8 +1169,27 @@ def main() -> None: expect_parse_roundtrip_and_verify(explicit_text, "explicit host_vec_copy specialization") expect("!pto.tile_buf" in default_text, "default specialization MLIR missing BLOCK=128 tile") expect("!pto.tile_buf" in block64_text, "BLOCK=64 specialization MLIR missing specialized tile") + expect("attributes {pto.aicore}" in default_text, "default @pto.jit should emit a flat aicore entry by default") + expect("attributes {pto.aicore}" in explicit_text, "explicit @pto.jit should emit a flat aicore entry by default") + expect("builtin.module" not in default_text, "default @pto.jit should no longer emit a nested builtin.module container") expect('pto.mode = "auto"' in default_text, "default specialization should carry auto mode module metadata") expect('pto.mode = "explicit"' in explicit_text, "explicit specialization should carry explicit mode module metadata") + expect( + host_vec_copy.compile()._module_spec.insert_sync is None, + "default @pto.jit insert_sync should stay unset and follow mode defaults", + ) + expect( + host_vec_copy_explicit.compile()._module_spec.insert_sync is None, + "explicit @pto.jit insert_sync should stay unset and follow mode defaults", + ) + expect( + host_vec_copy_no_insert_sync.compile()._module_spec.insert_sync is False, + "@pto.jit(insert_sync=False) should preserve the explicit override", + ) + expect( + host_vec_copy_explicit_insert_sync.compile()._module_spec.insert_sync is True, + "@pto.jit(insert_sync=True) should preserve the explicit override", + ) expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") auto_mode_violation = expect_raises( RuntimeError, @@ -1296,6 +1368,20 @@ def main() -> None: expect("arith.cmpf ole" in runtime_scalar_text, "float runtime '<=' should lower to arith.cmpf ole") expect("arith.andi" in runtime_scalar_text, "i1 conjunction from native '&' should lower to arith.andi") + host_runtime_scalar_entry_text = host_runtime_scalar_entry_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify( + host_runtime_scalar_entry_text, + "host runtime scalar entry specialization", + ) + expect( + "func.func @host_runtime_scalar_entry_probe" in host_runtime_scalar_entry_text, + "host runtime scalar entry probe should compile into a launchable kernel", + ) + expect( + "i32" in host_runtime_scalar_entry_text and "f32" in host_runtime_scalar_entry_text, + "host runtime scalar entry probe should preserve scalar ABI argument types in MLIR", + ) + signed_integer_scalar_text = signed_integer_scalar_probe.compile().mlir_text() expect_parse_roundtrip_and_verify(signed_integer_scalar_text, "signed integer scalar specialization") expect( diff --git a/test/python/ptodsl_jit_diagnostics.py b/test/python/ptodsl_jit_diagnostics.py index 76fb84f50..f5652306b 100644 --- a/test/python/ptodsl_jit_diagnostics.py +++ b/test/python/ptodsl_jit_diagnostics.py @@ -104,6 +104,22 @@ def bad_probe(*, BLOCK: pto.constexpr): return bad_probe +def define_missing_entry_annotation_probe(): + @pto.jit(target="a5") + def bad_probe(A): + pto.pipe_barrier(pto.Pipe.ALL) + + return bad_probe + + +def define_ptr_entry_annotation_probe(): + @pto.jit(target="a5") + def bad_probe(A: pto.ptr(pto.f32, "gm")): + pto.pipe_barrier(pto.Pipe.ALL) + + return bad_probe + + @pto.jit(target="a5") def missing_if_branch_probe(): with pto.if_(pto.const(1, dtype=pto.i1)) as br: @@ -238,6 +254,20 @@ def main() -> None: TypeError, "@pto.jit constexpr parameter 'BLOCK' must declare a default value", ) + expect_raises( + define_missing_entry_annotation_probe, + TypeError, + "@pto.jit positional parameter 'A' does not declare an entry ABI annotation", + "pto.tensor_spec(...)", + "pto.i32/pto.f32/pto.i1", + ) + expect_raises( + define_ptr_entry_annotation_probe, + TypeError, + "@pto.jit positional parameter 'A' uses unsupported entry annotation", + "pto.ptr(", + "not at the host/kernel entry", + ) expect_raises( missing_if_branch_probe.compile, RuntimeError, From c236b2491df12382ca453e2ece90b08a8c5a7e40 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 25 May 2026 10:58:16 +0800 Subject: [PATCH 24/31] Refine the online softmax demo --- lib/TileOps/tadd_template_tracing_poc.py | 78 ---- ptodsl/README.md | 38 +- ptodsl/docs/user_guide/02-quick-start.md | 3 +- ptodsl/docs/user_guide/05-control-flow.md | 2 +- .../docs/user_guide/07-data-movement-ops.md | 10 +- .../docs/user_guide/08-compute-operations.md | 2 +- .../user_guide/09-predicate-and-mask-ops.md | 10 +- .../jit/flash_attention_softmax_launch.py | 397 ++++++------------ ptodsl/examples/softmax_dsl.py | 375 +++++++---------- ptodsl/examples/tadd_dsl.py | 4 +- ptodsl/examples/tilelang_codegen.py | 1 - ptodsl/ptodsl/_control_flow.py | 31 +- ptodsl/ptodsl/_diagnostics.py | 34 +- ptodsl/ptodsl/_ops.py | 122 ++++-- ptodsl/ptodsl/_surface_values.py | 27 +- ptodsl/ptodsl/pto.py | 13 +- quick_install.sh | 8 + set_ptoas_env.sh | 8 + test/python/ptodsl_jit_compile.py | 132 +++++- 19 files changed, 608 insertions(+), 687 deletions(-) delete mode 100644 lib/TileOps/tadd_template_tracing_poc.py diff --git a/lib/TileOps/tadd_template_tracing_poc.py b/lib/TileOps/tadd_template_tracing_poc.py deleted file mode 100644 index 0b711f6da..000000000 --- a/lib/TileOps/tadd_template_tracing_poc.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -"""Experimental `ptodsl.vpto` POC version of the TileLang tadd template. - -This keeps the authored kernel body intentionally close to -`lib/TileOps/tadd_template.py`, but routes it through the experimental -`ptodsl.vpto` path instead of the TileLang AST frontend. -""" - -from __future__ import annotations - -import sys -from pathlib import Path - - -REPO_ROOT = Path(__file__).resolve().parents[2] -PTODSL_DIR = REPO_ROOT / "ptodsl" -if str(PTODSL_DIR) not in sys.path: - sys.path.insert(0, str(PTODSL_DIR)) - -from ptodsl import vpto as pto - - -@pto.vkernel( - target="a5", - op="pto.tadd", -) -def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - dtype = dst.element_type - valid_rows, valid_cols = dst.valid_shape - mask_scalar_ty = pto.i32 - - with pto.for_(0, valid_rows, step=1) as row: - remained0 = pto.scalar_const(64, mask_scalar_ty) - with pto.for_(0, valid_cols, step=pto.get_lanes(dtype), state={"remained": remained0}) as loop: - col = loop.iv - remained = loop.state.remained - mask, next_remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - summed = pto.vadd(lhs, rhs, mask) - pto.vsts(summed, dst[row, col:], mask) - loop.yield_state(remained=next_remained) - - -def build_specialized_kernel(): - return template_tadd.specialize( - src0=pto.TileSpec(shape=(16, 64), dtype=pto.f32), - src1=pto.TileSpec(shape=(16, 64), dtype=pto.f32), - dst=pto.TileSpec(shape=(16, 64), dtype=pto.f32), - ) - - -def main(argv: list[str]) -> int: - materialized = build_specialized_kernel() - - if len(argv) > 2: - print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) - return 2 - - if len(argv) == 2: - output_path = Path(argv[1]) - materialized.emit(output_path) - print(f"wrote MLIR to {output_path}") - return 0 - - print(materialized.mlir_text(), end="") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main(sys.argv)) diff --git a/ptodsl/README.md b/ptodsl/README.md index a4e232ae7..0f017bfe6 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -18,7 +18,7 @@ ptodsl/ │ ├── _bootstrap.py # MLIR path setup + context factory │ ├── _types.py # lazy dtype descriptors and type constructors │ ├── _ops.py # PTO operation wrappers -│ ├── _control_flow.py # vecscope, for_, if_, yield_ context managers +│ ├── _control_flow.py # for_, if_, yield_ context managers │ ├── _jit.py # @pto.jit decorator │ ├── _tracing/ # shared tracing runtime building blocks │ └── _tile_template_tracing.py # internal tile-template tracing implementation @@ -121,10 +121,10 @@ python3 ptodsl/examples/jit/tadd_launch.py ### `flash_attention_softmax_launch.py` -Launchable flash-attention softmax-stage demo. It intentionally keeps the -online softmax update stage only, so the runtime path can be validated without -depending on the still-incomplete SIMT/cube coverage needed for the full -flash-attention stack. +Launchable row-wise softmax demo. The kernel surface is the ordinary +`scores -> out` contract, while the implementation preloads the score matrix to +UB and then uses a packed online-softmax recurrence so one NPU can stream +64-row packs sequentially from UB. Compile-only: @@ -141,8 +141,8 @@ scripts/sim_dsl.sh ptodsl/examples/jit/flash_attention_softmax_launch.py Expected output: ```text -PASS rows8_seq128 -PASS rows17_seq96 +PASS rows64_seq128 +PASS rows81_seq96 All cases passed. ``` @@ -301,9 +301,9 @@ exported on the public surface: `@pto.jit(mode="auto")`, ### Type constructors (eager – require active context) ```python -vf32 = pto.vreg_type(64, pto.float32) # !pto.vreg<64xf32> -tile_col = pto.tile_buf_type([8,1], pto.float32, [-1,1], blayout="ColMajor") -tile_w = pto.tile_buf_type([8,128], pto.float32, [-1,-1]) +vf32 = pto.vreg_type(64, pto.float32) # !pto.vreg<64xf32> +tile_col = pto.alloc_tile(shape=[8, 1], dtype=pto.float32, blayout="ColMajor") +tile_w = pto.alloc_tile(shape=[8, 128], dtype=pto.float32) ``` ### Constants @@ -317,17 +317,20 @@ c64_i64= pto.const(64, dtype=pto.int64) ### Control flow ```python -with pto.vecscope(): # pto.vecscope { … } +with pto.simd(): # pto.simd { … } ... with pto.for_(c0, c16, step=c1) as i: # simple scf.for ... # scf.yield inserted automatically -with pto.for_(c0, c128, step=c64, iter_args=(a, b)) as loop: - x, y = loop.iter_args +loop = pto.for_(c0, c128, step=c64).carry(lhs=a, rhs=b) +with loop: + x = loop.lhs + y = loop.rhs ... - pto.yield_(nx, ny) # scf.yield with values -fx, fy = loop.results + loop.update(lhs=nx, rhs=ny) +fx = loop.final("lhs") +fy = loop.final("rhs") with pto.if_(has_rows) as br: # simple scf.if with br.then_: @@ -361,9 +364,8 @@ s.select(cond, t, f) # arith.select pto.castptr(addr, ptr_type) # pto.castptr pto.addptr(ptr, offset) # pto.addptr pto.vlds(ptr, offset) # pto.vlds, result vreg inferred from ptr element type -pto.vbrc_load(ptr, offset, vreg_type) # pto.vlds {dist="BRC_B32"} +pto.vbr(scalar) # pto.vbr, scalar broadcast -> vreg pto.vsts(v, ptr, offset, mask) # pto.vsts -pto.vsts_1pt(v, ptr, offset, mask) # pto.vsts {dist="1PT_B32"} pto.plt_b32(scalar) # → (mask, scalar_out) pto.pset_b32("PAT_ALL") # pto.pset_b32 → mask pto.vbitcast(v, dtype) # pto.vbitcast @@ -372,7 +374,7 @@ pto.vadd(a, b, mask) # infers result type from a.type pto.vmul / vmax / vdiv / vcmax / vcadd / vdup / vexpdif # similarly pto.make_tensor_view(ptr, shape=…, strides=…) # type inferred pto.partition_view(tv, offsets=…, sizes=…) # type inferred -pto.alloc_tile(shape=…, dtype=…, memory_space=…) # authored surface +pto.alloc_tile(shape=…, dtype=…, memory_space=…, valid_shape=…, addr=…) # authored surface pto.tile.load(part, tile) pto.tile.store(tile, part) tile.as_ptr() / view.as_ptr() diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md index 986fc5884..8ee4aa5b8 100644 --- a/ptodsl/docs/user_guide/02-quick-start.md +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -188,9 +188,8 @@ switch that kernel to `mode="explicit"`: def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, rows: pto.index, cols: pto.index): VEC = pto.elements_per_vreg(pto.f32) - initial_remained = scalar.index_cast(pto.i32, cols) with pto.for_(0, rows, step=1) as r: - col_loop = pto.for_(0, cols, step=VEC).carry(remained=initial_remained) + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) with col_loop: c = col_loop.iv remained = col_loop.remained diff --git a/ptodsl/docs/user_guide/05-control-flow.md b/ptodsl/docs/user_guide/05-control-flow.md index b2da72d56..ee300efe7 100644 --- a/ptodsl/docs/user_guide/05-control-flow.md +++ b/ptodsl/docs/user_guide/05-control-flow.md @@ -148,7 +148,7 @@ with col_loop: col_loop.update(remained=remained) ``` -`make_mask(dtype, n)` returns two values: the predicate mask for the current chunk and the updated remaining count. Passing the updated count back via `col_loop.update(remained=...)` feeds it into the next iteration, so each chunk correctly computes how many elements are left. +`make_mask(dtype, n)` returns two values: the predicate mask for the current chunk and the updated remaining count. Passing the updated count back via `col_loop.update(remained=...)` feeds it into the next iteration, so each chunk correctly computes how many elements are left. If `n` is an `index`, the updated remaining count stays an `index`; PTODSL hides the hardware `i32` tail-mask bookkeeping internally. ## 5.3 `pto.if_` — device-side conditionals diff --git a/ptodsl/docs/user_guide/07-data-movement-ops.md b/ptodsl/docs/user_guide/07-data-movement-ops.md index e3d3514a7..072019741 100644 --- a/ptodsl/docs/user_guide/07-data-movement-ops.md +++ b/ptodsl/docs/user_guide/07-data-movement-ops.md @@ -313,7 +313,7 @@ The compiler automatically computes the byte offset from the tile's shape, eleme | `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | | `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | | `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | | `dist` | `VLoadDist` or `None` | Optional load distribution: `NORM` (default), `UNPK_B8`/`UNPK_B16`/`UNPK_B32`, `BRC_B8`/`BRC_B16`/`BRC_B32` | **Returns**: @@ -341,7 +341,7 @@ the canonical form. | `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | | `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | | `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | | `dist` | `DeinterleaveDist` | `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` (alternating elements) or `BDINTLV` (block deinterleave) | **Returns**: @@ -543,7 +543,7 @@ distributions that use predicate masking. | `tile[row, col:]` | Tile index | 2D destination (vector-width range) | | `tile[start:]` | Tile index | 1D destination (vector-width range) | | `buf` | `PtrType` (UB) | Destination buffer (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | | `mask` | `MaskType` | Predicate mask gating writes | | `dist` | `VStoreDist` or `None` | Store distribution token. When omitted, PTODSL defaults to `NORM_B32` on the current surface. | @@ -596,7 +596,7 @@ into one destination. | `tile[row, col:]` | Tile index | 2D destination (vector-width range) | | `tile[start:]` | Tile index | 1D destination (vector-width range) | | `buf` | `PtrType` (UB) | Destination buffer (pointer form) | -| `offset` | `Index` | Byte offset (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | | `dist` | `InterleaveDist` | `INTLV_B8` / `INTLV_B16` / `INTLV_B32` | | `mask` | `MaskType` | Parameter retained for call-shape regularity; for the `INTLV_B*` family it does not affect the stored result | @@ -659,7 +659,7 @@ block-strided UB destination. Masked-off blocks do not write memory. | `tile[row, col:]` | Tile index | 2D destination (vector-width range) | | `tile[start:]` | Tile index | 1D destination (vector-width range) | | `buf` | `PtrType` (UB) | Destination buffer (pointer form) | -| `offset` | `Index` | Byte offset (all forms) | +| `offset` | `Index` | Element offset (all forms) | **Returns**: None (side-effect operation). diff --git a/ptodsl/docs/user_guide/08-compute-operations.md b/ptodsl/docs/user_guide/08-compute-operations.md index 41703f495..0da510042 100644 --- a/ptodsl/docs/user_guide/08-compute-operations.md +++ b/ptodsl/docs/user_guide/08-compute-operations.md @@ -734,7 +734,7 @@ These combine an arithmetic operation with a math function or activation in a si | Unary | `vexp`, `vln`, `vsqrt`, `vabs`, `vneg`, `vrec`, `vrsqrt`, `vrelu`, `vnot` | | Binary | `vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor`, `vshl`, `vshr` | | Vector-scalar | `vadds`, `vsubs`, `vmuls`, `vmaxs`, `vmins`, `vlrelu` | -| Broadcast | `vdup` | +| Broadcast | `vbr`, `vdup` | | Full reduction | `vcadd`, `vcmax`, `vcmin` | | Group reduction | `vcgadd`, `vcgmax`, `vcgmin` | | Scan | `vcpadd` | diff --git a/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md index 2ba02213a..0362a285d 100644 --- a/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md +++ b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md @@ -22,23 +22,23 @@ A mask and the vector it gates must share the same granularity: a `mask_b32` gat The recommended front door for creating masks is `pto.make_mask`. It dispatches to the right underlying op based on its arguments. -#### `pto.make_mask(dtype: Type, value: int | MaskPattern) -> MaskType | (MaskType, int)` +#### `pto.make_mask(dtype: Type, value: int-like | MaskPattern) -> MaskType | (MaskType, int-like)` -**Description**: Creates a predicate mask of the granularity matching `dtype`. When `value` is an `int` (typically a remaining-element count in a chunked loop), returns a tuple `(mask, remaining)`. When `value` is a `MaskPattern`, returns just the mask. +**Description**: Creates a predicate mask of the granularity matching `dtype`. When `value` is an integer-like scalar (typically a remaining-element count in a chunked loop), returns a tuple `(mask, remaining)`. When `value` is a `MaskPattern`, returns just the mask. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `dtype` | `Type` | Element type to infer mask granularity from (e.g., `pto.f32` → `mask_b32`, `pto.f16` → `mask_b16`) | -| `value` | `int` or `MaskPattern` | Either a remaining-element count or a pattern token | +| `value` | `int-like` or `MaskPattern` | Either a remaining-element count or a pattern token | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| | `mask` | `MaskType` | The created mask | -| `remained` | `int` | Updated remaining count (only when `value` is `int`) | +| `remained` | `int-like` | Updated remaining count (only when `value` is an integer-like scalar); its scalar kind is preserved, so an `index` remainder stays an `index` | **Example** — chunked SIMD loop with tail handling: @@ -56,7 +56,7 @@ with col_loop: col_loop.update(remained=remained) ``` -`make_mask` generates a tail mask from the remaining count: the first `min(remained, VL)` lanes are active, and `remained` is decremented by `VL` for the next iteration. On the final partial chunk, fewer than `VL` lanes are active. +`make_mask` generates a tail mask from the remaining count: the first `min(remained, VL)` lanes are active, and `remained` is decremented by `VL` for the next iteration. On the final partial chunk, fewer than `VL` lanes are active. PTODSL handles the hardware `i32` tail-mask operand internally, so loop-carried `index` metadata can flow through `make_mask` without manual casts. --- diff --git a/ptodsl/examples/jit/flash_attention_softmax_launch.py b/ptodsl/examples/jit/flash_attention_softmax_launch.py index cb538de77..3ed204bd7 100644 --- a/ptodsl/examples/jit/flash_attention_softmax_launch.py +++ b/ptodsl/examples/jit/flash_attention_softmax_launch.py @@ -7,28 +7,27 @@ # See LICENSE in the root of the software repository for the full text of the License. """ -Flash-attention softmax stage — end-to-end launch demo. +Row-wise softmax — end-to-end launch demo. This example is the launchable counterpart to the compile-only -``flash_attention_sketch.py`` demo. It intentionally keeps only the online -softmax update stage from flash attention because the current PTODSL runtime -path is already strong enough for vector-heavy softmax, while the full -flash-attention stack still depends on simt/cube capabilities that are not yet -complete for an end-to-end runtime demo. +``softmax_dsl.py`` sample. It uses an online-softmax recurrence internally, +but the public kernel surface is the ordinary softmax contract: load an input +score matrix, compute the per-row softmax, and store the normalized output. -Each kernel instance updates one block of up to 8 rows: +Each kernel instance runs on one NPU, preloads the full score matrix to UB, +where the input is already laid out as ``[seq, rows]`` so each UB row +represents one score column, and then streams 64-row packs through the +online-softmax recurrence: - m_next = max(m_prev, row_max(scores)) - p = exp(scores - m_next) - l_next = l_prev * exp(m_prev - m_next) + row_sum(p) - expmax = l_prev * exp(m_prev - m_next) / l_next - out = p / l_next + running_max = max(running_max, score_col) + running_sum = running_sum * exp(old_max - new_max) + exp(score_col - new_max) + out = exp(score_col - final_max) / final_sum -The demo offers two fixed-shape launchable kernels so the current launch ABI -does not need runtime scalar parameters: +The demo offers two launchable kernels so the current launch ABI does not need +an extra runtime tile-width parameter: -- ``rows8_seq128``: full-width 128-column softmax -- ``rows17_seq96``: multi-block + tail-mask coverage +- ``rows64_seq128``: full-width 64-row packed softmax +- ``rows81_seq96``: same single NPU, but two sequential row-pack updates """ import argparse @@ -49,224 +48,151 @@ "Unable to locate the PTODSL Python package root from flash_attention_softmax_launch.py" ) -from ptodsl import pto, scalar - -s = scalar +from ptodsl import pto _DEVICE = "npu:0" -_ROWS_PER_BLOCK = 8 -_PHYSICAL_COLS = 128 -def _make_flash_attention_softmax_kernel(name: str, *, rows: int, seq: int): +def _make_softmax_kernel(name: str, *, rows: int, seq: int): if rows <= 0: raise ValueError("rows must be positive") - if not 0 < seq <= _PHYSICAL_COLS: - raise ValueError(f"seq must be in [1, {_PHYSICAL_COLS}]") + if seq <= 0: + raise ValueError("seq must be positive") @pto.jit( name=name, - kernel_kind="vector", target="a5", mode="explicit", insert_sync=False ) def kernel( - oldmax: pto.tensor_spec(rank=2, dtype=pto.f32), - oldsum: pto.tensor_spec(rank=2, dtype=pto.f32), scores: pto.tensor_spec(rank=2, dtype=pto.f32), - newmax: pto.tensor_spec(rank=2, dtype=pto.f32), - newsum: pto.tensor_spec(rank=2, dtype=pto.f32), - expmax: pto.tensor_spec(rank=2, dtype=pto.f32), out: pto.tensor_spec(rank=2, dtype=pto.f32), ): - c0 = pto.const(0) - c1 = pto.const(1) - c8 = pto.const(_ROWS_PER_BLOCK) - c64 = pto.const(64) - c128 = pto.const(_PHYSICAL_COLS) - c_rows = pto.const(rows) - c_seq = pto.const(seq) - c_rows_x_128 = pto.const(rows * _PHYSICAL_COLS) - - c0_i64 = pto.const(0, dtype=pto.int64) - c128_i64 = pto.const(128, dtype=pto.int64) - c256_i64 = pto.const(256, dtype=pto.int64) - c8448_i64 = pto.const(8448, dtype=pto.int64) - c16640_i64 = pto.const(16640, dtype=pto.int64) - c16768_i64 = pto.const(16768, dtype=pto.int64) - c16896_i64 = pto.const(16896, dtype=pto.int64) - - c0_i32 = pto.const(0, dtype=pto.int32) - c1_i32 = pto.const(1, dtype=pto.int32) - c8_i32 = pto.const(_ROWS_PER_BLOCK, dtype=pto.int32) - c_seq_i32 = pto.const(seq, dtype=pto.int32) - c_rows_i32 = pto.const(rows, dtype=pto.int32) - - block_i64 = pto.get_block_idx() - block_idx = s.index_cast(block_i64) - row_base = s.muli(block_idx, c8) - row_base_i32 = s.index_cast(pto.int32, row_base) - remaining_rows = s.subi(c_rows_i32, row_base_i32) - has_rows = remaining_rows > c0_i32 - too_many_rows = remaining_rows > c8_i32 - row_count_i32 = s.select(too_many_rows, c8_i32, remaining_rows) - row_count = s.index_cast(row_count_i32) - - with pto.if_(has_rows) as has_rows_br: - with has_rows_br.then_: - s1 = [c_rows, c_rows, c_rows, c1, c_rows] - s128 = [c_rows_x_128, c_rows_x_128, c_rows_x_128, c128, c1] - sh1 = [c1, c1, c1, c_rows, c1] - sh128 = [c1, c1, c1, c_rows, c128] - - oldmax_view = pto.make_tensor_view(oldmax, shape=sh1, strides=s1) - oldsum_view = pto.make_tensor_view(oldsum, shape=sh1, strides=s1) - scores_view = pto.make_tensor_view(scores, shape=sh128, strides=s128) - newmax_view = pto.make_tensor_view(newmax, shape=sh1, strides=s1) - newsum_view = pto.make_tensor_view(newsum, shape=sh1, strides=s1) - expmax_view = pto.make_tensor_view(expmax, shape=sh1, strides=s1) - out_view = pto.make_tensor_view(out, shape=sh128, strides=s128) - - off = [c0, c0, c0, row_base, c0] - z1 = [c1, c1, c1, row_count, c1] - zs = [c1, c1, c1, row_count, c_seq] - - oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) - oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) - scores_part = pto.partition_view(scores_view, offsets=off, sizes=zs) - newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) - newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) - expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) - out_part = pto.partition_view(out_view, offsets=off, sizes=zs) - - tile_col = pto.tile_buf_type([8, 1], pto.float32, [-1, 1], blayout="ColMajor") - tile_w = pto.tile_buf_type([8, 128], pto.float32, [-1, -1]) - - oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) - oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) - scores_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=c_seq) - out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=c_seq) - newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) - newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) - expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) - - pto.tile.load(oldmax_part, oldmax_tile) - pto.tile.load(oldsum_part, oldsum_tile) - pto.tile.load(scores_part, scores_tile) - - pto.set_flag("MTE2", "V", event_id=0) - pto.wait_flag("MTE2", "V", event_id=0) - - with pto.vecscope(): - ptr_ub = pto.ptr(pto.float32, "ub") - vf32 = pto.vreg_type(64, pto.float32) - - ub_om = pto.as_ptr(oldmax_tile, ptr_ub) - ub_os = pto.as_ptr(oldsum_tile, ptr_ub) - ub_scores = pto.as_ptr(scores_tile, ptr_ub) - ub_out = pto.as_ptr(out_tile, ptr_ub) - ub_nm = pto.as_ptr(newmax_tile, ptr_ub) - ub_ns = pto.as_ptr(newsum_tile, ptr_ub) - ub_em = pto.as_ptr(expmax_tile, ptr_ub) - - active = pto.pset_b32(pto.MaskPattern.ALL) - one_mask, _ = pto.plt_b32(c1_i32) - - with pto.for_(c0, row_count, step=c1) as row: - row_scores = s.muli(row, c128) - oldmax_bc = pto.vbrc_load(ub_om, row, vf32) - oldsum_bc = pto.vbrc_load(ub_os, row, vf32) - - with pto.for_(c0, c128, step=c64, iter_args=(oldmax_bc, oldsum_bc)) as softmax_loop: - chunk = softmax_loop.iv - running_max, running_sum = softmax_loop.iter_args - - chunk_i32 = s.index_cast(pto.int32, chunk) - remaining_cols = s.subi(c_seq_i32, chunk_i32) - has_chunk = remaining_cols > c0_i32 - - with pto.if_(has_chunk) as br: - with br.then_: - chunk_mask, _ = pto.plt_b32(remaining_cols) - chunk_base = s.addi(row_scores, chunk) - vec = pto.vlds(ub_scores, chunk_base, vf32) - chunk_max = pto.vcmax(vec, chunk_mask) - chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") - merged_max = pto.vmax(running_max, chunk_max_bc, active) - scaled_running = pto.vexpdif(running_max, merged_max, active) - running_sum_scaled = pto.vmul(scaled_running, running_sum, active) - chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask) - chunk_sum = pto.vcadd(chunk_exp, chunk_mask) - chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") - merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) - br.assign(next_max=merged_max, next_sum=merged_sum) - with br.else_: - br.assign(next_max=running_max, next_sum=running_sum) - pto.yield_(br.next_max, br.next_sum) - - final_max, final_sum = softmax_loop.results - - raw_em = pto.vexpdif(oldmax_bc, final_max, active) - scaled_oldsum = pto.vmul(raw_em, oldsum_bc, active) - expmax = pto.vdiv(scaled_oldsum, final_sum, active) - - pto.vsts_1pt(final_max, ub_nm, row, one_mask) - pto.vsts_1pt(final_sum, ub_ns, row, one_mask) - pto.vsts_1pt(expmax, ub_em, row, one_mask) - - with pto.for_(c0, c128, step=c64) as chunk2: - rem2 = s.subi(c_seq_i32, s.index_cast(pto.int32, chunk2)) - has_chunk2 = rem2 > c0_i32 - with pto.if_(has_chunk2) as br2: - with br2.then_: - cmask2, _ = pto.plt_b32(rem2) - cbase2 = s.addi(row_scores, chunk2) - vec2 = pto.vlds(ub_scores, cbase2, vf32) - exp2 = pto.vexpdif(vec2, final_max, cmask2) - out2 = pto.vdiv(exp2, final_sum, cmask2) - pto.vsts(out2, ub_out, cbase2, cmask2) - - pto.set_flag("V", "MTE3", event_id=0) - pto.wait_flag("V", "MTE3", event_id=0) - - pto.tile.store(newmax_tile, newmax_part) - pto.tile.store(newsum_tile, newsum_part) - pto.tile.store(expmax_tile, expmax_part) - pto.tile.store(out_tile, out_part) + lane_num = pto.elements_per_vreg(pto.f32) + physical_rows = ((rows + lane_num - 1) // lane_num) * lane_num + scores_tile_bytes = seq * physical_rows * pto.bytewidth(pto.f32) + runtime_seq = scores.shape[0] + runtime_rows = scores.shape[1] + total_elems = runtime_rows * runtime_seq + + scores_view = pto.make_tensor_view( + scores, + shape=[1, 1, 1, runtime_seq, runtime_rows], + strides=[total_elems, total_elems, total_elems, runtime_rows, 1], + ) + out_view = pto.make_tensor_view( + out, + shape=[1, 1, 1, runtime_seq, runtime_rows], + strides=[total_elems, total_elems, total_elems, runtime_rows, 1], + ) + scores_part = pto.partition_view( + scores_view, + offsets=[0, 0, 0, 0, 0], + sizes=[1, 1, 1, runtime_seq, runtime_rows], + ) + out_part = pto.partition_view( + out_view, + offsets=[0, 0, 0, 0, 0], + sizes=[1, 1, 1, runtime_seq, runtime_rows], + ) + + scores_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=0, + valid_shape=[runtime_seq, runtime_rows], + blayout="RowMajor", + ) + out_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=scores_tile_bytes, + valid_shape=[runtime_seq, runtime_rows], + blayout="RowMajor", + ) + pto.tile.load(scores_part, scores_tile) + out_tile.fill(0.0) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.simd(): + row_loop = pto.for_(0, runtime_rows, step=lane_num).carry(remained=runtime_rows) + with row_loop: + row_base = row_loop.iv + remaining_rows = row_loop.remained + active_rows, remaining_after_pack = pto.make_mask(pto.f32, remaining_rows) + running_max = pto.vlds(scores_tile[0, row_base:]) + running_sum = pto.vbr(1.0) + + softmax_loop = pto.for_(1, runtime_seq, step=1).carry( + running_max=running_max, + running_sum=running_sum, + ) + with softmax_loop: + col = softmax_loop.iv + running_max = softmax_loop.running_max + running_sum = softmax_loop.running_sum + col_vec = pto.vlds(scores_tile[col, row_base:]) + merged_max = pto.vmax(running_max, col_vec, active_rows) + running_delta = pto.vsub(running_max, merged_max, active_rows) + scaled_running = pto.vexp(running_delta, active_rows) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active_rows) + col_delta = pto.vsub(col_vec, merged_max, active_rows) + col_exp = pto.vexp(col_delta, active_rows) + merged_sum = pto.vadd(running_sum_scaled, col_exp, active_rows) + softmax_loop.update(running_max=merged_max, running_sum=merged_sum) + + final_max = softmax_loop.final("running_max") + final_sum = softmax_loop.final("running_sum") + + with pto.for_(0, runtime_seq, step=1) as col: + col_vec = pto.vlds(scores_tile[col, row_base:]) + out_delta = pto.vsub(col_vec, final_max, active_rows) + exp_vec = pto.vexp(out_delta, active_rows) + out_vec = pto.vdiv(exp_vec, final_sum, active_rows) + pto.vsts(out_vec, out_tile[col, row_base:], active_rows) + + row_loop.update(remained=remaining_after_pack) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + pto.tile.store(out_tile, out_part) pto.pipe_barrier(pto.Pipe.ALL) return kernel -FLASH_SOFTMAX_ROWS8_SEQ128 = _make_flash_attention_softmax_kernel( - "flash_attention_softmax_rows8_seq128", - rows=8, +SOFTMAX_ROWS64_SEQ128 = _make_softmax_kernel( + "softmax_rows64_seq128", + rows=64, seq=128, ) -FLASH_SOFTMAX_ROWS17_SEQ96 = _make_flash_attention_softmax_kernel( - "flash_attention_softmax_rows17_seq96", - rows=17, +SOFTMAX_ROWS81_SEQ96 = _make_softmax_kernel( + "softmax_rows81_seq96", + rows=81, seq=96, ) KERNELS = ( - FLASH_SOFTMAX_ROWS8_SEQ128, - FLASH_SOFTMAX_ROWS17_SEQ96, + SOFTMAX_ROWS64_SEQ128, + SOFTMAX_ROWS81_SEQ96, ) CASES = [ { - "name": "rows8_seq128", - "kernel": FLASH_SOFTMAX_ROWS8_SEQ128, - "rows": 8, + "name": "rows64_seq128", + "kernel": SOFTMAX_ROWS64_SEQ128, + "rows": 64, "seq": 128, }, { - "name": "rows17_seq96", - "kernel": FLASH_SOFTMAX_ROWS17_SEQ96, - "rows": 17, + "name": "rows81_seq96", + "kernel": SOFTMAX_ROWS81_SEQ96, + "rows": 81, "seq": 96, }, ] @@ -276,28 +202,11 @@ def emit_mlir(): return pto.merge_jit_modules(*KERNELS) -def reference_online_softmax_update(oldmax: np.ndarray, oldsum: np.ndarray, scores: np.ndarray, seq: int): - rows = oldmax.shape[0] - newmax = np.empty_like(oldmax) - newsum = np.empty_like(oldsum) - expmax = np.empty_like(oldsum) - out = np.full_like(scores, np.nan) - - for row in range(rows): - m_prev = float(oldmax[row, 0]) - l_prev = float(oldsum[row, 0]) - row_scores = scores[row, :seq] - m_next = max(m_prev, float(np.max(row_scores))) - shifted = np.exp(row_scores - m_next) - l_scaled = l_prev * np.exp(m_prev - m_next) - l_next = l_scaled + float(np.sum(shifted)) - - newmax[row, 0] = m_next - newsum[row, 0] = l_next - expmax[row, 0] = l_scaled / l_next - out[row, :seq] = shifted / l_next - - return newmax, newsum, expmax, out +def reference_softmax(scores: np.ndarray): + row_max = np.max(scores, axis=0, keepdims=True) + shifted = np.exp(scores - row_max, dtype=np.float32) + row_sum = np.sum(shifted, axis=0, keepdims=True, dtype=np.float32) + return shifted / row_sum def init_runtime(): @@ -319,37 +228,17 @@ def make_case_inputs(case: dict[str, object]): seq = int(case["seq"]) rng = np.random.RandomState(hash(case["name"]) & 0xFFFFFFFF) - oldmax = rng.uniform(-2.0, 2.0, size=(rows, 1)).astype(np.float32) - oldsum = rng.uniform(0.25, 3.0, size=(rows, 1)).astype(np.float32) - scores = np.full((rows, _PHYSICAL_COLS), -1000.0, dtype=np.float32) - scores[:, :seq] = rng.uniform(-4.0, 4.0, size=(rows, seq)).astype(np.float32) - - newmax = np.full((rows, 1), np.nan, dtype=np.float32) - newsum = np.full((rows, 1), np.nan, dtype=np.float32) - expmax = np.full((rows, 1), np.nan, dtype=np.float32) - out = np.full((rows, _PHYSICAL_COLS), np.nan, dtype=np.float32) + scores = rng.uniform(-4.0, 4.0, size=(seq, rows)).astype(np.float32) + out = np.zeros((seq, rows), dtype=np.float32) - return oldmax, oldsum, scores, newmax, newsum, expmax, out + return scores, out def run_case(case: dict[str, object], torch) -> None: - rows = int(case["rows"]) - seq = int(case["seq"]) - grid = (rows + _ROWS_PER_BLOCK - 1) // _ROWS_PER_BLOCK - oldmax, oldsum, scores, newmax, newsum, expmax, out = make_case_inputs(case) - ref_newmax, ref_newsum, ref_expmax, ref_out = reference_online_softmax_update( - oldmax, - oldsum, - scores, - seq, - ) + scores, out = make_case_inputs(case) + ref_out = reference_softmax(scores) - oldmax_t = torch.from_numpy(oldmax).to(_DEVICE) - oldsum_t = torch.from_numpy(oldsum).to(_DEVICE) scores_t = torch.from_numpy(scores).to(_DEVICE) - newmax_t = torch.from_numpy(newmax).to(_DEVICE) - newsum_t = torch.from_numpy(newsum).to(_DEVICE) - expmax_t = torch.from_numpy(expmax).to(_DEVICE) out_t = torch.from_numpy(out).to(_DEVICE) stream = npu_stream(torch) @@ -358,24 +247,14 @@ def run_case(case: dict[str, object], torch) -> None: compile_s = time.perf_counter() - t0 t0 = time.perf_counter() - compiled[grid, stream]( - oldmax_t, - oldsum_t, + compiled[1, stream]( scores_t, - newmax_t, - newsum_t, - expmax_t, out_t, ) torch.npu.synchronize() launch_s = time.perf_counter() - t0 - np.testing.assert_allclose(newmax_t.cpu().numpy(), ref_newmax, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(newsum_t.cpu().numpy(), ref_newsum, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(expmax_t.cpu().numpy(), ref_expmax, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(out_t.cpu().numpy()[:, :seq], ref_out[:, :seq], rtol=1e-5, atol=1e-5) - if seq < _PHYSICAL_COLS: - assert np.isnan(out_t.cpu().numpy()[:, seq:]).all(), "tail columns should remain untouched" + np.testing.assert_allclose(out_t.cpu().numpy(), ref_out, rtol=1e-5, atol=1e-5) print( f"PASS {case['name']} " @@ -383,7 +262,7 @@ def run_case(case: dict[str, object], torch) -> None: ) -def test_flash_attention_softmax() -> None: +def test_softmax() -> None: torch = init_runtime() for case in CASES: run_case(case, torch) @@ -403,7 +282,7 @@ def main(argv=None) -> int: print(emit_mlir()) return 0 - test_flash_attention_softmax() + test_softmax() return 0 diff --git a/ptodsl/examples/softmax_dsl.py b/ptodsl/examples/softmax_dsl.py index d55a271dd..fb417bbe1 100644 --- a/ptodsl/examples/softmax_dsl.py +++ b/ptodsl/examples/softmax_dsl.py @@ -7,240 +7,159 @@ # See LICENSE in the root of the software repository for the full text of the License. """ -Online softmax update kernel – DSL-style builder. - -Generates the same IR as - test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto -using the ``@pto.jit(mode="explicit")`` decorator and the ``pto.*`` namespace. - -The Python maps almost line-for-line to the target MLIR: - - func.func @online_softmax_update_kernel_2d( # function signature - %arg0: !pto.ptr, …, %arg7: i32, …) # arg0: pto.ptr(…), … - - scf.if %has_rows { # with pto.if_(has_rows): - pto.tload ins(…) outs(…) # pto.tile.load(part, tile) - pto.vecscope { # with pto.vecscope(): - scf.for %row = … { # with pto.for_(…) as row: - %final_max, %final_sum = # - scf.for %chunk = … iter_args(…) { # with pto.for_(…, iter_args=…) as loop: - scf.if %has_chunk → (vreg, vreg) { # with pto.if_(…, results=…) as br: - scf.yield %merged_max, %merged_sum # pto.yield_(…) - } else { # with br.else_: - scf.yield %running_max, %running_sum # pto.yield_(…) - } # - scf.yield %next_max, %next_sum # pto.yield_(…) - } # - } # - } # - } # - pto.barrier # pto.pipe_barrier(pto.Pipe.ALL) -""" - -from ptodsl import pto, scalar - -s = scalar # arith shorthand alias +Row-wise softmax kernel – compile-only DSL builder. +This sample mirrors the launchable softmax demo. It uses a transposed logical +GM view so each UB row holds one score column, then processes 64 rows in +parallel with the online-softmax recurrence using only public PTODSL surface +syntax. +""" -@pto.jit( - name="online_softmax_update_kernel_2d", - kernel_kind="vector", - target="a5", - mode="explicit", +from pathlib import Path +import sys + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from softmax_dsl.py" + ) + +from ptodsl import pto + + +def _make_softmax_kernel(name: str, *, rows: int, seq: int): + if rows <= 0: + raise ValueError("rows must be positive") + if seq <= 0: + raise ValueError("seq must be positive") + + @pto.jit( + name=name, + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False, + ) + def kernel( + scores: pto.tensor_spec(rank=2, dtype=pto.f32), + out: pto.tensor_spec(rank=2, dtype=pto.f32), + ): + packed_rows = pto.elements_per_vreg(pto.f32) + physical_rows = ((rows + packed_rows - 1) // packed_rows) * packed_rows + scores_tile_bytes = seq * physical_rows * pto.bytewidth(pto.f32) + runtime_rows = scores.shape[0] + runtime_seq = scores.shape[1] + has_rows = runtime_rows > 0 + + with pto.if_(has_rows) as has_rows_br: + with has_rows_br.then_: + scores_view = pto.make_tensor_view( + scores, + shape=[seq, rows], + strides=[1, seq], + ) + out_view = pto.make_tensor_view( + out, + shape=[seq, rows], + strides=[1, seq], + ) + scores_part = pto.partition_view( + scores_view, + offsets=[0, 0], + sizes=[runtime_seq, runtime_rows], + ) + out_part = pto.partition_view( + out_view, + offsets=[0, 0], + sizes=[runtime_seq, runtime_rows], + ) + + scores_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=pto.const(0, dtype=pto.i64), + valid_shape=[runtime_seq, runtime_rows], + ) + out_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=pto.const(scores_tile_bytes, dtype=pto.i64), + valid_shape=[runtime_seq, runtime_rows], + ) + + pto.tile.load(scores_part, scores_tile) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.simd(): + row_loop = pto.for_(0, runtime_rows, step=packed_rows).carry(remained=runtime_rows) + with row_loop: + row_base = row_loop.iv + remaining_rows = row_loop.remained + active_rows, remaining_after_pack = pto.make_mask(pto.f32, remaining_rows) + running_max = pto.vlds(scores_tile[0, row_base:]) + running_sum = pto.vbr(1.0) + + softmax_loop = pto.for_(1, runtime_seq, step=1).carry( + running_max=running_max, + running_sum=running_sum, + ) + with softmax_loop: + col = softmax_loop.iv + running_max = softmax_loop.running_max + running_sum = softmax_loop.running_sum + col_vec = pto.vlds(scores_tile[col, row_base:]) + merged_max = pto.vmax(running_max, col_vec, active_rows) + running_delta = pto.vsub(running_max, merged_max, active_rows) + scaled_running = pto.vexp(running_delta, active_rows) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active_rows) + col_delta = pto.vsub(col_vec, merged_max, active_rows) + col_exp = pto.vexp(col_delta, active_rows) + merged_sum = pto.vadd(running_sum_scaled, col_exp, active_rows) + softmax_loop.update(running_max=merged_max, running_sum=merged_sum) + + final_max = softmax_loop.final("running_max") + final_sum = softmax_loop.final("running_sum") + + with pto.for_(0, runtime_seq, step=1) as col: + col_vec = pto.vlds(scores_tile[col, row_base:]) + out_delta = pto.vsub(col_vec, final_max, active_rows) + exp_vec = pto.vexp(out_delta, active_rows) + out_vec = pto.vdiv(exp_vec, final_sum, active_rows) + pto.vsts(out_vec, out_tile[col, row_base:], active_rows) + + row_loop.update(remained=remaining_after_pack) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + pto.tile.store(out_tile, out_part) + pto.pipe_barrier(pto.Pipe.ALL) + + return kernel + + +SOFTMAX_ROWS64_SEQ128 = _make_softmax_kernel( + "softmax_rows64_seq128_dsl", + rows=64, + seq=128, +) +SOFTMAX_ROWS81_SEQ96 = _make_softmax_kernel( + "softmax_rows81_seq96_dsl", + rows=81, + seq=96, ) -def online_softmax_update_kernel_2d( - arg0: pto.ptr(pto.float32, "gm"), - arg1: pto.ptr(pto.float32, "gm"), - arg2: pto.ptr(pto.float32, "gm"), - arg3: pto.ptr(pto.float32, "gm"), - arg4: pto.ptr(pto.float32, "gm"), - arg5: pto.ptr(pto.float32, "gm"), - arg6: pto.ptr(pto.float32, "gm"), - arg7: pto.int32, - arg8: pto.int32, -): - # ── Index constants ────────────────────────────────────────────────────── - c0 = pto.const(0) - c1 = pto.const(1) - c8 = pto.const(8) - c64 = pto.const(64) - c128 = pto.const(128) - - # ── i64 address constants (UB tile base addresses) ─────────────────────── - c0_i64 = pto.const(0, dtype=pto.int64) - c1_i64 = pto.const(1, dtype=pto.int64) # noqa: F841 (present in IR) - c8_i64 = pto.const(8, dtype=pto.int64) # noqa: F841 - c16_i64 = pto.const(16, dtype=pto.int64) # noqa: F841 - c32_i64 = pto.const(32, dtype=pto.int64) # noqa: F841 - c64_i64 = pto.const(64, dtype=pto.int64) # noqa: F841 - c128_i64 = pto.const(128, dtype=pto.int64) - c256_i64 = pto.const(256, dtype=pto.int64) - c512_i64 = pto.const(512, dtype=pto.int64) # noqa: F841 - c8448_i64 = pto.const(8448, dtype=pto.int64) - c16640_i64 = pto.const(16640, dtype=pto.int64) - c16768_i64 = pto.const(16768, dtype=pto.int64) - c16896_i64 = pto.const(16896, dtype=pto.int64) - - # ── i32 constants ──────────────────────────────────────────────────────── - c1_i32 = pto.const(1, dtype=pto.int32) - c8_i32 = pto.const(8, dtype=pto.int32) - c64_i32 = pto.const(64, dtype=pto.int32) - c0_i32 = pto.const(0, dtype=pto.int32) - - # ── Block-level row assignment ──────────────────────────────────────────── - block_i64 = pto.get_block_idx() - block_idx = s.index_cast(block_i64) # → index - row_base = s.muli(block_idx, c8) - _ = s.index_cast(pto.int32, c8) # block_rows_i32 - row_base_i32 = s.index_cast(pto.int32, row_base) - remaining_rows= s.subi(arg8, row_base_i32) - has_rows = remaining_rows > c0_i32 - too_many_rows = remaining_rows > c8_i32 - row_count_i32 = s.select(too_many_rows, c8_i32, remaining_rows) - row_count = s.index_cast(row_count_i32) # → index - seq = s.index_cast(arg7) # → index - rows = s.index_cast(arg8) # → index - rows_x_128 = s.muli(rows, c128) - - with pto.if_(has_rows) as has_rows_br: - with has_rows_br.then_: - # ── Tensor views ───────────────────────────────────────────────────── - s1 = [rows, rows, rows, c1, rows] - s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] - sh1 = [c1, c1, c1, rows, c1] - sh128= [c1, c1, c1, rows, c128] - - oldmax_view = pto.make_tensor_view(arg0, shape=sh1, strides=s1) - oldsum_view = pto.make_tensor_view(arg1, shape=sh1, strides=s1) - qk_view = pto.make_tensor_view(arg2, shape=sh128, strides=s128) - newmax_view = pto.make_tensor_view(arg3, shape=sh1, strides=s1) - newsum_view = pto.make_tensor_view(arg4, shape=sh1, strides=s1) - expmax_view = pto.make_tensor_view(arg5, shape=sh1, strides=s1) - out_view = pto.make_tensor_view(arg6, shape=sh128, strides=s128) - - # ── Partition views ─────────────────────────────────────────────────── - off = [c0, c0, c0, row_base, c0] - z1 = [c1, c1, c1, row_count, c1] - zs = [c1, c1, c1, row_count, seq] - - oldmax_part = pto.partition_view(oldmax_view, offsets=off, sizes=z1) - oldsum_part = pto.partition_view(oldsum_view, offsets=off, sizes=z1) - qk_part = pto.partition_view(qk_view, offsets=off, sizes=zs) - newmax_part = pto.partition_view(newmax_view, offsets=off, sizes=z1) - newsum_part = pto.partition_view(newsum_view, offsets=off, sizes=z1) - expmax_part = pto.partition_view(expmax_view, offsets=off, sizes=z1) - out_part = pto.partition_view(out_view, offsets=off, sizes=zs) - - # ── UB tile allocation ──────────────────────────────────────────────── - tile_col = pto.tile_buf_type([8, 1], pto.float32, [-1, 1], blayout="ColMajor") - tile_w = pto.tile_buf_type([8, 128], pto.float32, [-1, -1]) - - oldmax_tile = pto.alloc_tile(tile_col, addr=c0_i64, valid_row=row_count) - oldsum_tile = pto.alloc_tile(tile_col, addr=c128_i64, valid_row=row_count) - qk_tile = pto.alloc_tile(tile_w, addr=c256_i64, valid_row=row_count, valid_col=seq) - out_tile = pto.alloc_tile(tile_w, addr=c8448_i64, valid_row=row_count, valid_col=seq) - newmax_tile = pto.alloc_tile(tile_col, addr=c16640_i64, valid_row=row_count) - newsum_tile = pto.alloc_tile(tile_col, addr=c16768_i64, valid_row=row_count) - expmax_tile = pto.alloc_tile(tile_col, addr=c16896_i64, valid_row=row_count) - - # ── Tile loads from GM ──────────────────────────────────────────────── - pto.tile.load(oldmax_part, oldmax_tile) - pto.tile.load(oldsum_part, oldsum_tile) - pto.tile.load(qk_part, qk_tile) - - pto.set_flag("MTE2", "V", event_id=0) - pto.wait_flag("MTE2", "V", event_id=0) - - with pto.vecscope(): - # Materialise typed UB pointers from tile handles - ptr_ub = pto.ptr(pto.float32, "ub") - vf32 = pto.vreg_type(64, pto.float32) - - ub_om = pto.as_ptr(oldmax_tile, ptr_ub) - ub_os = pto.as_ptr(oldsum_tile, ptr_ub) - ub_qk = pto.as_ptr(qk_tile, ptr_ub) - ub_out = pto.as_ptr(out_tile, ptr_ub) - ub_nm = pto.as_ptr(newmax_tile, ptr_ub) - ub_ns = pto.as_ptr(newsum_tile, ptr_ub) - ub_em = pto.as_ptr(expmax_tile, ptr_ub) - - active = pto.pset_b32("PAT_ALL") - one_mask, _ = pto.plt_b32(c1_i32) - - with pto.for_(c0, row_count, step=c1) as row: - row_qk = s.muli(row, c128) - oldmax_bc = pto.vbrc_load(ub_om, row, vf32) - oldsum_bc = pto.vbrc_load(ub_os, row, vf32) - - # scf.for with iter_args: accumulate (running_max, running_sum) - with pto.for_(c0, c128, step=c64, iter_args=(oldmax_bc, oldsum_bc)) as loop: - chunk = loop.iv - running_max, running_sum = loop.iter_args - - chunk_i32 = s.index_cast(pto.int32, chunk) - remaining_cols = s.subi(arg7, chunk_i32) - has_chunk = remaining_cols > c0_i32 - - # scf.if with merged branch values – produce (next_max, next_sum) - with pto.if_(has_chunk) as br: - with br.then_: - chunk_mask, _ = pto.plt_b32(remaining_cols) - chunk_base = s.addi(row_qk, chunk) - vec = pto.vlds(ub_qk, chunk_base, vf32) - chunk_max = pto.vcmax(vec, chunk_mask) - chunk_max_bc = pto.vdup(chunk_max, active, position="LOWEST") - merged_max = pto.vmax(running_max, chunk_max_bc, active) - scaled_running = pto.vexpdif(running_max, merged_max, active) - running_sum_scaled = pto.vmul(scaled_running, running_sum, active) - chunk_exp = pto.vexpdif(vec, merged_max, chunk_mask) - chunk_sum = pto.vcadd(chunk_exp, chunk_mask) - chunk_sum_bc = pto.vdup(chunk_sum, active, position="LOWEST") - merged_sum = pto.vadd(running_sum_scaled, chunk_sum_bc, active) - br.assign(next_max=merged_max, next_sum=merged_sum) - with br.else_: - br.assign(next_max=running_max, next_sum=running_sum) - - pto.yield_(br.next_max, br.next_sum) - - final_max, final_sum = loop.results - - # Compute per-row expmax scalar - raw_em = pto.vexpdif(oldmax_bc, final_max, active) - sc_os = pto.vmul(raw_em, oldsum_bc, active) - expmax = pto.vdiv(sc_os, final_sum, active) - - pto.vsts_1pt(final_max, ub_nm, row, one_mask) - pto.vsts_1pt(final_sum, ub_ns, row, one_mask) - pto.vsts_1pt(expmax, ub_em, row, one_mask) - - # Output normalisation loop - with pto.for_(c0, c128, step=c64) as chunk2: - rem2 = s.subi(arg7, s.index_cast(pto.int32, chunk2)) - has_chunk2= rem2 > c0_i32 - with pto.if_(has_chunk2) as br2: - with br2.then_: - cmask2, _ = pto.plt_b32(rem2) - cbase2 = s.addi(row_qk, chunk2) - vec2 = pto.vlds(ub_qk, cbase2, vf32) - exp2 = pto.vexpdif(vec2, final_max, cmask2) - out2 = pto.vdiv(exp2, final_sum, cmask2) - pto.vsts(out2, ub_out, cbase2, cmask2) - - pto.set_flag("V", "MTE3", event_id=0) - pto.wait_flag("V", "MTE3", event_id=0) - - # Tile stores to GM - pto.tile.store(newmax_tile, newmax_part) - pto.tile.store(newsum_tile, newsum_part) - pto.tile.store(expmax_tile, expmax_part) - pto.tile.store(out_tile, out_part) - - pto.pipe_barrier(pto.Pipe.ALL) def build(): - return online_softmax_update_kernel_2d.mlir_module() + return pto.merge_jit_modules(SOFTMAX_ROWS64_SEQ128, SOFTMAX_ROWS81_SEQ96) if __name__ == "__main__": - print(online_softmax_update_kernel_2d) + print(build()) diff --git a/ptodsl/examples/tadd_dsl.py b/ptodsl/examples/tadd_dsl.py index e02076058..4d1482dce 100644 --- a/ptodsl/examples/tadd_dsl.py +++ b/ptodsl/examples/tadd_dsl.py @@ -18,7 +18,7 @@ %c0_i64 = arith.constant 0 : i64 # pto.const(0, dtype=pto.int64) %c16 = arith.constant 16 : index # pto.const(16, dtype=pto.index) … - pto.vecscope { # with pto.vecscope(): + pto.simd { # with pto.simd(): %0 = pto.castptr %c4096_i64 … # pto.castptr(c4096_i64, …) scf.for %arg0 = %c0 to %c16 … { # with pto.for_(c0, c16, step=c1) as i: %mask, _ = pto.plt_b32 … # pto.plt_b32(c64_i32) @@ -43,7 +43,7 @@ def TADD(): c64_i32 = pto.const(64, dtype=pto.int32) c64 = pto.const(64) - with pto.vecscope(): + with pto.simd(): ptr_f32_ub = pto.ptr(pto.float32, "ub") vf32 = pto.vreg_type(64, pto.float32) ptr_src = pto.castptr(c4096_i64, ptr_f32_ub) diff --git a/ptodsl/examples/tilelang_codegen.py b/ptodsl/examples/tilelang_codegen.py index 7353979a7..de7dbc6ab 100644 --- a/ptodsl/examples/tilelang_codegen.py +++ b/ptodsl/examples/tilelang_codegen.py @@ -78,7 +78,6 @@ def _tilelang_generated_body( pto.wait_flag("MTE3", "V", event_id=iter % 2) with pto.simd(): mask_cnt = 8192 - _ = mask_cnt with pto.for_(0, 128, step=1) as i: mask = pto.pset_b32("PAT_ALL") r0 = pto.vlds( diff --git a/ptodsl/ptodsl/_control_flow.py b/ptodsl/ptodsl/_control_flow.py index 54cacdf97..5520d7093 100644 --- a/ptodsl/ptodsl/_control_flow.py +++ b/ptodsl/ptodsl/_control_flow.py @@ -13,8 +13,8 @@ Public API ────────── ``vecscope()`` – ``pto.vecscope { … }`` -``for_(lo, hi, step, *, iter_args)`` - – ``scf.for`` with optional iter_args or named carry state +``for_(lo, hi, step)`` + – ``scf.for`` with optional named carry state via ``.carry(...)`` ``if_(cond)`` – ``scf.if`` via explicit branch handle + automatic named merge ``yield_(*vals)`` – ``scf.yield`` """ @@ -53,9 +53,9 @@ def vecscope() -> _VecScopeCM: class LoopHandle: """ - Handle for a ``scf.for`` loop with iter_args. + Internal handle for a lowered ``scf.for`` loop. - Attributes available *after* the ``with pto.for_(…) as loop:`` block:: + Attributes used by the control-flow implementation:: loop.iv – induction variable loop.iter_args – tuple of inner (mutable) SSA values @@ -114,25 +114,15 @@ def __exit__(self, *exc): self._ip.__exit__(*exc) -def for_(start, stop, *, step, iter_args=None): +def for_(start, stop, *, step): """ ``scf.for`` context manager. - Without ``iter_args`` – yields the induction variable; ``scf.yield`` is - inserted automatically:: + Yields the induction variable; ``scf.yield`` is inserted automatically:: with pto.for_(c0, c16, step=c1) as i: ... - With ``iter_args`` – yields a :class:`LoopHandle`; the caller must emit - ``pto.yield_(…)`` before the block closes:: - - with pto.for_(c0, c128, step=c64, iter_args=(a, b)) as loop: - x, y = loop.iter_args - ... - pto.yield_(nx, ny) - fa, fb = loop.results - Named carry state is expressed with ``.carry(...)``:: loop = pto.for_(c0, c128, step=c64).carry(acc=tile) @@ -141,7 +131,7 @@ def for_(start, stop, *, step, iter_args=None): loop.update(acc=cur) out = loop.final("acc") """ - return _ForBuilder(start, stop, step, iter_args) + return _ForBuilder(start, stop, step) class _CarryLoopStateView: @@ -252,22 +242,19 @@ def final(self, name): class _ForBuilder: - def __init__(self, start, stop, step, iter_args=None): + def __init__(self, start, stop, step): self._start = start self._stop = stop self._step = step - self._iter_args = iter_args def __enter__(self): - self._cm = _ForCM(self._start, self._stop, self._step, self._iter_args) + self._cm = _ForCM(self._start, self._stop, self._step, None) return self._cm.__enter__() def __exit__(self, *exc): return self._cm.__exit__(*exc) def carry(self, **kwargs): - if self._iter_args is not None: - raise RuntimeError("for_(..., iter_args=...) cannot be combined with .carry(...)") if not kwargs: raise ValueError("carry(...) requires at least one named loop-carried value") for name in kwargs: diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index 9087a861d..b221eb5df 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -166,12 +166,34 @@ def invalid_jit_mode_error( ) -def removed_ukernel_surface_error() -> AttributeError: - """Return one diagnostic for the removed ``pto.ukernel`` public surface.""" +def unsupported_public_surface_error(name: str) -> AttributeError: + """Return one diagnostic for unsupported names on the public ``pto`` surface.""" + hints = { + "ukernel": ( + 'Use @pto.jit(mode="explicit") for explicit DMA orchestration, and call or inline ' + "@pto.simd/@pto.simt/@pto.cube directly from that kernel." + ), + "tile_buf_type": ( + "Use pto.alloc_tile(shape=..., dtype=..., memory_space=..., valid_shape=..., addr=...) " + "to author tiles, and keep explicit tile-type construction inside internal implementation code only." + ), + "vecscope": ( + "Use @pto.simd for named SIMD helpers, or inline SIMD code with `with pto.simd():`." + ), + "as_ptr": ( + "Use tile.as_ptr(), view.as_ptr(), or partition.as_ptr() on the authored object itself " + "instead of the removed pto.as_ptr(...) helper." + ), + "vbrc_load": ( + 'Use pto.vlds(ptr, offset, dist="BRC_B32") instead of the removed pto.vbrc_load(...) helper.' + ), + "vsts_1pt": ( + 'Use pto.vsts(vec, ptr, offset, mask, dist="1PT_B32") instead of the removed pto.vsts_1pt(...) helper.' + ), + } + suffix = hints.get(name, "Use the documented PTODSL public surface instead.") return AttributeError( - 'pto.ukernel has been removed from the PTODSL public surface. ' - 'Use @pto.jit(mode="explicit") for explicit DMA orchestration, and call or inline ' - "@pto.simd/@pto.simt/@pto.cube directly from that kernel." + f"pto.{name} is not a supported PTODSL public interface. {suffix}" ) @@ -186,9 +208,9 @@ def removed_ukernel_surface_error() -> AttributeError: "illegal_subkernel_placement_error", "invalid_jit_mode_error", "native_python_control_flow_error", - "removed_ukernel_surface_error", "simd_value_escape_error", "subkernel_host_tensor_boundary_error", "subkernel_signature_boundary_error", "tile_row_alignment_error", + "unsupported_public_surface_error", ] diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index b3a473ca0..9cbfec41f 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -195,25 +195,50 @@ def addptr(base_ptr, index_offset): # ── Vector load / store ─────────────────────────────────────────────────────── -def vlds(src_ptr, offset=None, result_vreg_type=None): +_VLOAD_DIST_TOKENS = { + "NORM", + "UNPK_B8", "UNPK_B16", "UNPK_B32", + "BRC_B8", "BRC_B16", "BRC_B32", + "US_B8", "US_B16", + "DS_B8", "DS_B16", +} + + +def vlds(src_ptr, offset=None, result_vreg_type=None, *, dist=None): """``pto.vlds`` – vector load from a tile slice or from *src_ptr* at *offset*.""" if isinstance(src_ptr, TileSliceValue): if offset is not None or result_vreg_type is not None: raise TypeError("vlds(tile[row, col:]) infers its memref slice and vreg type; do not pass offset/result_vreg_type") + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VLOAD_DIST_TOKENS, + context="vlds(..., dist)", + ) return wrap_surface_value(_pto.VldsOp( _infer_vreg_type_from_tile_slice(src_ptr), unwrap_surface_value(src_ptr), _index_zero(), + **kwargs, ).result) if offset is None: raise TypeError("vlds(ptr, offset, result_vreg_type=None) requires an explicit offset") if result_vreg_type is None: result_vreg_type = _infer_vreg_type_from_address_source(src_ptr) + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VLOAD_DIST_TOKENS, + context="vlds(..., dist)", + ) return wrap_surface_value(_pto.VldsOp( _resolve(result_vreg_type), unwrap_surface_value(src_ptr), unwrap_surface_value(offset), + **kwargs, ).result) @@ -308,18 +333,6 @@ def vldsx2(source, offset_or_dist, dist=None): return wrap_surface_value(op.low), wrap_surface_value(op.high) -def vbrc_load(src_ptr, offset, result_vreg_type): - """``pto.vlds {dist="BRC_B32"}`` – broadcast a scalar into all lanes.""" - return wrap_surface_value( - _pto.VldsOp( - _resolve(result_vreg_type), - unwrap_surface_value(src_ptr), - unwrap_surface_value(offset), - dist="BRC_B32", - ).result - ) - - def vbitcast(vector_value, to_dtype): """``pto.vbitcast`` – reinterpret one vector register as a different element type.""" target_elem = _resolve(to_dtype) @@ -342,37 +355,42 @@ def pbitcast(mask_value, to_type): ) -def vsts(val, dst_ptr, offset, mask=None): +def vsts(val, dst_ptr, offset, mask=None, *, dist=None): """``pto.vsts`` – vector store to a tile slice or to *dst_ptr* at *offset*.""" if isinstance(dst_ptr, TileSliceValue): if mask is not None: raise TypeError("vsts(vec, tile[row, col:], mask) does not accept a separate offset argument") + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VSTORE_DIST_TOKENS, + context="vsts(..., dist)", + ) _pto.VstsOp( unwrap_surface_value(val), unwrap_surface_value(dst_ptr), _index_zero(), unwrap_surface_value(offset), + **kwargs, ) return if mask is None: raise TypeError("vsts(vec, ptr, offset, mask) requires an explicit mask") + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VSTORE_DIST_TOKENS, + context="vsts(..., dist)", + ) _pto.VstsOp( unwrap_surface_value(val), unwrap_surface_value(dst_ptr), unwrap_surface_value(offset), unwrap_surface_value(mask), - ) - - -def vsts_1pt(val, dst_ptr, offset, mask): - """``pto.vsts {dist="1PT_B32"}`` – store only the lowest lane.""" - _pto.VstsOp( - unwrap_surface_value(val), - unwrap_surface_value(dst_ptr), - unwrap_surface_value(offset), - unwrap_surface_value(mask), - dist="1PT_B32", + **kwargs, ) @@ -1076,6 +1094,35 @@ def vstus(align_in, offset, value, base): # ── Vector math (result type inferred from first operand) ───────────────────── +def vbr(value): + """``pto.vbr`` – broadcast one scalar value to all vector lanes.""" + raw_value = unwrap_surface_value(value) + if isinstance(raw_value, bool): + raise TypeError("vbr(value) does not accept bool values") + + if hasattr(raw_value, "type"): + scalar_kind = classify_runtime_scalar_type(raw_value.type) + if scalar_kind == "index": + raise TypeError("vbr(value) does not support index scalars") + scalar_value = raw_value + elem_type = raw_value.type + else: + if isinstance(raw_value, float): + elem_type = F32Type.get() + elif isinstance(raw_value, int): + elem_type = IntegerType.get_signless(32) + else: + raise TypeError("vbr(value) expects a runtime scalar or one Python int/float literal") + scalar_value = materialize_scalar_literal(raw_value, elem_type, context="vbr(value)") + + try: + result_type = _resolve(vreg_type(_elements_per_vreg(elem_type), elem_type)) + except TypeError as exc: + raise TypeError(f"vbr(value) does not support scalar type {elem_type}") from exc + + return wrap_surface_value(_pto.VbrOp(result_type, scalar_value).result) + + def _emit_unary_vec_op(op_ctor, inp, mask): return wrap_surface_value( op_ctor( @@ -1549,7 +1596,7 @@ def alloc_tile( Accepts either the authored surface form: - ``alloc_tile(shape=[...], dtype=..., memory_space=...)`` + ``alloc_tile(shape=[...], dtype=..., memory_space=..., valid_shape=..., addr=...)`` or the low-level explicit-type form: @@ -1561,10 +1608,10 @@ def alloc_tile( if tile_type is None: if shape is None or dtype is None: raise TypeError("alloc_tile() requires either tile_type or both shape= and dtype=") - if addr is not None or valid_row is not None or valid_col is not None: + if valid_row is not None or valid_col is not None: raise TypeError( "alloc_tile(shape=..., dtype=...) uses the authored surface form; " - "addr=/valid_row=/valid_col= are only supported with an explicit tile_type" + "use valid_shape=... instead of valid_row=/valid_col=" ) logical_shape = _normalize_static_tile_shape(shape) physical_shape = _authored_tile_physical_shape(logical_shape) @@ -1588,7 +1635,7 @@ def alloc_tile( value = _pto.AllocTileOp( _resolve(tile_type), - addr=unwrap_surface_value(addr) if addr is not None else None, + addr=_coerce_i64(addr, context="alloc_tile(addr)") if addr is not None else None, valid_row=_coerce_index(valid_row, context="alloc_tile(valid_row)") if valid_row is not None else None, valid_col=_coerce_index(valid_col, context="alloc_tile(valid_col)") if valid_col is not None else None, ).result @@ -2351,10 +2398,10 @@ def tfillpad_inplace(src, dst): ) -def as_ptr(value, result_ptr_type=None): +def as_ptr(value): """Materialize a typed pointer from a tile or tensor-view descriptor.""" wrapped = wrap_surface_value(value) - return emit_as_ptr(wrapped, result_ptr_type) + return emit_as_ptr(wrapped) def _constant_like(value, mlir_type): @@ -2451,7 +2498,7 @@ def _infer_vreg_metadata(vector_value): def _extract_lowest_lane_scalar(vector_value, mask): lanes, elem_type = _infer_vreg_metadata(vector_value) tmp_tile = alloc_tile(shape=[1, lanes], dtype=elem_type, valid_shape=[1, 1]) - vsts_1pt(vector_value, tmp_tile.as_ptr(), _index_zero(), mask) + vsts(vector_value, tmp_tile.as_ptr(), _index_zero(), mask, dist="1PT_B32") from . import scalar as _scalar return _scalar.load(tmp_tile[0, 0]) @@ -2677,9 +2724,15 @@ def make_mask(dtype, value): ) raw_value = unwrap_surface_value(value) + authored_scalar_type = raw_value.type if hasattr(raw_value, "type") else IntegerType.get_signless(32) raw_value = _coerce_i32(raw_value, context="make_mask(..., value)") plt_op = _plt_op_for_mask_bits(mask_bits)(result_type, IntegerType.get_signless(32), raw_value) - return MaskResultValue(plt_op.mask, plt_op.scalar_out) + next_value = coerce_scalar_to_type( + plt_op.scalar_out, + authored_scalar_type, + context="make_mask(..., value) result", + ) + return MaskResultValue(plt_op.mask, next_value) # ── Hardware / sync ─────────────────────────────────────────────────────────── @@ -3164,7 +3217,7 @@ def wait_flag(src: str, dst: str, *, event_id: int = 0): __all__ = [ "const", "castptr", "addptr", - "vlds", "vldas", "vldus", "vldsx2", "vbrc_load", "vsts", "vsts_1pt", "vstsx2", + "vlds", "vldas", "vldus", "vldsx2", "vsts", "vstsx2", "init_align", "plt_b8", "plt_b16", "plt_b32", "pset_b8", "pset_b16", "pset_b32", @@ -3178,6 +3231,7 @@ def wait_flag(src: str, dst: str, *, event_id: int = 0): "vcmp", "vcmps", "plds", "psts", "pstu", "vstar", "vstas", "vstur", "vstus", "vbitcast", + "vbr", "vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", "vand", "vor", "vxor", "vshl", "vshr", "vcmax", "vcadd", "vcmin", "vdup", "vexpdif", diff --git a/ptodsl/ptodsl/_surface_values.py b/ptodsl/ptodsl/_surface_values.py index eaf14d245..be9a89097 100644 --- a/ptodsl/ptodsl/_surface_values.py +++ b/ptodsl/ptodsl/_surface_values.py @@ -347,9 +347,9 @@ def surface_metadata(self): "strides": self.strides, } - def as_ptr(self, result_ptr_type=None): + def as_ptr(self): from ._ops import as_ptr - return as_ptr(self, result_ptr_type) + return as_ptr(self) class PartitionTensorViewValue(_SurfaceValue, PartitionTensorView): @@ -363,9 +363,9 @@ def __init__(self, value, *, root_tensor_view=None, offsets=None, sizes=None): self.shape = self.sizes self.strides = getattr(root_tensor_view, "strides", None) - def as_ptr(self, result_ptr_type=None): + def as_ptr(self): from ._ops import as_ptr - return as_ptr(self, result_ptr_type) + return as_ptr(self) class _TileValidShapeView: @@ -468,9 +468,9 @@ def surface_metadata(self): "valid_shape": self.static_valid_shape, } - def as_ptr(self, result_ptr_type=None): + def as_ptr(self): from ._ops import as_ptr - return as_ptr(self, result_ptr_type) + return as_ptr(self) def fill(self, value): from ._ops import fill_tile @@ -563,11 +563,8 @@ def compose_partition_spec(source, *, offsets, sizes) -> PartitionSpec | None: ) -def infer_ptr_type_from_surface_value(surface_value, result_ptr_type=None): - """Infer a PTO pointer type for `as_ptr()` when the caller omits one.""" - if result_ptr_type is not None: - return _resolve(result_ptr_type) - +def infer_ptr_type_from_surface_value(surface_value): + """Infer a PTO pointer type for `as_ptr()` from the authored source value.""" value_type = surface_value.type tv_type = _maybe_cast_tensor_view_type(value_type) @@ -601,10 +598,10 @@ def infer_ptr_type_from_surface_value(surface_value, result_ptr_type=None): return _resolve(ptr(tile_type.element_type, space_enum)) -def emit_as_ptr(surface_value, result_ptr_type=None): +def emit_as_ptr(surface_value): """Lower `as_ptr()` on a surface value to the appropriate PTO op.""" value = unwrap_surface_value(surface_value) - result_type = infer_address_type_from_surface_value(surface_value, result_ptr_type) + result_type = infer_address_type_from_surface_value(surface_value) if isinstance(surface_value, (TensorViewValue, PartitionTensorViewValue)): return AddressValue(_pto.TensorViewAddrOp(result_type, value).result) @@ -703,9 +700,9 @@ def infer_tile_element_type(tile): return parsed["element_type"] -def infer_address_type_from_surface_value(surface_value, result_ptr_type=None): +def infer_address_type_from_surface_value(surface_value): """Infer the concrete result type emitted by `as_ptr()`.""" - return infer_ptr_type_from_surface_value(surface_value, result_ptr_type) + return infer_ptr_type_from_surface_value(surface_value) def infer_memref_type_from_surface_value(surface_value): diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index 293fab0c4..da43f3c2b 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -20,7 +20,7 @@ internally as ``_pto`` (``from mlir.dialects import pto as _pto``). """ -from ._diagnostics import removed_ukernel_surface_error +from ._diagnostics import unsupported_public_surface_error # ── Types ───────────────────────────────────────────────────────────────────── from ._types import ( # noqa: F401 @@ -31,7 +31,6 @@ ui8, ui16, ui32, ui64, index, ptr, vreg_type, mask_type, - tile_buf_type, _resolve, ) from ._surface_types import ( # noqa: F401 @@ -61,7 +60,7 @@ from ._ops import ( # noqa: F401 const, castptr, addptr, - vlds, vldas, vldus, vldsx2, vbrc_load, vsts, vsts_1pt, vstsx2, + vlds, vldas, vldus, vldsx2, vsts, vstsx2, init_align, plt_b8, plt_b16, plt_b32, pset_b8, pset_b16, pset_b32, @@ -76,6 +75,7 @@ vcmp, vcmps, plds, psts, pstu, vstar, vstas, vstur, vstus, vbitcast, + vbr, vadd, vsub, vmul, vdiv, vmax, vmin, vand, vor, vxor, vshl, vshr, vcmax, vcadd, vcmin, vdup, vexpdif, @@ -85,7 +85,7 @@ vaxpy, vaddrelu, vsubrelu, vsel, make_tensor_view, partition_view, - alloc_tile, as_ptr, + alloc_tile, mte_load, mte_store, mte_gm_ub, mte_ub_gm, mte_ub_ub, mte_ub_l1, mem_bar, mte_l1_l0a, mte_l1_l0b, mte_l0c_ub, mad, mad_acc, mad_bias, mad_mx, mad_mx_acc, mad_mx_bias, @@ -99,7 +99,6 @@ # ── Control flow ────────────────────────────────────────────────────────────── from ._control_flow import ( # noqa: F401 - vecscope, for_, if_, yield_, LoopHandle, BranchHandle, ) @@ -122,6 +121,6 @@ def __getattr__(name): - if name == "ukernel": - raise removed_ukernel_surface_error() + if name in {"ukernel", "tile_buf_type", "vecscope", "as_ptr", "vbrc_load", "vsts_1pt"}: + raise unsupported_public_surface_error(name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/quick_install.sh b/quick_install.sh index 4a0f6b69d..38509b61d 100755 --- a/quick_install.sh +++ b/quick_install.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + # For quick development, build and install ptoas and its python bindings # on top of Docker image https://github.com/learning-chip/agent_docker_npu/pull/8 # assume MLIR is already installed to save time, takes <3min to finish the build of pto extension diff --git a/set_ptoas_env.sh b/set_ptoas_env.sh index f96d4eb33..996d84ad4 100644 --- a/set_ptoas_env.sh +++ b/set_ptoas_env.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + # after `quick_install.sh`, run `source set_ptoas_env.sh` in a new shell to find the lib export PTO_SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" diff --git a/test/python/ptodsl_jit_compile.py b/test/python/ptodsl_jit_compile.py index 020a0e41c..047b81b69 100644 --- a/test/python/ptodsl_jit_compile.py +++ b/test/python/ptodsl_jit_compile.py @@ -53,6 +53,13 @@ def expect_parse_roundtrip_and_verify(text: str, label: str) -> None: ) +expect_raises( + TypeError, + lambda: pto.for_(0, 1, step=1, iter_args=(0,)), + "iter_args", +) + + @pto.jit(target="a5") def host_vec_copy( A: pto.tensor_spec(rank=2, dtype=pto.f32), @@ -140,6 +147,31 @@ def runtime_metadata_kernel( pto.tile.store(o_tile, out) +@pto.jit(target="a5", mode="explicit") +def authored_addr_tile_surface_probe( + A: pto.tensor_spec(rank=2, dtype=pto.f32), +): + rows = A.shape[0] + cols = A.shape[1] + tile = pto.alloc_tile(shape=[1, 128], dtype=pto.f32, addr=0, valid_shape=[rows, cols]) + _ = tile + + +@pto.jit(target="a5", mode="explicit") +def dynamic_addr_tile_surface_probe( + A: pto.tensor_spec(rank=2, dtype=pto.f32), +): + rows = A.shape[0] + cols = A.shape[1] + tile = pto.alloc_tile( + shape=[1, 128], + dtype=pto.f32, + addr=scalar.index_cast(pto.i64, rows), + valid_shape=[rows, cols], + ) + _ = tile + + @pto.jit(target="a5") def tile_surface_compute_probe(): lhs = pto.alloc_tile(shape=[2, 16], dtype=pto.f32) @@ -407,6 +439,19 @@ def tile_valid_shape_update_1d_probe( tile.valid_shape = [length] +@pto.jit(target="a5", mode="explicit") +def make_mask_index_roundtrip_probe( + A: pto.tensor_spec(rank=1, dtype=pto.f32), +): + cols = A.shape[0] + col_loop = pto.for_(0, cols, step=64).carry(remained=cols) + with col_loop: + remained = col_loop.remained + mask, remained_after_pack = pto.make_mask(pto.f32, remained) + _ = mask + col_loop.update(remained=remained_after_pack) + + @pto.jit(target="a5") def integer_loop_bound_probe(*, BLOCK: pto.constexpr = 8): row_start = pto.const(0, dtype=pto.i32) @@ -653,7 +698,7 @@ def signed_integer_scalar_probe(): def low_precision_storage_probe(): lp_tile = pto.alloc_tile(shape=[128, 64], dtype=pto.f8e4m3) lp_tile_hif8 = pto.alloc_tile(shape=[64, 64], dtype=pto.hif8) - lp_tile_ty = pto.tile_buf_type([16, 16], pto.f4e2m1x2, [16, 16]) + lp_tile_ty = pto_types.tile_buf_type([16, 16], pto.f4e2m1x2, [16, 16]) _ = lp_tile _ = lp_tile_hif8 _ = lp_tile_ty @@ -663,9 +708,11 @@ def low_precision_storage_probe(): def pointer_vlds_inference_probe(*, BLOCK: pto.constexpr = 128): tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) vec = pto.vlds(tile.as_ptr(), pto.const(0)) + vec_brc = pto.vlds(tile.as_ptr(), pto.const(0), dist="BRC_B32") ivec = pto.vbitcast(vec, pto.i32) f16_vec = pto.vbitcast(vec, pto.f16) _ = vec + _ = vec_brc _ = ivec _ = f16_vec @@ -960,14 +1007,44 @@ def main() -> None: expect(not hasattr(pto, "tload"), "legacy pto.tload should not remain on the public pto namespace") expect(not hasattr(pto, "tstore"), "legacy pto.tstore should not remain on the public pto namespace") expect(not hasattr(pto, "tadd"), "legacy pto.tadd should not remain on the public pto namespace") + expect(not hasattr(pto, "tile_buf_type"), "pto.tile_buf_type should not remain on the public pto namespace") + expect(not hasattr(pto, "vecscope"), "pto.vecscope should not remain on the public pto namespace") + expect(not hasattr(pto, "as_ptr"), "pto.as_ptr should not remain on the public pto namespace") + expect(not hasattr(pto, "vbrc_load"), "pto.vbrc_load should not remain on the public pto namespace") + expect(not hasattr(pto, "vsts_1pt"), "pto.vsts_1pt should not remain on the public pto namespace") expect(not hasattr(scalar, "sts"), "scalar.sts should not remain in the public scalar namespace") expect(not hasattr(scalar, "cmpi"), "scalar.cmpi should not remain in the public scalar namespace") expect(not hasattr(scalar, "cmpi_sgt"), "scalar.cmpi_sgt should not remain in the public scalar namespace") + removed_tile_buf_type = expect_raises(AttributeError, lambda: getattr(pto, "tile_buf_type")) + expect( + "pto.tile_buf_type is not a supported PTODSL public interface" in str(removed_tile_buf_type), + "removed pto.tile_buf_type should diagnose the authored alloc_tile replacement", + ) + removed_vecscope = expect_raises(AttributeError, lambda: getattr(pto, "vecscope")) + expect( + "pto.vecscope is not a supported PTODSL public interface" in str(removed_vecscope), + "removed pto.vecscope should diagnose the public SIMD replacements", + ) + removed_as_ptr = expect_raises(AttributeError, lambda: getattr(pto, "as_ptr")) + expect( + "pto.as_ptr is not a supported PTODSL public interface" in str(removed_as_ptr), + "removed pto.as_ptr should diagnose the authored object-method replacements", + ) + removed_vbrc_load = expect_raises(AttributeError, lambda: getattr(pto, "vbrc_load")) + expect( + "pto.vbrc_load is not a supported PTODSL public interface" in str(removed_vbrc_load), + "removed pto.vbrc_load should diagnose the public vlds(dist=...) replacement", + ) + removed_vsts_1pt = expect_raises(AttributeError, lambda: getattr(pto, "vsts_1pt")) + expect( + "pto.vsts_1pt is not a supported PTODSL public interface" in str(removed_vsts_1pt), + "removed pto.vsts_1pt should diagnose the public vsts(dist=...) replacement", + ) for name in ("max", "min", "exp", "log", "sqrt", "abs"): expect(hasattr(scalar, name), f"scalar.{name} should be exported from the public scalar namespace") with make_context() as ctx, Location.unknown(ctx): - tile_buf_ty = pto.tile_buf_type( + tile_buf_ty = pto_types.tile_buf_type( [16, 32], pto.f32, [16, 8], @@ -986,6 +1063,8 @@ def main() -> None: host_vec_copy.verify() runtime_metadata_kernel.verify() + authored_addr_tile_surface_probe.verify() + dynamic_addr_tile_surface_probe.verify() tile_surface_compute_probe.verify() shared_subkernel_lowering_probe.verify() simt_helper_lowering_probe.verify() @@ -998,6 +1077,7 @@ def main() -> None: tile_slice_1d_surface_probe.verify() tile_valid_shape_update_probe.verify() tile_valid_shape_update_1d_probe.verify() + make_mask_index_roundtrip_probe.verify() integer_loop_bound_probe.verify() scalar_pointer_offset_probe.verify() addptr_surface_probe.verify() @@ -1086,7 +1166,7 @@ def main() -> None: "pto.mask_b32 should resolve to the public 32-bit mask type", ) - lp_tile_ty = pto.tile_buf_type([16, 16], pto.hif8, [16, 16]) + lp_tile_ty = pto_types.tile_buf_type([16, 16], pto.hif8, [16, 16]) lp_tv_ty = pto_types.tensor_view_type(2, pto.f8e4m3) lp_part_ty = pto_types.part_tensor_view_type(2, pto.f4e2m1x2) expect( @@ -1236,6 +1316,27 @@ def main() -> None: "partition_view sizes derived from tensor metadata should remain runtime MLIR values", ) + authored_addr_tile_text = authored_addr_tile_surface_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify(authored_addr_tile_text, "authored alloc_tile addr specialization") + expect( + "pto.alloc_tile addr = %c0_i64 valid_row = %arg1 valid_col = %arg2 : !pto.tile_buf" in authored_addr_tile_text, + "alloc_tile(shape=..., dtype=..., addr=int, valid_shape=...) should coerce Python ints to i64 operands", + ) + + dynamic_addr_tile_text = dynamic_addr_tile_surface_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify(dynamic_addr_tile_text, "dynamic alloc_tile addr specialization") + expect( + "arith.index_cast %arg1 : index to i64" in dynamic_addr_tile_text, + "alloc_tile(addr=runtime index) should cast dynamic index metadata to i64 before lowering", + ) + expect( + re.search( + r"pto\.alloc_tile addr = %[0-9]+ valid_row = %arg1 valid_col = %arg2 : !pto\.tile_buf", + dynamic_addr_tile_text, + ) is not None, + "alloc_tile(shape=..., dtype=..., addr=runtime value, valid_shape=...) should accept dynamic i64-like operands", + ) + tile_valid_shape_text = tile_valid_shape_update_probe.compile().mlir_text() expect_parse_roundtrip_and_verify(tile_valid_shape_text, "tile valid-shape update specialization") expect( @@ -1256,6 +1357,30 @@ def main() -> None: "tile.valid_shape = [length] should lower to pto.set_validshape on a rank-1 dynamic-valid tile", ) + make_mask_index_roundtrip_text = make_mask_index_roundtrip_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify(make_mask_index_roundtrip_text, "make_mask index round-trip specialization") + expect( + re.search( + r"arith\.index_cast %[a-zA-Z0-9_]+ : index to i32", + make_mask_index_roundtrip_text, + ) is not None, + "make_mask(...) should still cast index counts to the hardware i32 tail-mask operand type", + ) + expect( + re.search( + r"arith\.index_cast %[a-zA-Z0-9_]+ : i32 to index", + make_mask_index_roundtrip_text, + ) is not None, + "make_mask(...) should restore index counts after tail-mask generation so loop-carried state stays in authored index form", + ) + expect( + re.search( + r"scf\.yield %[a-zA-Z0-9_]+ : index", + make_mask_index_roundtrip_text, + ) is not None, + "make_mask(...) should allow loop-carried index remainders without manual i32 casts", + ) + SUBKERNEL_OBSERVATIONS.clear() shared_subkernel_lowering_probe.compile(TRACE_TOKEN=1) expect( @@ -1560,6 +1685,7 @@ def main() -> None: expect("!pto.tile_buf" in low_precision_storage_text, "low-precision tile allocation should preserve HiF8 element types in MLIR") expect("pto.vlds" in pointer_vlds_text, "vlds(ptr, offset) should still lower to pto.vlds") expect("!pto.vreg<64xf32>" in pointer_vlds_text, "vlds(ptr, offset) should infer the result vreg type from the pointer element type") + expect('dist = "BRC_B32"' in pointer_vlds_text, 'vlds(ptr, offset, dist="BRC_B32") should lower the authored load distribution') expect("pto.vbitcast" in pointer_vlds_text, "vbitcast(...) should lower to pto.vbitcast") expect("!pto.vreg<128xf16>" in pointer_vlds_text, "vbitcast(vec, pto.f16) should preserve the 256-byte payload while adjusting the lane count") expect(mask_bitcast_text.count("pto.pbitcast") == 2, "pbitcast(...) should lower to pto.pbitcast for each authored mask reinterpretation") From c2402e4f39f5dcbfcf52ab7796dfc418e2b277cf Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 26 May 2026 11:24:27 +0200 Subject: [PATCH 25/31] working tadd dynamic --- ptodsl/examples/jit/tadd_launch.py | 58 ++++++++++++++++++++++++++---- ptodsl/ptodsl/_runtime/codegen.py | 26 ++++++++++++++ ptodsl/ptodsl/_runtime/launch.py | 24 +++++++++++++ 3 files changed, 101 insertions(+), 7 deletions(-) diff --git a/ptodsl/examples/jit/tadd_launch.py b/ptodsl/examples/jit/tadd_launch.py index 734041689..4eb0ed64d 100644 --- a/ptodsl/examples/jit/tadd_launch.py +++ b/ptodsl/examples/jit/tadd_launch.py @@ -31,7 +31,7 @@ "Unable to locate the PTODSL Python package root from tadd_launch.py" ) -from ptodsl import pto +from ptodsl import pto, scalar as s _DEVICE = "npu:0" @@ -69,6 +69,35 @@ def _tadd_tile(A, B, C, rows: int, cols: int) -> None: pto.tile.store(c_tile, c_part) +def _tadd_tile_dynamic_rows(A, B, C, rows, *, max_rows: int, cols: int) -> None: + c0 = pto.const(0) + c1 = pto.const(1) + c_rows = s.index_cast(rows) + c_cols = pto.const(cols) + c_elems = s.muli(c_rows, c_cols) + + shape = [c1, c1, c1, c_rows, c_cols] + strides = [c_elems, c_elems, c_elems, c_cols, c1] + off = [c0, c0, c0, c0, c0] + + a_view = pto.make_tensor_view(A, shape=shape, strides=strides) + b_view = pto.make_tensor_view(B, shape=shape, strides=strides) + c_view = pto.make_tensor_view(C, shape=shape, strides=strides) + + a_part = pto.partition_view(a_view, offsets=off, sizes=shape) + b_part = pto.partition_view(b_view, offsets=off, sizes=shape) + c_part = pto.partition_view(c_view, offsets=off, sizes=shape) + + a_tile = pto.alloc_tile(shape=[max_rows, cols], dtype=pto.float32, valid_shape=[c_rows, cols]) + b_tile = pto.alloc_tile(shape=[max_rows, cols], dtype=pto.float32, valid_shape=[c_rows, cols]) + c_tile = pto.alloc_tile(shape=[max_rows, cols], dtype=pto.float32, valid_shape=[c_rows, cols]) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.tile.add(a_tile, b_tile, c_tile) + pto.tile.store(c_tile, c_part) + + @pto.jit( name="TADD_f32_16x64", kernel_kind="vector", @@ -83,19 +112,20 @@ def TADD_f32_16x64( @pto.jit( - name="TADD_f32_32x32", + name="TADD_f32_dyn_rows_x64", kernel_kind="vector", target="a5", ) -def TADD_f32_32x32( +def TADD_f32_dyn_rows_x64( A: pto.tensor_spec(rank=2, dtype=pto.f32), B: pto.tensor_spec(rank=2, dtype=pto.f32), C: pto.tensor_spec(rank=2, dtype=pto.f32), + rows: pto.i32, ): - _tadd_tile(A, B, C, 32, 32) + _tadd_tile_dynamic_rows(A, B, C, rows, max_rows=32, cols=64) -KERNELS = (TADD_f32_16x64, TADD_f32_32x32) +KERNELS = (TADD_f32_16x64, TADD_f32_dyn_rows_x64) def emit_mlir(): @@ -108,7 +138,20 @@ def emit_mlir(): CASES = [ {"name": "f32_16x64", "kernel": TADD_f32_16x64, "shape": (16, 64), "eps": 1e-6}, - {"name": "f32_32x32", "kernel": TADD_f32_32x32, "shape": (32, 32), "eps": 1e-6}, + { + "name": "f32_dyn_rows_16x64", + "kernel": TADD_f32_dyn_rows_x64, + "shape": (16, 64), + "dynamic_rows": True, + "eps": 1e-6, + }, + { + "name": "f32_dyn_rows_32x64", + "kernel": TADD_f32_dyn_rows_x64, + "shape": (32, 64), + "dynamic_rows": True, + "eps": 1e-6, + }, ] @@ -143,7 +186,8 @@ def run_case(case: dict, torch) -> None: compile_s = time.perf_counter() - t0 t0 = time.perf_counter() - compiled[1, stream](a, b, c) + launch_args = (a, b, c, shape[0]) if case.get("dynamic_rows") else (a, b, c) + compiled[1, stream](*launch_args) torch.npu.synchronize() launch_s = time.perf_counter() - t0 diff --git a/ptodsl/ptodsl/_runtime/codegen.py b/ptodsl/ptodsl/_runtime/codegen.py index 75eb4ed0f..35758542a 100644 --- a/ptodsl/ptodsl/_runtime/codegen.py +++ b/ptodsl/ptodsl/_runtime/codegen.py @@ -11,10 +11,32 @@ from mlir.ir import BF16Type, F16Type, F32Type, IndexType, IntegerType +from .. import _types as _pto_types from .._kernel_signature import DeviceParameterSpec, RuntimeScalarParameterSpec, TensorSpecParameterSpec from .._types import _PtrDescriptor, _resolve +_RUNTIME_SCALAR_CPP_TYPES_BY_DESCRIPTOR = { + _pto_types.index: "int64_t", + _pto_types.int1: "bool", + _pto_types.int8: "int8_t", + _pto_types.int16: "int16_t", + _pto_types.int32: "int32_t", + _pto_types.int64: "int64_t", + _pto_types.si8: "int8_t", + _pto_types.si16: "int16_t", + _pto_types.si32: "int32_t", + _pto_types.si64: "int64_t", + _pto_types.ui8: "uint8_t", + _pto_types.ui16: "uint16_t", + _pto_types.ui32: "uint32_t", + _pto_types.ui64: "uint64_t", + _pto_types.float32: "float", + _pto_types.float16: "__fp16", + _pto_types.bf16: "__bf16", +} + + def _elem_cpp_type(elem) -> str: name = getattr(elem, "__name__", repr(elem)).lower() mapping = { @@ -52,6 +74,10 @@ def _device_param_cpp_type(annotation) -> str: def _runtime_scalar_cpp_type(annotation) -> str: + descriptor_cpp_type = _RUNTIME_SCALAR_CPP_TYPES_BY_DESCRIPTOR.get(annotation) + if descriptor_cpp_type is not None: + return descriptor_cpp_type + type_obj = _resolve(annotation) if IndexType.isinstance(type_obj): return "int64_t" diff --git a/ptodsl/ptodsl/_runtime/launch.py b/ptodsl/ptodsl/_runtime/launch.py index c517e7037..084a5b992 100644 --- a/ptodsl/ptodsl/_runtime/launch.py +++ b/ptodsl/ptodsl/_runtime/launch.py @@ -17,6 +17,7 @@ looks_like_host_tensor, ) from .._kernel_signature import DeviceParameterSpec, RuntimeScalarParameterSpec, TensorSpecParameterSpec +from .. import _types as _pto_types from .._types import _resolve from .native_build import build_native_library @@ -26,6 +27,25 @@ from .._kernel_compilation import CompiledKernelHandle +_RUNTIME_SCALAR_CTYPES_BY_DESCRIPTOR = { + _pto_types.index: ctypes.c_int64, + _pto_types.int1: ctypes.c_bool, + _pto_types.int8: ctypes.c_int8, + _pto_types.int16: ctypes.c_int16, + _pto_types.int32: ctypes.c_int32, + _pto_types.int64: ctypes.c_int64, + _pto_types.si8: ctypes.c_int8, + _pto_types.si16: ctypes.c_int16, + _pto_types.si32: ctypes.c_int32, + _pto_types.si64: ctypes.c_int64, + _pto_types.ui8: ctypes.c_uint8, + _pto_types.ui16: ctypes.c_uint16, + _pto_types.ui32: ctypes.c_uint32, + _pto_types.ui64: ctypes.c_uint64, + _pto_types.float32: ctypes.c_float, +} + + def _normalize_stream_ptr(stream): if stream is None: try: @@ -56,6 +76,10 @@ def _as_void_ptr(value): def _ctype_for_runtime_scalar(annotation): + descriptor_ctype = _RUNTIME_SCALAR_CTYPES_BY_DESCRIPTOR.get(annotation) + if descriptor_ctype is not None: + return descriptor_ctype + type_obj = _resolve(annotation) if IndexType.isinstance(type_obj): return ctypes.c_int64 From b987f9a58f4be05e1008848e1290f84725456fc8 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 26 May 2026 14:02:01 +0200 Subject: [PATCH 26/31] softmax dynamic example --- .../jit/flash_attention_softmax_launch.py | 257 ++++++++---------- ptodsl/ptodsl/_host_tensors.py | 6 +- 2 files changed, 116 insertions(+), 147 deletions(-) diff --git a/ptodsl/examples/jit/flash_attention_softmax_launch.py b/ptodsl/examples/jit/flash_attention_softmax_launch.py index 3ed204bd7..df3bcb878 100644 --- a/ptodsl/examples/jit/flash_attention_softmax_launch.py +++ b/ptodsl/examples/jit/flash_attention_softmax_launch.py @@ -23,175 +23,142 @@ running_sum = running_sum * exp(old_max - new_max) + exp(score_col - new_max) out = exp(score_col - final_max) / final_sum -The demo offers two launchable kernels so the current launch ABI does not need -an extra runtime tile-width parameter: - -- ``rows64_seq128``: full-width 64-row packed softmax -- ``rows81_seq96``: same single NPU, but two sequential row-pack updates +The demo launches the same dynamic-shape kernel for multiple ``[seq, rows]`` +sizes by passing ``rows`` and ``seq`` as runtime scalar arguments. """ import argparse import time -from pathlib import Path -import sys import numpy as np -if __package__ in {None, ""}: - here = Path(__file__).resolve() - for candidate in here.parents: - if (candidate / "ptodsl" / "__init__.py").exists(): - sys.path.insert(0, str(candidate)) - break - else: - raise RuntimeError( - "Unable to locate the PTODSL Python package root from flash_attention_softmax_launch.py" - ) - -from ptodsl import pto +from ptodsl import pto, scalar as s _DEVICE = "npu:0" +_MAX_ROWS = 128 +_MAX_SEQ = 128 -def _make_softmax_kernel(name: str, *, rows: int, seq: int): - if rows <= 0: - raise ValueError("rows must be positive") - if seq <= 0: - raise ValueError("seq must be positive") +@pto.jit( + target="a5", + mode="explicit", + insert_sync=False, +) +def softmax_dynamic_shape( + scores: pto.tensor_spec(dtype=pto.f32), + out: pto.tensor_spec(dtype=pto.f32), + rows: pto.i32, + seq: pto.i32, +): + lane_num = pto.elements_per_vreg(pto.f32) + physical_rows = ((_MAX_ROWS + lane_num - 1) // lane_num) * lane_num + scores_tile_bytes = _MAX_SEQ * physical_rows * pto.bytewidth(pto.f32) + runtime_rows = s.index_cast(rows) + runtime_seq = s.index_cast(seq) + total_elems = s.muli(runtime_rows, runtime_seq) + + scores_view = pto.make_tensor_view( + scores, + shape=[1, 1, 1, runtime_seq, runtime_rows], + strides=[total_elems, total_elems, total_elems, runtime_rows, 1], + ) + out_view = pto.make_tensor_view( + out, + shape=[1, 1, 1, runtime_seq, runtime_rows], + strides=[total_elems, total_elems, total_elems, runtime_rows, 1], + ) + scores_part = pto.partition_view( + scores_view, + offsets=[0, 0, 0, 0, 0], + sizes=[1, 1, 1, runtime_seq, runtime_rows], + ) + out_part = pto.partition_view( + out_view, + offsets=[0, 0, 0, 0, 0], + sizes=[1, 1, 1, runtime_seq, runtime_rows], + ) - @pto.jit( - name=name, - target="a5", - mode="explicit", - insert_sync=False + scores_tile = pto.alloc_tile( + shape=[_MAX_SEQ, physical_rows], + dtype=pto.float32, + addr=0, + valid_shape=[runtime_seq, runtime_rows], + blayout="RowMajor", ) - def kernel( - scores: pto.tensor_spec(rank=2, dtype=pto.f32), - out: pto.tensor_spec(rank=2, dtype=pto.f32), - ): - lane_num = pto.elements_per_vreg(pto.f32) - physical_rows = ((rows + lane_num - 1) // lane_num) * lane_num - scores_tile_bytes = seq * physical_rows * pto.bytewidth(pto.f32) - runtime_seq = scores.shape[0] - runtime_rows = scores.shape[1] - total_elems = runtime_rows * runtime_seq - - scores_view = pto.make_tensor_view( - scores, - shape=[1, 1, 1, runtime_seq, runtime_rows], - strides=[total_elems, total_elems, total_elems, runtime_rows, 1], - ) - out_view = pto.make_tensor_view( - out, - shape=[1, 1, 1, runtime_seq, runtime_rows], - strides=[total_elems, total_elems, total_elems, runtime_rows, 1], - ) - scores_part = pto.partition_view( - scores_view, - offsets=[0, 0, 0, 0, 0], - sizes=[1, 1, 1, runtime_seq, runtime_rows], - ) - out_part = pto.partition_view( - out_view, - offsets=[0, 0, 0, 0, 0], - sizes=[1, 1, 1, runtime_seq, runtime_rows], - ) - - scores_tile = pto.alloc_tile( - shape=[seq, physical_rows], - dtype=pto.float32, - addr=0, - valid_shape=[runtime_seq, runtime_rows], - blayout="RowMajor", - ) - out_tile = pto.alloc_tile( - shape=[seq, physical_rows], - dtype=pto.float32, - addr=scores_tile_bytes, - valid_shape=[runtime_seq, runtime_rows], - blayout="RowMajor", - ) - - pto.tile.load(scores_part, scores_tile) - out_tile.fill(0.0) - - pto.set_flag("MTE2", "V", event_id=0) - pto.wait_flag("MTE2", "V", event_id=0) - - with pto.simd(): - row_loop = pto.for_(0, runtime_rows, step=lane_num).carry(remained=runtime_rows) - with row_loop: - row_base = row_loop.iv - remaining_rows = row_loop.remained - active_rows, remaining_after_pack = pto.make_mask(pto.f32, remaining_rows) - running_max = pto.vlds(scores_tile[0, row_base:]) - running_sum = pto.vbr(1.0) - - softmax_loop = pto.for_(1, runtime_seq, step=1).carry( - running_max=running_max, - running_sum=running_sum, - ) - with softmax_loop: - col = softmax_loop.iv - running_max = softmax_loop.running_max - running_sum = softmax_loop.running_sum - col_vec = pto.vlds(scores_tile[col, row_base:]) - merged_max = pto.vmax(running_max, col_vec, active_rows) - running_delta = pto.vsub(running_max, merged_max, active_rows) - scaled_running = pto.vexp(running_delta, active_rows) - running_sum_scaled = pto.vmul(scaled_running, running_sum, active_rows) - col_delta = pto.vsub(col_vec, merged_max, active_rows) - col_exp = pto.vexp(col_delta, active_rows) - merged_sum = pto.vadd(running_sum_scaled, col_exp, active_rows) - softmax_loop.update(running_max=merged_max, running_sum=merged_sum) - - final_max = softmax_loop.final("running_max") - final_sum = softmax_loop.final("running_sum") - - with pto.for_(0, runtime_seq, step=1) as col: - col_vec = pto.vlds(scores_tile[col, row_base:]) - out_delta = pto.vsub(col_vec, final_max, active_rows) - exp_vec = pto.vexp(out_delta, active_rows) - out_vec = pto.vdiv(exp_vec, final_sum, active_rows) - pto.vsts(out_vec, out_tile[col, row_base:], active_rows) - - row_loop.update(remained=remaining_after_pack) - - pto.set_flag("V", "MTE3", event_id=0) - pto.wait_flag("V", "MTE3", event_id=0) - - pto.tile.store(out_tile, out_part) - pto.pipe_barrier(pto.Pipe.ALL) - - return kernel - - -SOFTMAX_ROWS64_SEQ128 = _make_softmax_kernel( - "softmax_rows64_seq128", - rows=64, - seq=128, -) -SOFTMAX_ROWS81_SEQ96 = _make_softmax_kernel( - "softmax_rows81_seq96", - rows=81, - seq=96, -) + out_tile = pto.alloc_tile( + shape=[_MAX_SEQ, physical_rows], + dtype=pto.float32, + addr=scores_tile_bytes, + valid_shape=[runtime_seq, runtime_rows], + blayout="RowMajor", + ) + + pto.tile.load(scores_part, scores_tile) + out_tile.fill(0.0) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.simd(): + row_loop = pto.for_(0, runtime_rows, step=lane_num).carry(remained=runtime_rows) + with row_loop: + row_base = row_loop.iv + remaining_rows = row_loop.remained + active_rows, remaining_after_pack = pto.make_mask(pto.f32, remaining_rows) + running_max = pto.vlds(scores_tile[0, row_base:]) + running_sum = pto.vbr(1.0) + + softmax_loop = pto.for_(1, runtime_seq, step=1).carry( + running_max=running_max, + running_sum=running_sum, + ) + with softmax_loop: + col = softmax_loop.iv + running_max = softmax_loop.running_max + running_sum = softmax_loop.running_sum + col_vec = pto.vlds(scores_tile[col, row_base:]) + merged_max = pto.vmax(running_max, col_vec, active_rows) + running_delta = pto.vsub(running_max, merged_max, active_rows) + scaled_running = pto.vexp(running_delta, active_rows) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active_rows) + col_delta = pto.vsub(col_vec, merged_max, active_rows) + col_exp = pto.vexp(col_delta, active_rows) + merged_sum = pto.vadd(running_sum_scaled, col_exp, active_rows) + softmax_loop.update(running_max=merged_max, running_sum=merged_sum) + + final_max = softmax_loop.final("running_max") + final_sum = softmax_loop.final("running_sum") + + with pto.for_(0, runtime_seq, step=1) as col: + col_vec = pto.vlds(scores_tile[col, row_base:]) + out_delta = pto.vsub(col_vec, final_max, active_rows) + exp_vec = pto.vexp(out_delta, active_rows) + out_vec = pto.vdiv(exp_vec, final_sum, active_rows) + pto.vsts(out_vec, out_tile[col, row_base:], active_rows) + + row_loop.update(remained=remaining_after_pack) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + pto.tile.store(out_tile, out_part) + pto.pipe_barrier(pto.Pipe.ALL) + KERNELS = ( - SOFTMAX_ROWS64_SEQ128, - SOFTMAX_ROWS81_SEQ96, + softmax_dynamic_shape, ) CASES = [ { "name": "rows64_seq128", - "kernel": SOFTMAX_ROWS64_SEQ128, + "kernel": softmax_dynamic_shape, "rows": 64, "seq": 128, }, { "name": "rows81_seq96", - "kernel": SOFTMAX_ROWS81_SEQ96, + "kernel": softmax_dynamic_shape, "rows": 81, "seq": 96, }, @@ -250,6 +217,8 @@ def run_case(case: dict[str, object], torch) -> None: compiled[1, stream]( scores_t, out_t, + int(case["rows"]), + int(case["seq"]), ) torch.npu.synchronize() launch_s = time.perf_counter() - t0 diff --git a/ptodsl/ptodsl/_host_tensors.py b/ptodsl/ptodsl/_host_tensors.py index 0217ffb5f..03ee4c080 100644 --- a/ptodsl/ptodsl/_host_tensors.py +++ b/ptodsl/ptodsl/_host_tensors.py @@ -103,8 +103,8 @@ def inspect_host_tensor_metadata(tensor) -> HostTensorMetadata: class TensorSpec: """Static ABI hint for one Python-native ``@pto.jit`` tensor parameter.""" - rank: int dtype: object + rank: int = 2 address_space: str = "gm" def __post_init__(self): @@ -139,9 +139,9 @@ def __repr__(self): ) -def tensor_spec(*, rank: int, dtype, address_space: str = "gm") -> TensorSpec: +def tensor_spec(*, dtype, rank: int = 2, address_space: str = "gm") -> TensorSpec: """Declare the ABI contract of one Python-native ``@pto.jit`` tensor parameter.""" - return TensorSpec(rank=rank, dtype=dtype, address_space=address_space) + return TensorSpec(dtype=dtype, rank=rank, address_space=address_space) class HostTensorValue: From a8425826597b06827bacc39f3283f9da88f05712 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 26 May 2026 14:43:46 +0200 Subject: [PATCH 27/31] removed rank2 from tadd example --- ptodsl/examples/jit/tadd_launch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ptodsl/examples/jit/tadd_launch.py b/ptodsl/examples/jit/tadd_launch.py index 4eb0ed64d..b15123adb 100644 --- a/ptodsl/examples/jit/tadd_launch.py +++ b/ptodsl/examples/jit/tadd_launch.py @@ -104,9 +104,9 @@ def _tadd_tile_dynamic_rows(A, B, C, rows, *, max_rows: int, cols: int) -> None: target="a5", ) def TADD_f32_16x64( - A: pto.tensor_spec(rank=2, dtype=pto.f32), - B: pto.tensor_spec(rank=2, dtype=pto.f32), - C: pto.tensor_spec(rank=2, dtype=pto.f32), + A: pto.tensor_spec(dtype=pto.f32), + B: pto.tensor_spec(dtype=pto.f32), + C: pto.tensor_spec(dtype=pto.f32), ): _tadd_tile(A, B, C, 16, 64) @@ -117,9 +117,9 @@ def TADD_f32_16x64( target="a5", ) def TADD_f32_dyn_rows_x64( - A: pto.tensor_spec(rank=2, dtype=pto.f32), - B: pto.tensor_spec(rank=2, dtype=pto.f32), - C: pto.tensor_spec(rank=2, dtype=pto.f32), + A: pto.tensor_spec(dtype=pto.f32), + B: pto.tensor_spec(dtype=pto.f32), + C: pto.tensor_spec(dtype=pto.f32), rows: pto.i32, ): _tadd_tile_dynamic_rows(A, B, C, rows, max_rows=32, cols=64) From abd56718c3a1dc604c1754addc679f87a6a5d2a6 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Wed, 27 May 2026 13:28:19 +0200 Subject: [PATCH 28/31] static matmul example --- .../examples/jit/tmatmul_f16_16x16_launch.py | 216 ++++++++++++++++++ ptodsl/ptodsl/_ops.py | 31 ++- ptodsl/ptodsl/pto.py | 2 +- 3 files changed, 247 insertions(+), 2 deletions(-) create mode 100644 ptodsl/examples/jit/tmatmul_f16_16x16_launch.py diff --git a/ptodsl/examples/jit/tmatmul_f16_16x16_launch.py b/ptodsl/examples/jit/tmatmul_f16_16x16_launch.py new file mode 100644 index 000000000..ea84651bf --- /dev/null +++ b/ptodsl/examples/jit/tmatmul_f16_16x16_launch.py @@ -0,0 +1,216 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Minimal cube matmul JIT demo that mirrors the manual-address IR shape used by +``TMATMUL_f16_16x16x16``. + +The boundary movement uses explicit GM->L1 fractal loads and L0C->GM store, +while the core matmul stays on the TileOp path through ``pto.tmatmul``. +""" + +import argparse +import time + +import numpy as np +from mlir.dialects import pto as _pto +from mlir.ir import Attribute + +from ptodsl import pto +from ptodsl._surface_values import unwrap_surface_value + +_DEVICE = "npu:0" + + +def _raw(value): + return unwrap_surface_value(value) + + +def _mte_gm_l1_frac(source, destination, *, shape, src_layout, dst_group, ctrl): + _pto.mte_gm_l1_frac( + _raw(source), + _raw(destination), + _raw(shape[0]), + _raw(shape[1]), + _raw(src_layout[0]), + _raw(dst_group[0]), + _raw(dst_group[1]), + _raw(dst_group[2]), + _raw(dst_group[3]), + _raw(ctrl[0]), + _raw(ctrl[1]), + Attribute.parse("#pto"), + ) + + +def _tmatmul(lhs, rhs, dst): + _pto.tmatmul(None, _raw(lhs), _raw(rhs), _raw(dst)) + + +@pto.jit( + name="TMATMUL_f16_16x16x16", + target="a5", + kernel_kind="cube", + mode="explicit", + insert_sync=False, +) +def TMATMUL_f16_16x16x16( + A: pto.tensor_spec(rank=2, dtype=pto.f16), + B: pto.tensor_spec(rank=2, dtype=pto.f16), + C: pto.tensor_spec(rank=2, dtype=pto.f32), +): + c0 = pto.const(0, dtype=pto.i64) + c1 = pto.const(1, dtype=pto.i64) + c16 = pto.const(16, dtype=pto.i64) + c32 = pto.const(32, dtype=pto.i64) + false = pto.const(0, dtype=pto.i1) + + l1_a_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.MAT, + blayout="ColMajor", + slayout="RowMajor", + addr=0, + ) + l1_b_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.MAT, + blayout="ColMajor", + slayout="RowMajor", + addr=512, + ) + l0a_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.LEFT, + blayout="ColMajor", + slayout="RowMajor", + addr=0, + ) + l0b_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f16, + memory_space=pto.MemorySpace.RIGHT, + blayout="RowMajor", + slayout="ColMajor", + addr=0, + ) + l0c_tile = pto.alloc_tile( + shape=[16, 16], + dtype=pto.f32, + memory_space=pto.MemorySpace.ACC, + blayout="ColMajor", + slayout="RowMajor", + fractal_size=1024, + addr=0, + ) + + l1_a = l1_a_tile.as_ptr() + l1_b = l1_b_tile.as_ptr() + l0a = l0a_tile.as_ptr() + l0b = l0b_tile.as_ptr() + l0c = l0c_tile.as_ptr() + + _mte_gm_l1_frac( + A.data_handle, + l1_a, + shape=(c16, c16), + src_layout=(c32,), + dst_group=(c1, c1, c16, c0), + ctrl=(c0, false), + ) + pto.set_flag("MTE2", "MTE1", event_id=0) + pto.wait_flag("MTE2", "MTE1", event_id=0) + pto.mte_l1_l0a(l1_a, l0a, c16, c16) + + _mte_gm_l1_frac( + B.data_handle, + l1_b, + shape=(c16, c16), + src_layout=(c32,), + dst_group=(c1, c1, c16, c0), + ctrl=(c0, false), + ) + pto.set_flag("MTE2", "MTE1", event_id=1) + pto.wait_flag("MTE2", "MTE1", event_id=1) + pto.mte_l1_l0b(l1_b, l0b, c16, c16, transpose=True) + + pto.set_flag("MTE1", "M", event_id=0) + pto.wait_flag("MTE1", "M", event_id=0) + _tmatmul(l0a_tile, l0b_tile, l0c_tile) + + pto.set_flag("M", "FIX", event_id=1) + pto.wait_flag("M", "FIX", event_id=1) + pto.mte_l0c_gm(l0c, C.data_handle, c16, c16, c16, c16, c0, c0) + pto.pipe_barrier(pto.Pipe.ALL) + + +def emit_mlir(): + return pto.merge_jit_modules(TMATMUL_f16_16x16x16) + + +def init_runtime(): + import torch + import torch_npu # noqa: F401 + + torch.npu.config.allow_internal_format = False + torch_npu.npu.set_compile_mode(jit_compile=False) + torch.npu.set_device(_DEVICE) + return torch + + +def npu_stream(torch): + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + +def test_tmatmul() -> None: + torch = init_runtime() + rng = np.random.RandomState(0) + a_np = rng.uniform(-1.0, 1.0, size=(16, 16)).astype(np.float16) + b_np = rng.uniform(-1.0, 1.0, size=(16, 16)).astype(np.float16) + ref = np.matmul(a_np.astype(np.float32), b_np.astype(np.float32)) + + a = torch.from_numpy(a_np).to(_DEVICE) + b = torch.from_numpy(b_np).to(_DEVICE) + c = torch.empty((16, 16), dtype=torch.float32, device=_DEVICE) + stream = npu_stream(torch) + + t0 = time.perf_counter() + compiled = TMATMUL_f16_16x16x16.compile() + compile_s = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled[1, stream](a, b, c) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + np.testing.assert_allclose(c.cpu().numpy(), ref, rtol=1e-2, atol=1e-2) + print(f"PASS TMATMUL_f16_16x16x16 compile={compile_s:.3f}s launch={launch_s:.3f}s") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--emit-mlir", + action="store_true", + help="print the generated MLIR module and exit", + ) + args = parser.parse_args(argv) + + if args.emit_mlir: + print(emit_mlir()) + return 0 + + test_tmatmul() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index 9cbfec41f..8ba892296 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -2986,6 +2986,35 @@ def mte_l1_l0b(source, destination, k, n, *, transpose=False): ) +@_explicit_mode_only("pto.mte_l0c_gm(...)") +def mte_l0c_gm( + source, + destination, + m, + n, + src_stride, + dst_stride, + sid=0, + l2_cache_ctrl=0, + *, + mode="nz2nd", +): + """``pto.mte_l0c_gm`` – ACC to GM store.""" + if isinstance(mode, str): + mode = Attribute.parse(f"#pto") + _pto.mte_l0c_gm( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(m, context="mte_l0c_gm m"), + _coerce_i64(n, context="mte_l0c_gm n"), + _coerce_i64(src_stride, context="mte_l0c_gm src_stride"), + _coerce_i64(dst_stride, context="mte_l0c_gm dst_stride"), + _coerce_i64(sid, context="mte_l0c_gm sid"), + _coerce_i64(l2_cache_ctrl, context="mte_l0c_gm l2_cache_ctrl"), + mode=mode, + ) + + @_explicit_mode_only("pto.mte_l0c_ub(...)") def mte_l0c_ub(source, destination, m, n, src_stride, dst_stride, sub_blockid=0, *, dst_mode="single"): """``pto.mte_l0c_ub`` – ACC to UB store.""" @@ -3259,7 +3288,7 @@ def wait_flag(src: str, dst: str, *, event_id: int = 0): "tfillpad", "tfillpad_expand", "tfillpad_inplace", "as_ptr", "mte_load", "mte_store", "mte_gm_ub", "mte_ub_gm", "mte_ub_ub", "mte_ub_l1", "mem_bar", - "mte_l1_l0a", "mte_l1_l0b", "mte_l0c_ub", + "mte_l1_l0a", "mte_l1_l0b", "mte_l0c_gm", "mte_l0c_ub", "mad", "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias", "get_block_idx", "get_block_num", "get_subblock_idx", "get_subblock_num", "store_vfsimt_info", "get_tid_x", "get_tid_y", "get_tid_z", diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index da43f3c2b..dbb306b52 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -87,7 +87,7 @@ make_tensor_view, partition_view, alloc_tile, mte_load, mte_store, mte_gm_ub, mte_ub_gm, mte_ub_ub, mte_ub_l1, mem_bar, - mte_l1_l0a, mte_l1_l0b, mte_l0c_ub, + mte_l1_l0a, mte_l1_l0b, mte_l0c_gm, mte_l0c_ub, mad, mad_acc, mad_bias, mad_mx, mad_mx_acc, mad_mx_bias, get_block_idx, get_block_num, get_subblock_idx, get_subblock_num, store_vfsimt_info, get_tid_x, get_tid_y, get_tid_z, From 38abaf0a129142a8fec5d4ec9c21cdbe0e663fcd Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Wed, 27 May 2026 14:27:26 +0200 Subject: [PATCH 29/31] dyanmic size matmul example --- ..._f16_16x16_launch.py => tmatmul_launch.py} | 67 ++++++++++++------- 1 file changed, 41 insertions(+), 26 deletions(-) rename ptodsl/examples/jit/{tmatmul_f16_16x16_launch.py => tmatmul_launch.py} (72%) diff --git a/ptodsl/examples/jit/tmatmul_f16_16x16_launch.py b/ptodsl/examples/jit/tmatmul_launch.py similarity index 72% rename from ptodsl/examples/jit/tmatmul_f16_16x16_launch.py rename to ptodsl/examples/jit/tmatmul_launch.py index ea84651bf..61268d5bb 100644 --- a/ptodsl/examples/jit/tmatmul_f16_16x16_launch.py +++ b/ptodsl/examples/jit/tmatmul_launch.py @@ -21,7 +21,8 @@ from mlir.dialects import pto as _pto from mlir.ir import Attribute -from ptodsl import pto +from ptodsl import pto, scalar as s +from ptodsl._ops import _coerce_i64 from ptodsl._surface_values import unwrap_surface_value _DEVICE = "npu:0" @@ -35,14 +36,14 @@ def _mte_gm_l1_frac(source, destination, *, shape, src_layout, dst_group, ctrl): _pto.mte_gm_l1_frac( _raw(source), _raw(destination), - _raw(shape[0]), - _raw(shape[1]), - _raw(src_layout[0]), - _raw(dst_group[0]), - _raw(dst_group[1]), - _raw(dst_group[2]), - _raw(dst_group[3]), - _raw(ctrl[0]), + _coerce_i64(shape[0], context="mte_gm_l1_frac shape[0]"), + _coerce_i64(shape[1], context="mte_gm_l1_frac shape[1]"), + _coerce_i64(src_layout[0], context="mte_gm_l1_frac src_layout[0]"), + _coerce_i64(dst_group[0], context="mte_gm_l1_frac dst_group[0]"), + _coerce_i64(dst_group[1], context="mte_gm_l1_frac dst_group[1]"), + _coerce_i64(dst_group[2], context="mte_gm_l1_frac dst_group[2]"), + _coerce_i64(dst_group[3], context="mte_gm_l1_frac dst_group[3]"), + _coerce_i64(ctrl[0], context="mte_gm_l1_frac ctrl[0]"), _raw(ctrl[1]), Attribute.parse("#pto"), ) @@ -63,11 +64,14 @@ def TMATMUL_f16_16x16x16( A: pto.tensor_spec(rank=2, dtype=pto.f16), B: pto.tensor_spec(rank=2, dtype=pto.f16), C: pto.tensor_spec(rank=2, dtype=pto.f32), + dim: pto.i32, ): c0 = pto.const(0, dtype=pto.i64) c1 = pto.const(1, dtype=pto.i64) - c16 = pto.const(16, dtype=pto.i64) - c32 = pto.const(32, dtype=pto.i64) + c2 = pto.const(2) + c16 = s.index_cast(dim) + c16_static = pto.const(16) + c32 = s.muli(c16_static, c2) false = pto.const(0, dtype=pto.i1) l1_a_tile = pto.alloc_tile( @@ -148,7 +152,7 @@ def TMATMUL_f16_16x16x16( pto.set_flag("M", "FIX", event_id=1) pto.wait_flag("M", "FIX", event_id=1) - pto.mte_l0c_gm(l0c, C.data_handle, c16, c16, c16, c16, c0, c0) + pto.mte_l0c_gm(l0c, C.data_handle, c16, c16, c16_static, c16_static, c0, c0) pto.pipe_barrier(pto.Pipe.ALL) @@ -173,26 +177,37 @@ def npu_stream(torch): def test_tmatmul() -> None: torch = init_runtime() rng = np.random.RandomState(0) - a_np = rng.uniform(-1.0, 1.0, size=(16, 16)).astype(np.float16) - b_np = rng.uniform(-1.0, 1.0, size=(16, 16)).astype(np.float16) - ref = np.matmul(a_np.astype(np.float32), b_np.astype(np.float32)) - - a = torch.from_numpy(a_np).to(_DEVICE) - b = torch.from_numpy(b_np).to(_DEVICE) - c = torch.empty((16, 16), dtype=torch.float32, device=_DEVICE) stream = npu_stream(torch) + dims = [int(rng.randint(4, 16)) for _ in range(2)] + [16] t0 = time.perf_counter() compiled = TMATMUL_f16_16x16x16.compile() compile_s = time.perf_counter() - t0 - t0 = time.perf_counter() - compiled[1, stream](a, b, c) - torch.npu.synchronize() - launch_s = time.perf_counter() - t0 - - np.testing.assert_allclose(c.cpu().numpy(), ref, rtol=1e-2, atol=1e-2) - print(f"PASS TMATMUL_f16_16x16x16 compile={compile_s:.3f}s launch={launch_s:.3f}s") + for dim in dims: + a_np = np.zeros((16, 16), dtype=np.float16) + b_np = np.zeros((16, 16), dtype=np.float16) + a_np[:dim, :dim] = rng.uniform(-1.0, 1.0, size=(dim, dim)).astype(np.float16) + b_np[:dim, :dim] = rng.uniform(-1.0, 1.0, size=(dim, dim)).astype(np.float16) + ref = np.matmul( + a_np[:dim, :dim].astype(np.float32), + b_np[:dim, :dim].astype(np.float32), + ) + + a = torch.from_numpy(a_np).to(_DEVICE) + b = torch.from_numpy(b_np).to(_DEVICE) + c = torch.empty((16, 16), dtype=torch.float32, device=_DEVICE) + + t0 = time.perf_counter() + compiled[1, stream](a, b, c, dim) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + np.testing.assert_allclose(c.cpu().numpy()[:dim, :dim], ref, rtol=1e-2, atol=1e-2) + print( + f"PASS TMATMUL_f16_{dim}x{dim}x{dim} " + f"compile={compile_s:.3f}s launch={launch_s:.3f}s" + ) def main(argv=None) -> int: From ee7348a825ead0138bd9fd5c6cc8684e56fe82b0 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Wed, 27 May 2026 14:32:38 +0200 Subject: [PATCH 30/31] decreased tolerance for tmatmul example --- ptodsl/examples/jit/tmatmul_launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ptodsl/examples/jit/tmatmul_launch.py b/ptodsl/examples/jit/tmatmul_launch.py index 61268d5bb..41fd52a5d 100644 --- a/ptodsl/examples/jit/tmatmul_launch.py +++ b/ptodsl/examples/jit/tmatmul_launch.py @@ -203,7 +203,7 @@ def test_tmatmul() -> None: torch.npu.synchronize() launch_s = time.perf_counter() - t0 - np.testing.assert_allclose(c.cpu().numpy()[:dim, :dim], ref, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(c.cpu().numpy()[:dim, :dim], ref, rtol=1e-6, atol=1e-6) print( f"PASS TMATMUL_f16_{dim}x{dim}x{dim} " f"compile={compile_s:.3f}s launch={launch_s:.3f}s" From 539f066d86f2342ac0d05a77e68bceda6514b274 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Wed, 27 May 2026 15:15:53 +0200 Subject: [PATCH 31/31] batched matmul dynamic example --- ...mul_launch.py => batched_matmul_launch.py} | 139 ++++++++++++------ 1 file changed, 95 insertions(+), 44 deletions(-) rename ptodsl/examples/jit/{tmatmul_launch.py => batched_matmul_launch.py} (60%) diff --git a/ptodsl/examples/jit/tmatmul_launch.py b/ptodsl/examples/jit/batched_matmul_launch.py similarity index 60% rename from ptodsl/examples/jit/tmatmul_launch.py rename to ptodsl/examples/jit/batched_matmul_launch.py index 41fd52a5d..306adbeb0 100644 --- a/ptodsl/examples/jit/tmatmul_launch.py +++ b/ptodsl/examples/jit/batched_matmul_launch.py @@ -11,7 +11,8 @@ ``TMATMUL_f16_16x16x16``. The boundary movement uses explicit GM->L1 fractal loads and L0C->GM store, -while the core matmul stays on the TileOp path through ``pto.tmatmul``. +while the core matmul stays on the TileOp path through ``pto.tmatmul``. The +same compiled kernel is launched with multiple runtime batch sizes. """ import argparse @@ -61,17 +62,22 @@ def _tmatmul(lhs, rhs, dst): insert_sync=False, ) def TMATMUL_f16_16x16x16( - A: pto.tensor_spec(rank=2, dtype=pto.f16), - B: pto.tensor_spec(rank=2, dtype=pto.f16), - C: pto.tensor_spec(rank=2, dtype=pto.f32), + A: pto.tensor_spec(rank=3, dtype=pto.f16), + B: pto.tensor_spec(rank=3, dtype=pto.f16), + C: pto.tensor_spec(rank=3, dtype=pto.f32), dim: pto.i32, + batch: pto.i32, ): + c0_idx = pto.const(0) + c1_idx = pto.const(1) c0 = pto.const(0, dtype=pto.i64) c1 = pto.const(1, dtype=pto.i64) c2 = pto.const(2) c16 = s.index_cast(dim) + c_batch = s.index_cast(batch) c16_static = pto.const(16) c32 = s.muli(c16_static, c2) + c_tile_elems = s.muli(c16_static, c16_static) false = pto.const(0, dtype=pto.i1) l1_a_tile = pto.alloc_tile( @@ -121,38 +127,74 @@ def TMATMUL_f16_16x16x16( l0a = l0a_tile.as_ptr() l0b = l0b_tile.as_ptr() l0c = l0c_tile.as_ptr() - - _mte_gm_l1_frac( - A.data_handle, - l1_a, - shape=(c16, c16), - src_layout=(c32,), - dst_group=(c1, c1, c16, c0), - ctrl=(c0, false), + a_view = pto.make_tensor_view( + A, + shape=[c_batch, c16_static, c16_static], + strides=[c_tile_elems, c16_static, c1_idx], + ) + b_view = pto.make_tensor_view( + B, + shape=[c_batch, c16_static, c16_static], + strides=[c_tile_elems, c16_static, c1_idx], ) - pto.set_flag("MTE2", "MTE1", event_id=0) - pto.wait_flag("MTE2", "MTE1", event_id=0) - pto.mte_l1_l0a(l1_a, l0a, c16, c16) - - _mte_gm_l1_frac( - B.data_handle, - l1_b, - shape=(c16, c16), - src_layout=(c32,), - dst_group=(c1, c1, c16, c0), - ctrl=(c0, false), + c_view = pto.make_tensor_view( + C, + shape=[c_batch, c16_static, c16_static], + strides=[c_tile_elems, c16_static, c1_idx], ) - pto.set_flag("MTE2", "MTE1", event_id=1) - pto.wait_flag("MTE2", "MTE1", event_id=1) - pto.mte_l1_l0b(l1_b, l0b, c16, c16, transpose=True) - pto.set_flag("MTE1", "M", event_id=0) - pto.wait_flag("MTE1", "M", event_id=0) - _tmatmul(l0a_tile, l0b_tile, l0c_tile) + with pto.for_(c0_idx, c_batch, step=c1_idx) as batch_idx: + a_part = pto.partition_view( + a_view, + offsets=[batch_idx, c0_idx, c0_idx], + sizes=[c1_idx, c16, c16], + ) + b_part = pto.partition_view( + b_view, + offsets=[batch_idx, c0_idx, c0_idx], + sizes=[c1_idx, c16, c16], + ) + c_part = pto.partition_view( + c_view, + offsets=[batch_idx, c0_idx, c0_idx], + sizes=[c1_idx, c16, c16], + ) + a_ptr = a_part.as_ptr() + b_ptr = b_part.as_ptr() + c_ptr = c_part.as_ptr() + + _mte_gm_l1_frac( + a_ptr, + l1_a, + shape=(c16, c16), + src_layout=(c32,), + dst_group=(c1, c1, c16, c0), + ctrl=(c0, false), + ) + pto.set_flag("MTE2", "MTE1", event_id=0) + pto.wait_flag("MTE2", "MTE1", event_id=0) + pto.mte_l1_l0a(l1_a, l0a, c16, c16) + + _mte_gm_l1_frac( + b_ptr, + l1_b, + shape=(c16, c16), + src_layout=(c32,), + dst_group=(c1, c1, c16, c0), + ctrl=(c0, false), + ) + pto.set_flag("MTE2", "MTE1", event_id=1) + pto.wait_flag("MTE2", "MTE1", event_id=1) + pto.mte_l1_l0b(l1_b, l0b, c16, c16, transpose=True) + + pto.set_flag("MTE1", "M", event_id=0) + pto.wait_flag("MTE1", "M", event_id=0) + _tmatmul(l0a_tile, l0b_tile, l0c_tile) + + pto.set_flag("M", "FIX", event_id=1) + pto.wait_flag("M", "FIX", event_id=1) + pto.mte_l0c_gm(l0c, c_ptr, c16, c16, c16_static, c16_static, c0, c0) - pto.set_flag("M", "FIX", event_id=1) - pto.wait_flag("M", "FIX", event_id=1) - pto.mte_l0c_gm(l0c, C.data_handle, c16, c16, c16_static, c16_static, c0, c0) pto.pipe_barrier(pto.Pipe.ALL) @@ -178,34 +220,43 @@ def test_tmatmul() -> None: torch = init_runtime() rng = np.random.RandomState(0) stream = npu_stream(torch) - dims = [int(rng.randint(4, 16)) for _ in range(2)] + [16] + cases = [ + (int(rng.randint(4, 16)), int(rng.randint(1, 5))) + for _ in range(2) + ] + [(16, 3)] t0 = time.perf_counter() compiled = TMATMUL_f16_16x16x16.compile() compile_s = time.perf_counter() - t0 - for dim in dims: - a_np = np.zeros((16, 16), dtype=np.float16) - b_np = np.zeros((16, 16), dtype=np.float16) - a_np[:dim, :dim] = rng.uniform(-1.0, 1.0, size=(dim, dim)).astype(np.float16) - b_np[:dim, :dim] = rng.uniform(-1.0, 1.0, size=(dim, dim)).astype(np.float16) + for dim, batch in cases: + a_np = np.zeros((batch, 16, 16), dtype=np.float16) + b_np = np.zeros((batch, 16, 16), dtype=np.float16) + a_np[:, :dim, :dim] = rng.uniform( + -1.0, 1.0, size=(batch, dim, dim) + ).astype(np.float16) + b_np[:, :dim, :dim] = rng.uniform( + -1.0, 1.0, size=(batch, dim, dim) + ).astype(np.float16) ref = np.matmul( - a_np[:dim, :dim].astype(np.float32), - b_np[:dim, :dim].astype(np.float32), + a_np[:, :dim, :dim].astype(np.float32), + b_np[:, :dim, :dim].astype(np.float32), ) a = torch.from_numpy(a_np).to(_DEVICE) b = torch.from_numpy(b_np).to(_DEVICE) - c = torch.empty((16, 16), dtype=torch.float32, device=_DEVICE) + c = torch.empty((batch, 16, 16), dtype=torch.float32, device=_DEVICE) t0 = time.perf_counter() - compiled[1, stream](a, b, c, dim) + compiled[1, stream](a, b, c, dim, batch) torch.npu.synchronize() launch_s = time.perf_counter() - t0 - np.testing.assert_allclose(c.cpu().numpy()[:dim, :dim], ref, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose( + c.cpu().numpy()[:, :dim, :dim], ref, rtol=1e-6, atol=1e-6 + ) print( - f"PASS TMATMUL_f16_{dim}x{dim}x{dim} " + f"PASS TMATMUL_f16_batch{batch}_{dim}x{dim}x{dim} " f"compile={compile_s:.3f}s launch={launch_s:.3f}s" )